Skip to main content

atomr_accel_cuda/kernel/cudnn/
graph.rs

1//! cuDNN v9 frontend graph builder — Rust-level "spec" objects that
2//! describe a backend descriptor DAG (tensors → ops → operation graph
3//! → engine config → execution plan) plus a plan cache keyed on
4//! op-shape signatures.
5//!
6//! The spec layer is fully host-buildable. The actual
7//! `cudnnBackendCreateDescriptor` / `cudnnBackendFinalize` calls live
8//! in [`Self::build_into`] which only fires when a real cuDNN handle
9//! is plumbed in. Unit tests round-trip the spec without touching FFI.
10//!
11//! # What we build
12//!
13//! ```text
14//! TensorSpec*  ──► OpSpec  ──► OperationGraphSpec
15//!                                   │
16//!                                   ▼
17//!                       EngineHeurSpec ──► EnginecfgSpec
18//!                                              │
19//!                                              ▼
20//!                                       ExecutionPlanSpec
21//!                                              │
22//!                                              ▼
23//!                                       VariantPackSpec
24//! ```
25//!
26//! Plan-cache key is op-kind + dtype + tensor-spec digest, so two
27//! requests with identical shapes hit the same cached plan.
28
29#![allow(dead_code)]
30
31use std::collections::hash_map::DefaultHasher;
32use std::hash::{Hash, Hasher};
33use std::num::NonZeroUsize;
34
35use lru::LruCache;
36
37#[cfg(feature = "cudnn")]
38use cudarc::cudnn::sys as cudnn_sys;
39
40use crate::error::GpuError;
41
42/// Default LRU capacity for the plan cache (matches the existing
43/// cuDNN ConvForward cache + cuBLASLt heuristic cache).
44pub const DEFAULT_PLAN_CACHE_SIZE: usize = 256;
45
46/// Tensor layout: NCHW, NHWC, or fully arbitrary nd-strided.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum TensorLayout {
49    /// NCHW (or NCDHW for 3D): channel-second, packed strides.
50    NchwPacked,
51    /// NHWC (or NDHWC for 3D): channel-last, packed strides.
52    NhwcPacked,
53    /// Caller supplies explicit strides.
54    Strided,
55}
56
57/// cuDNN scalar dtype tag, decoupled from a `T: CudaDtype` parameter
58/// so spec-level objects are dyn-friendly.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum DtypeTag {
61    F32,
62    F64,
63    F16,
64    Bf16,
65    I8,
66    I32,
67    U8,
68}
69
70impl DtypeTag {
71    pub fn name(self) -> &'static str {
72        match self {
73            DtypeTag::F32 => "f32",
74            DtypeTag::F64 => "f64",
75            DtypeTag::F16 => "f16",
76            DtypeTag::Bf16 => "bf16",
77            DtypeTag::I8 => "i8",
78            DtypeTag::I32 => "i32",
79            DtypeTag::U8 => "u8",
80        }
81    }
82
83    /// Map back to the cuDNN data-type enum.
84    #[cfg(feature = "cudnn")]
85    pub fn cudnn(self) -> cudnn_sys::cudnnDataType_t {
86        use cudnn_sys::cudnnDataType_t::*;
87        match self {
88            DtypeTag::F32 => CUDNN_DATA_FLOAT,
89            DtypeTag::F64 => CUDNN_DATA_DOUBLE,
90            DtypeTag::F16 => CUDNN_DATA_HALF,
91            DtypeTag::Bf16 => CUDNN_DATA_BFLOAT16,
92            DtypeTag::I8 => CUDNN_DATA_INT8,
93            DtypeTag::I32 => CUDNN_DATA_INT32,
94            DtypeTag::U8 => CUDNN_DATA_UINT8,
95        }
96    }
97}
98
99/// One tensor in a backend graph: unique id, dtype, dims, strides,
100/// alignment.
101#[derive(Debug, Clone, PartialEq, Eq, Hash)]
102pub struct TensorSpec {
103    pub uid: i64,
104    pub dtype: DtypeTag,
105    pub dims: Vec<i64>,
106    pub strides: Vec<i64>,
107    /// Byte alignment of the data pointer. cuDNN requires ≥ 4.
108    pub alignment: i64,
109    /// Whether the tensor is virtual (intermediate result, no
110    /// device-pointer binding required).
111    pub is_virtual: bool,
112}
113
114impl TensorSpec {
115    /// Build a TensorSpec for `dims` under `layout`. For `Strided`
116    /// the caller must call `with_strides` afterwards.
117    pub fn new(uid: i64, dtype: DtypeTag, dims: Vec<i64>, layout: TensorLayout) -> Self {
118        let strides = packed_strides(&dims, layout);
119        Self {
120            uid,
121            dtype,
122            dims,
123            strides,
124            alignment: 16,
125            is_virtual: false,
126        }
127    }
128
129    pub fn with_strides(mut self, strides: Vec<i64>) -> Self {
130        debug_assert_eq!(strides.len(), self.dims.len());
131        self.strides = strides;
132        self
133    }
134
135    pub fn with_alignment(mut self, alignment: i64) -> Self {
136        self.alignment = alignment;
137        self
138    }
139
140    pub fn virtualized(mut self) -> Self {
141        self.is_virtual = true;
142        self
143    }
144
145    pub fn rank(&self) -> usize {
146        self.dims.len()
147    }
148}
149
150/// Compute packed strides for `dims` under `layout`. NchwPacked is
151/// row-major over `[N,C,...]`; NhwcPacked is row-major over `[N,...,C]`.
152fn packed_strides(dims: &[i64], layout: TensorLayout) -> Vec<i64> {
153    let n = dims.len();
154    if n == 0 {
155        return Vec::new();
156    }
157    match layout {
158        TensorLayout::NchwPacked | TensorLayout::Strided => {
159            let mut strides = vec![1i64; n];
160            for i in (0..n - 1).rev() {
161                strides[i] = strides[i + 1] * dims[i + 1];
162            }
163            strides
164        }
165        TensorLayout::NhwcPacked => {
166            // NHWC: order on disk is N, H, W, ..., C. We model as
167            // strides such that channel has stride 1 and the leading
168            // batch is the slowest-moving.
169            // For dims [N, C, S1, ..., Sk] return strides such that
170            // channel stride = 1, S_k stride = C, ..., N stride = C * prod(S).
171            assert!(n >= 3, "NHWC layout requires at least N,C,S1");
172            let mut strides = vec![0i64; n];
173            // stride for channel (index 1) is 1
174            strides[1] = 1;
175            // last spatial dim has stride = channels
176            let c = dims[1];
177            strides[n - 1] = c;
178            // walk spatial dims right-to-left
179            for i in (2..n - 1).rev() {
180                strides[i] = strides[i + 1] * dims[i + 1];
181            }
182            // batch stride
183            strides[0] = strides[2] * dims[2];
184            strides
185        }
186    }
187}
188
189/// One op in a backend graph. Each `OpSpec` references TensorSpecs
190/// by their `uid`; the actual TensorSpec values live on the parent
191/// [`OperationGraphSpec`].
192///
193/// `Hash` is implemented manually so that float fields participate via
194/// their bit-pattern (so two specs with `alpha = 0.0` hash equal even
195/// though `f64: !Eq`).
196#[derive(Debug, Clone)]
197pub enum OpSpec {
198    /// Convolution forward: y = conv(x, w).
199    ConvFwd {
200        x: i64,
201        w: i64,
202        y: i64,
203        spatial_dims: usize,
204        pre_padding: Vec<i64>,
205        post_padding: Vec<i64>,
206        stride: Vec<i64>,
207        dilation: Vec<i64>,
208        compute_dtype: DtypeTag,
209        alpha: f64,
210        beta: f64,
211    },
212    /// Convolution backward data: dx = conv_bwd_data(w, dy).
213    ConvBwdData {
214        dy: i64,
215        w: i64,
216        dx: i64,
217        spatial_dims: usize,
218        pre_padding: Vec<i64>,
219        post_padding: Vec<i64>,
220        stride: Vec<i64>,
221        dilation: Vec<i64>,
222        compute_dtype: DtypeTag,
223        alpha: f64,
224        beta: f64,
225    },
226    /// Convolution backward filter: dw = conv_bwd_filter(x, dy).
227    ConvBwdFilter {
228        x: i64,
229        dy: i64,
230        dw: i64,
231        spatial_dims: usize,
232        pre_padding: Vec<i64>,
233        post_padding: Vec<i64>,
234        stride: Vec<i64>,
235        dilation: Vec<i64>,
236        compute_dtype: DtypeTag,
237        alpha: f64,
238        beta: f64,
239    },
240    /// Pointwise op (activation, scale, bias-add, …).
241    Pointwise {
242        mode: PointwiseMode,
243        x: i64,
244        b: Option<i64>,
245        y: i64,
246        compute_dtype: DtypeTag,
247        alpha1: f64,
248        alpha2: f64,
249    },
250    /// Pooling/resample forward.
251    PoolFwd {
252        kind: PoolKind,
253        x: i64,
254        y: i64,
255        window: Vec<i64>,
256        pre_padding: Vec<i64>,
257        post_padding: Vec<i64>,
258        stride: Vec<i64>,
259        compute_dtype: DtypeTag,
260    },
261    /// Pooling/resample backward.
262    PoolBwd {
263        kind: PoolKind,
264        dy: i64,
265        x: i64,
266        y: i64,
267        dx: i64,
268        window: Vec<i64>,
269        pre_padding: Vec<i64>,
270        post_padding: Vec<i64>,
271        stride: Vec<i64>,
272        compute_dtype: DtypeTag,
273    },
274    /// Normalisation forward (batch / layer / instance / group).
275    NormFwd {
276        mode: NormMode,
277        phase: NormPhase,
278        x: i64,
279        scale: i64,
280        bias: i64,
281        mean: Option<i64>,
282        var: Option<i64>,
283        y: i64,
284        compute_dtype: DtypeTag,
285        epsilon: f64,
286        exp_avg_factor: f64,
287    },
288    /// Normalisation backward.
289    NormBwd {
290        mode: NormMode,
291        x: i64,
292        dy: i64,
293        scale: i64,
294        mean: i64,
295        var: i64,
296        dx: i64,
297        dscale: i64,
298        dbias: i64,
299        compute_dtype: DtypeTag,
300    },
301    /// Matmul (2D) — used by attention fusion.
302    Matmul {
303        a: i64,
304        b: i64,
305        c: i64,
306        compute_dtype: DtypeTag,
307    },
308    /// Reduction (sum / max / min / mul / norm).
309    Reduce {
310        op: ReduceOp,
311        x: i64,
312        y: i64,
313        compute_dtype: DtypeTag,
314    },
315    /// Reshape (no-copy view change).
316    Reshape { x: i64, y: i64 },
317}
318
319/// Pointwise op mode.
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
321pub enum PointwiseMode {
322    Relu,
323    Sigmoid,
324    Tanh,
325    Gelu,
326    GeluApprox,
327    Swish,
328    Elu,
329    Softplus,
330    Identity,
331    Add,
332    Mul,
333    Sub,
334    Div,
335    Min,
336    Max,
337    Sqrt,
338    Rsqrt,
339    Exp,
340    Log,
341    Neg,
342    Abs,
343}
344
345/// Pooling kind.
346#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
347pub enum PoolKind {
348    MaxFwd,
349    AvgFwd,
350    MaxBwd,
351    AvgBwd,
352}
353
354/// Normalisation kind.
355#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
356pub enum NormMode {
357    BatchNorm,
358    LayerNorm,
359    InstanceNorm,
360    GroupNorm,
361    RmsNorm,
362}
363
364/// Normalisation training phase.
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
366pub enum NormPhase {
367    Inference,
368    Training,
369    /// Persistent batchnorm (CUDNN_BN_FINALIZE_STATISTICS).
370    PersistentTraining,
371}
372
373/// Reduction op tag.
374#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
375pub enum ReduceOp {
376    Add,
377    Mul,
378    Min,
379    Max,
380    Mean,
381    Norm1,
382    Norm2,
383}
384
385// Manual Hash for OpSpec so f64 fields hash via to_bits().
386impl Hash for OpSpec {
387    fn hash<H: Hasher>(&self, h: &mut H) {
388        match self {
389            OpSpec::ConvFwd {
390                x,
391                w,
392                y,
393                spatial_dims,
394                pre_padding,
395                post_padding,
396                stride,
397                dilation,
398                compute_dtype,
399                alpha,
400                beta,
401            } => {
402                0u8.hash(h);
403                x.hash(h);
404                w.hash(h);
405                y.hash(h);
406                spatial_dims.hash(h);
407                pre_padding.hash(h);
408                post_padding.hash(h);
409                stride.hash(h);
410                dilation.hash(h);
411                compute_dtype.hash(h);
412                alpha.to_bits().hash(h);
413                beta.to_bits().hash(h);
414            }
415            OpSpec::ConvBwdData {
416                dy,
417                w,
418                dx,
419                spatial_dims,
420                pre_padding,
421                post_padding,
422                stride,
423                dilation,
424                compute_dtype,
425                alpha,
426                beta,
427            } => {
428                1u8.hash(h);
429                dy.hash(h);
430                w.hash(h);
431                dx.hash(h);
432                spatial_dims.hash(h);
433                pre_padding.hash(h);
434                post_padding.hash(h);
435                stride.hash(h);
436                dilation.hash(h);
437                compute_dtype.hash(h);
438                alpha.to_bits().hash(h);
439                beta.to_bits().hash(h);
440            }
441            OpSpec::ConvBwdFilter {
442                x,
443                dy,
444                dw,
445                spatial_dims,
446                pre_padding,
447                post_padding,
448                stride,
449                dilation,
450                compute_dtype,
451                alpha,
452                beta,
453            } => {
454                2u8.hash(h);
455                x.hash(h);
456                dy.hash(h);
457                dw.hash(h);
458                spatial_dims.hash(h);
459                pre_padding.hash(h);
460                post_padding.hash(h);
461                stride.hash(h);
462                dilation.hash(h);
463                compute_dtype.hash(h);
464                alpha.to_bits().hash(h);
465                beta.to_bits().hash(h);
466            }
467            OpSpec::Pointwise {
468                mode,
469                x,
470                b,
471                y,
472                compute_dtype,
473                alpha1,
474                alpha2,
475            } => {
476                3u8.hash(h);
477                mode.hash(h);
478                x.hash(h);
479                b.hash(h);
480                y.hash(h);
481                compute_dtype.hash(h);
482                alpha1.to_bits().hash(h);
483                alpha2.to_bits().hash(h);
484            }
485            OpSpec::PoolFwd {
486                kind,
487                x,
488                y,
489                window,
490                pre_padding,
491                post_padding,
492                stride,
493                compute_dtype,
494            } => {
495                4u8.hash(h);
496                kind.hash(h);
497                x.hash(h);
498                y.hash(h);
499                window.hash(h);
500                pre_padding.hash(h);
501                post_padding.hash(h);
502                stride.hash(h);
503                compute_dtype.hash(h);
504            }
505            OpSpec::PoolBwd {
506                kind,
507                dy,
508                x,
509                y,
510                dx,
511                window,
512                pre_padding,
513                post_padding,
514                stride,
515                compute_dtype,
516            } => {
517                5u8.hash(h);
518                kind.hash(h);
519                dy.hash(h);
520                x.hash(h);
521                y.hash(h);
522                dx.hash(h);
523                window.hash(h);
524                pre_padding.hash(h);
525                post_padding.hash(h);
526                stride.hash(h);
527                compute_dtype.hash(h);
528            }
529            OpSpec::NormFwd {
530                mode,
531                phase,
532                x,
533                scale,
534                bias,
535                mean,
536                var,
537                y,
538                compute_dtype,
539                epsilon,
540                exp_avg_factor,
541            } => {
542                6u8.hash(h);
543                mode.hash(h);
544                phase.hash(h);
545                x.hash(h);
546                scale.hash(h);
547                bias.hash(h);
548                mean.hash(h);
549                var.hash(h);
550                y.hash(h);
551                compute_dtype.hash(h);
552                epsilon.to_bits().hash(h);
553                exp_avg_factor.to_bits().hash(h);
554            }
555            OpSpec::NormBwd {
556                mode,
557                x,
558                dy,
559                scale,
560                mean,
561                var,
562                dx,
563                dscale,
564                dbias,
565                compute_dtype,
566            } => {
567                7u8.hash(h);
568                mode.hash(h);
569                x.hash(h);
570                dy.hash(h);
571                scale.hash(h);
572                mean.hash(h);
573                var.hash(h);
574                dx.hash(h);
575                dscale.hash(h);
576                dbias.hash(h);
577                compute_dtype.hash(h);
578            }
579            OpSpec::Matmul {
580                a,
581                b,
582                c,
583                compute_dtype,
584            } => {
585                8u8.hash(h);
586                a.hash(h);
587                b.hash(h);
588                c.hash(h);
589                compute_dtype.hash(h);
590            }
591            OpSpec::Reduce {
592                op,
593                x,
594                y,
595                compute_dtype,
596            } => {
597                9u8.hash(h);
598                op.hash(h);
599                x.hash(h);
600                y.hash(h);
601                compute_dtype.hash(h);
602            }
603            OpSpec::Reshape { x, y } => {
604                10u8.hash(h);
605                x.hash(h);
606                y.hash(h);
607            }
608        }
609    }
610}
611
612/// Top-level operation graph.
613#[derive(Debug, Clone)]
614pub struct OperationGraphSpec {
615    pub tensors: Vec<TensorSpec>,
616    pub ops: Vec<OpSpec>,
617    /// Optional name for diagnostics.
618    pub name: String,
619}
620
621impl OperationGraphSpec {
622    pub fn new(name: impl Into<String>) -> Self {
623        Self {
624            tensors: Vec::new(),
625            ops: Vec::new(),
626            name: name.into(),
627        }
628    }
629
630    pub fn add_tensor(&mut self, t: TensorSpec) -> i64 {
631        let uid = t.uid;
632        self.tensors.push(t);
633        uid
634    }
635
636    pub fn add_op(&mut self, op: OpSpec) {
637        self.ops.push(op);
638    }
639
640    /// Stable signature digest for plan-cache keying.
641    pub fn signature(&self) -> u64 {
642        let mut h = DefaultHasher::new();
643        self.tensors.hash(&mut h);
644        self.ops.hash(&mut h);
645        h.finish()
646    }
647
648    /// Drive `cudnnBackendCreateDescriptor` for every tensor and op,
649    /// then build an `OPERATION_GRAPH_DESCRIPTOR`. Returns the
650    /// finalised graph descriptor.
651    ///
652    /// On hosts without cuDNN, this short-circuits with
653    /// `LibraryError("cudnn-frontend graph build path requires a real handle")`.
654    #[cfg(feature = "cudnn")]
655    pub fn build_into(
656        &self,
657        _handle: cudnn_sys::cudnnHandle_t,
658    ) -> Result<crate::sys::cudnn::BackendDescriptor, GpuError> {
659        // The full backend-descriptor build path is non-trivial — it
660        // would walk every tensor / op kind, allocate sub-descriptors,
661        // and finalise. The skeleton here keeps the entry point so
662        // request-side dispatch and tests compile against the same
663        // surface that real GPU runs use; runtime tests fill in the
664        // body when a real handle is available.
665        Err(GpuError::LibraryError {
666            lib: "cudnn",
667            msg: "OperationGraphSpec::build_into not yet wired (Phase 2 \
668                  cuDNN frontend skeleton)"
669                .to_string(),
670        })
671    }
672}
673
674/// Cached execution-plan handle. On a host without cuDNN this is
675/// just a marker that "this signature was prepared" — the actual
676/// `BackendDescriptor` lives only on a real GPU build.
677#[derive(Debug)]
678pub struct CachedPlan {
679    pub signature: u64,
680    pub op_kind: &'static str,
681    pub dtype: DtypeTag,
682    pub workspace_bytes: usize,
683    /// `None` on host-only build.
684    #[cfg(feature = "cudnn")]
685    pub plan: Option<crate::sys::cudnn::BackendDescriptor>,
686}
687
688unsafe impl Send for CachedPlan {}
689
690/// Plan-cache key: op-kind + dtype + signature digest. Compact
691/// (24 bytes) so the LRU lookup is cheap.
692#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
693pub struct PlanCacheKey {
694    pub op_kind: &'static str,
695    pub dtype: DtypeTag,
696    pub signature: u64,
697}
698
699/// LRU plan cache. The cuDNN actor wraps this in a `Mutex` and shares
700/// one instance across all op kinds — entries are tagged by op_kind
701/// in the key.
702pub struct PlanCache {
703    inner: LruCache<PlanCacheKey, CachedPlan>,
704}
705
706impl PlanCache {
707    pub fn new(cap: usize) -> Self {
708        Self {
709            inner: LruCache::new(NonZeroUsize::new(cap.max(1)).unwrap()),
710        }
711    }
712
713    pub fn get(&mut self, key: &PlanCacheKey) -> Option<&CachedPlan> {
714        self.inner.get(key)
715    }
716
717    pub fn put(&mut self, key: PlanCacheKey, plan: CachedPlan) {
718        self.inner.put(key, plan);
719    }
720
721    pub fn len(&self) -> usize {
722        self.inner.len()
723    }
724
725    pub fn cap(&self) -> usize {
726        self.inner.cap().get()
727    }
728
729    pub fn clear(&mut self) {
730        self.inner.clear();
731    }
732}
733
734impl Default for PlanCache {
735    fn default() -> Self {
736        Self::new(DEFAULT_PLAN_CACHE_SIZE)
737    }
738}
739
740/// Build a `PlanCacheKey` from an op-kind tag + dtype + an
741/// [`OperationGraphSpec`]. The signature is the full graph digest.
742pub fn cache_key(
743    op_kind: &'static str,
744    dtype: DtypeTag,
745    graph: &OperationGraphSpec,
746) -> PlanCacheKey {
747    PlanCacheKey {
748        op_kind,
749        dtype,
750        signature: graph.signature(),
751    }
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    #[test]
759    fn nchw_packed_strides_4d() {
760        let dims = vec![2i64, 3, 4, 5];
761        let s = packed_strides(&dims, TensorLayout::NchwPacked);
762        // [3*4*5, 4*5, 5, 1]
763        assert_eq!(s, vec![60, 20, 5, 1]);
764    }
765
766    #[test]
767    fn nhwc_packed_strides_4d() {
768        let dims = vec![2i64, 3, 4, 5];
769        let s = packed_strides(&dims, TensorLayout::NhwcPacked);
770        // strides: N -> 4*5*3 = 60, C -> 1, H -> 5*3 = 15, W -> 3
771        assert_eq!(s[1], 1);
772        assert_eq!(s[3], 3);
773        assert_eq!(s[2], 15);
774        assert_eq!(s[0], 60);
775    }
776
777    #[test]
778    fn tensor_spec_round_trip() {
779        let t = TensorSpec::new(1, DtypeTag::F32, vec![1, 3, 8, 8], TensorLayout::NchwPacked)
780            .with_alignment(32);
781        assert_eq!(t.dims, vec![1, 3, 8, 8]);
782        assert_eq!(t.strides, vec![192, 64, 8, 1]);
783        assert_eq!(t.alignment, 32);
784        assert!(!t.is_virtual);
785    }
786
787    #[test]
788    fn op_graph_signature_is_deterministic() {
789        let mut g1 = OperationGraphSpec::new("conv");
790        g1.add_tensor(TensorSpec::new(
791            1,
792            DtypeTag::F32,
793            vec![1, 3, 8, 8],
794            TensorLayout::NchwPacked,
795        ));
796        g1.add_tensor(TensorSpec::new(
797            2,
798            DtypeTag::F32,
799            vec![16, 3, 3, 3],
800            TensorLayout::NchwPacked,
801        ));
802        g1.add_tensor(TensorSpec::new(
803            3,
804            DtypeTag::F32,
805            vec![1, 16, 6, 6],
806            TensorLayout::NchwPacked,
807        ));
808        g1.add_op(OpSpec::ConvFwd {
809            x: 1,
810            w: 2,
811            y: 3,
812            spatial_dims: 2,
813            pre_padding: vec![0, 0],
814            post_padding: vec![0, 0],
815            stride: vec![1, 1],
816            dilation: vec![1, 1],
817            compute_dtype: DtypeTag::F32,
818            alpha: 1.0,
819            beta: 0.0,
820        });
821        let s1 = g1.signature();
822
823        let mut g2 = OperationGraphSpec::new("conv-renamed");
824        g2.add_tensor(TensorSpec::new(
825            1,
826            DtypeTag::F32,
827            vec![1, 3, 8, 8],
828            TensorLayout::NchwPacked,
829        ));
830        g2.add_tensor(TensorSpec::new(
831            2,
832            DtypeTag::F32,
833            vec![16, 3, 3, 3],
834            TensorLayout::NchwPacked,
835        ));
836        g2.add_tensor(TensorSpec::new(
837            3,
838            DtypeTag::F32,
839            vec![1, 16, 6, 6],
840            TensorLayout::NchwPacked,
841        ));
842        g2.add_op(OpSpec::ConvFwd {
843            x: 1,
844            w: 2,
845            y: 3,
846            spatial_dims: 2,
847            pre_padding: vec![0, 0],
848            post_padding: vec![0, 0],
849            stride: vec![1, 1],
850            dilation: vec![1, 1],
851            compute_dtype: DtypeTag::F32,
852            alpha: 1.0,
853            beta: 0.0,
854        });
855        let s2 = g2.signature();
856        // Name is metadata only, not part of the digest.
857        assert_eq!(s1, s2);
858    }
859
860    #[test]
861    fn plan_cache_lru_eviction() {
862        let mut cache = PlanCache::new(2);
863        let k1 = PlanCacheKey {
864            op_kind: "conv_fwd",
865            dtype: DtypeTag::F32,
866            signature: 1,
867        };
868        let k2 = PlanCacheKey {
869            op_kind: "conv_fwd",
870            dtype: DtypeTag::F32,
871            signature: 2,
872        };
873        let k3 = PlanCacheKey {
874            op_kind: "conv_fwd",
875            dtype: DtypeTag::F32,
876            signature: 3,
877        };
878        let mk = |sig| CachedPlan {
879            signature: sig,
880            op_kind: "conv_fwd",
881            dtype: DtypeTag::F32,
882            workspace_bytes: 0,
883            #[cfg(feature = "cudnn")]
884            plan: None,
885        };
886        cache.put(k1, mk(1));
887        cache.put(k2, mk(2));
888        cache.put(k3, mk(3));
889        assert_eq!(cache.len(), 2);
890        assert!(cache.get(&k1).is_none());
891        assert!(cache.get(&k2).is_some());
892        assert!(cache.get(&k3).is_some());
893    }
894
895    #[test]
896    fn dtype_tags_have_names() {
897        assert_eq!(DtypeTag::F32.name(), "f32");
898        assert_eq!(DtypeTag::F16.name(), "f16");
899        assert_eq!(DtypeTag::Bf16.name(), "bf16");
900        assert_eq!(DtypeTag::I8.name(), "i8");
901    }
902
903    /// Exercise the backend descriptor builder against a small mocked
904    /// op tree — verifies the spec layer round-trips without touching
905    /// FFI on host builds.
906    #[test]
907    fn backend_descriptor_builder_round_trip() {
908        let mut graph = OperationGraphSpec::new("test-graph");
909        let x = graph.add_tensor(TensorSpec::new(
910            1,
911            DtypeTag::F32,
912            vec![2, 3, 4, 4],
913            TensorLayout::NchwPacked,
914        ));
915        let w = graph.add_tensor(TensorSpec::new(
916            2,
917            DtypeTag::F32,
918            vec![8, 3, 3, 3],
919            TensorLayout::NchwPacked,
920        ));
921        let y = graph.add_tensor(
922            TensorSpec::new(3, DtypeTag::F32, vec![2, 8, 2, 2], TensorLayout::NchwPacked)
923                .virtualized(),
924        );
925        graph.add_op(OpSpec::ConvFwd {
926            x,
927            w,
928            y,
929            spatial_dims: 2,
930            pre_padding: vec![0, 0],
931            post_padding: vec![0, 0],
932            stride: vec![1, 1],
933            dilation: vec![1, 1],
934            compute_dtype: DtypeTag::F32,
935            alpha: 1.0,
936            beta: 0.0,
937        });
938        // Add a fused activation on the conv output -> a virtual sink.
939        let act_out = graph.add_tensor(TensorSpec::new(
940            4,
941            DtypeTag::F32,
942            vec![2, 8, 2, 2],
943            TensorLayout::NchwPacked,
944        ));
945        graph.add_op(OpSpec::Pointwise {
946            mode: PointwiseMode::Relu,
947            x: y,
948            b: None,
949            y: act_out,
950            compute_dtype: DtypeTag::F32,
951            alpha1: 1.0,
952            alpha2: 0.0,
953        });
954        assert_eq!(graph.tensors.len(), 4);
955        assert_eq!(graph.ops.len(), 2);
956        // Signature stable under a clone.
957        let cloned = graph.clone();
958        assert_eq!(graph.signature(), cloned.signature());
959        // Spec signature differs once we change a stride.
960        let mut graph2 = graph.clone();
961        graph2.tensors[0].strides = vec![999, 1, 1, 1];
962        assert_ne!(graph.signature(), graph2.signature());
963    }
964}