1#![allow(dead_code)]
30
31use std::collections::hash_map::DefaultHasher;
32use std::hash::{Hash, Hasher};
33use std::num::NonZeroUsize;
34
35use lru::LruCache;
36
37#[cfg(feature = "cudnn")]
38use cudarc::cudnn::sys as cudnn_sys;
39
40use crate::error::GpuError;
41
42pub const DEFAULT_PLAN_CACHE_SIZE: usize = 256;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum TensorLayout {
49 NchwPacked,
51 NhwcPacked,
53 Strided,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum DtypeTag {
61 F32,
62 F64,
63 F16,
64 Bf16,
65 I8,
66 I32,
67 U8,
68}
69
70impl DtypeTag {
71 pub fn name(self) -> &'static str {
72 match self {
73 DtypeTag::F32 => "f32",
74 DtypeTag::F64 => "f64",
75 DtypeTag::F16 => "f16",
76 DtypeTag::Bf16 => "bf16",
77 DtypeTag::I8 => "i8",
78 DtypeTag::I32 => "i32",
79 DtypeTag::U8 => "u8",
80 }
81 }
82
83 #[cfg(feature = "cudnn")]
85 pub fn cudnn(self) -> cudnn_sys::cudnnDataType_t {
86 use cudnn_sys::cudnnDataType_t::*;
87 match self {
88 DtypeTag::F32 => CUDNN_DATA_FLOAT,
89 DtypeTag::F64 => CUDNN_DATA_DOUBLE,
90 DtypeTag::F16 => CUDNN_DATA_HALF,
91 DtypeTag::Bf16 => CUDNN_DATA_BFLOAT16,
92 DtypeTag::I8 => CUDNN_DATA_INT8,
93 DtypeTag::I32 => CUDNN_DATA_INT32,
94 DtypeTag::U8 => CUDNN_DATA_UINT8,
95 }
96 }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq, Hash)]
102pub struct TensorSpec {
103 pub uid: i64,
104 pub dtype: DtypeTag,
105 pub dims: Vec<i64>,
106 pub strides: Vec<i64>,
107 pub alignment: i64,
109 pub is_virtual: bool,
112}
113
114impl TensorSpec {
115 pub fn new(uid: i64, dtype: DtypeTag, dims: Vec<i64>, layout: TensorLayout) -> Self {
118 let strides = packed_strides(&dims, layout);
119 Self {
120 uid,
121 dtype,
122 dims,
123 strides,
124 alignment: 16,
125 is_virtual: false,
126 }
127 }
128
129 pub fn with_strides(mut self, strides: Vec<i64>) -> Self {
130 debug_assert_eq!(strides.len(), self.dims.len());
131 self.strides = strides;
132 self
133 }
134
135 pub fn with_alignment(mut self, alignment: i64) -> Self {
136 self.alignment = alignment;
137 self
138 }
139
140 pub fn virtualized(mut self) -> Self {
141 self.is_virtual = true;
142 self
143 }
144
145 pub fn rank(&self) -> usize {
146 self.dims.len()
147 }
148}
149
150fn packed_strides(dims: &[i64], layout: TensorLayout) -> Vec<i64> {
153 let n = dims.len();
154 if n == 0 {
155 return Vec::new();
156 }
157 match layout {
158 TensorLayout::NchwPacked | TensorLayout::Strided => {
159 let mut strides = vec![1i64; n];
160 for i in (0..n - 1).rev() {
161 strides[i] = strides[i + 1] * dims[i + 1];
162 }
163 strides
164 }
165 TensorLayout::NhwcPacked => {
166 assert!(n >= 3, "NHWC layout requires at least N,C,S1");
172 let mut strides = vec![0i64; n];
173 strides[1] = 1;
175 let c = dims[1];
177 strides[n - 1] = c;
178 for i in (2..n - 1).rev() {
180 strides[i] = strides[i + 1] * dims[i + 1];
181 }
182 strides[0] = strides[2] * dims[2];
184 strides
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
197pub enum OpSpec {
198 ConvFwd {
200 x: i64,
201 w: i64,
202 y: i64,
203 spatial_dims: usize,
204 pre_padding: Vec<i64>,
205 post_padding: Vec<i64>,
206 stride: Vec<i64>,
207 dilation: Vec<i64>,
208 compute_dtype: DtypeTag,
209 alpha: f64,
210 beta: f64,
211 },
212 ConvBwdData {
214 dy: i64,
215 w: i64,
216 dx: i64,
217 spatial_dims: usize,
218 pre_padding: Vec<i64>,
219 post_padding: Vec<i64>,
220 stride: Vec<i64>,
221 dilation: Vec<i64>,
222 compute_dtype: DtypeTag,
223 alpha: f64,
224 beta: f64,
225 },
226 ConvBwdFilter {
228 x: i64,
229 dy: i64,
230 dw: i64,
231 spatial_dims: usize,
232 pre_padding: Vec<i64>,
233 post_padding: Vec<i64>,
234 stride: Vec<i64>,
235 dilation: Vec<i64>,
236 compute_dtype: DtypeTag,
237 alpha: f64,
238 beta: f64,
239 },
240 Pointwise {
242 mode: PointwiseMode,
243 x: i64,
244 b: Option<i64>,
245 y: i64,
246 compute_dtype: DtypeTag,
247 alpha1: f64,
248 alpha2: f64,
249 },
250 PoolFwd {
252 kind: PoolKind,
253 x: i64,
254 y: i64,
255 window: Vec<i64>,
256 pre_padding: Vec<i64>,
257 post_padding: Vec<i64>,
258 stride: Vec<i64>,
259 compute_dtype: DtypeTag,
260 },
261 PoolBwd {
263 kind: PoolKind,
264 dy: i64,
265 x: i64,
266 y: i64,
267 dx: i64,
268 window: Vec<i64>,
269 pre_padding: Vec<i64>,
270 post_padding: Vec<i64>,
271 stride: Vec<i64>,
272 compute_dtype: DtypeTag,
273 },
274 NormFwd {
276 mode: NormMode,
277 phase: NormPhase,
278 x: i64,
279 scale: i64,
280 bias: i64,
281 mean: Option<i64>,
282 var: Option<i64>,
283 y: i64,
284 compute_dtype: DtypeTag,
285 epsilon: f64,
286 exp_avg_factor: f64,
287 },
288 NormBwd {
290 mode: NormMode,
291 x: i64,
292 dy: i64,
293 scale: i64,
294 mean: i64,
295 var: i64,
296 dx: i64,
297 dscale: i64,
298 dbias: i64,
299 compute_dtype: DtypeTag,
300 },
301 Matmul {
303 a: i64,
304 b: i64,
305 c: i64,
306 compute_dtype: DtypeTag,
307 },
308 Reduce {
310 op: ReduceOp,
311 x: i64,
312 y: i64,
313 compute_dtype: DtypeTag,
314 },
315 Reshape { x: i64, y: i64 },
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
321pub enum PointwiseMode {
322 Relu,
323 Sigmoid,
324 Tanh,
325 Gelu,
326 GeluApprox,
327 Swish,
328 Elu,
329 Softplus,
330 Identity,
331 Add,
332 Mul,
333 Sub,
334 Div,
335 Min,
336 Max,
337 Sqrt,
338 Rsqrt,
339 Exp,
340 Log,
341 Neg,
342 Abs,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
347pub enum PoolKind {
348 MaxFwd,
349 AvgFwd,
350 MaxBwd,
351 AvgBwd,
352}
353
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
356pub enum NormMode {
357 BatchNorm,
358 LayerNorm,
359 InstanceNorm,
360 GroupNorm,
361 RmsNorm,
362}
363
364#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
366pub enum NormPhase {
367 Inference,
368 Training,
369 PersistentTraining,
371}
372
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
375pub enum ReduceOp {
376 Add,
377 Mul,
378 Min,
379 Max,
380 Mean,
381 Norm1,
382 Norm2,
383}
384
385impl Hash for OpSpec {
387 fn hash<H: Hasher>(&self, h: &mut H) {
388 match self {
389 OpSpec::ConvFwd {
390 x,
391 w,
392 y,
393 spatial_dims,
394 pre_padding,
395 post_padding,
396 stride,
397 dilation,
398 compute_dtype,
399 alpha,
400 beta,
401 } => {
402 0u8.hash(h);
403 x.hash(h);
404 w.hash(h);
405 y.hash(h);
406 spatial_dims.hash(h);
407 pre_padding.hash(h);
408 post_padding.hash(h);
409 stride.hash(h);
410 dilation.hash(h);
411 compute_dtype.hash(h);
412 alpha.to_bits().hash(h);
413 beta.to_bits().hash(h);
414 }
415 OpSpec::ConvBwdData {
416 dy,
417 w,
418 dx,
419 spatial_dims,
420 pre_padding,
421 post_padding,
422 stride,
423 dilation,
424 compute_dtype,
425 alpha,
426 beta,
427 } => {
428 1u8.hash(h);
429 dy.hash(h);
430 w.hash(h);
431 dx.hash(h);
432 spatial_dims.hash(h);
433 pre_padding.hash(h);
434 post_padding.hash(h);
435 stride.hash(h);
436 dilation.hash(h);
437 compute_dtype.hash(h);
438 alpha.to_bits().hash(h);
439 beta.to_bits().hash(h);
440 }
441 OpSpec::ConvBwdFilter {
442 x,
443 dy,
444 dw,
445 spatial_dims,
446 pre_padding,
447 post_padding,
448 stride,
449 dilation,
450 compute_dtype,
451 alpha,
452 beta,
453 } => {
454 2u8.hash(h);
455 x.hash(h);
456 dy.hash(h);
457 dw.hash(h);
458 spatial_dims.hash(h);
459 pre_padding.hash(h);
460 post_padding.hash(h);
461 stride.hash(h);
462 dilation.hash(h);
463 compute_dtype.hash(h);
464 alpha.to_bits().hash(h);
465 beta.to_bits().hash(h);
466 }
467 OpSpec::Pointwise {
468 mode,
469 x,
470 b,
471 y,
472 compute_dtype,
473 alpha1,
474 alpha2,
475 } => {
476 3u8.hash(h);
477 mode.hash(h);
478 x.hash(h);
479 b.hash(h);
480 y.hash(h);
481 compute_dtype.hash(h);
482 alpha1.to_bits().hash(h);
483 alpha2.to_bits().hash(h);
484 }
485 OpSpec::PoolFwd {
486 kind,
487 x,
488 y,
489 window,
490 pre_padding,
491 post_padding,
492 stride,
493 compute_dtype,
494 } => {
495 4u8.hash(h);
496 kind.hash(h);
497 x.hash(h);
498 y.hash(h);
499 window.hash(h);
500 pre_padding.hash(h);
501 post_padding.hash(h);
502 stride.hash(h);
503 compute_dtype.hash(h);
504 }
505 OpSpec::PoolBwd {
506 kind,
507 dy,
508 x,
509 y,
510 dx,
511 window,
512 pre_padding,
513 post_padding,
514 stride,
515 compute_dtype,
516 } => {
517 5u8.hash(h);
518 kind.hash(h);
519 dy.hash(h);
520 x.hash(h);
521 y.hash(h);
522 dx.hash(h);
523 window.hash(h);
524 pre_padding.hash(h);
525 post_padding.hash(h);
526 stride.hash(h);
527 compute_dtype.hash(h);
528 }
529 OpSpec::NormFwd {
530 mode,
531 phase,
532 x,
533 scale,
534 bias,
535 mean,
536 var,
537 y,
538 compute_dtype,
539 epsilon,
540 exp_avg_factor,
541 } => {
542 6u8.hash(h);
543 mode.hash(h);
544 phase.hash(h);
545 x.hash(h);
546 scale.hash(h);
547 bias.hash(h);
548 mean.hash(h);
549 var.hash(h);
550 y.hash(h);
551 compute_dtype.hash(h);
552 epsilon.to_bits().hash(h);
553 exp_avg_factor.to_bits().hash(h);
554 }
555 OpSpec::NormBwd {
556 mode,
557 x,
558 dy,
559 scale,
560 mean,
561 var,
562 dx,
563 dscale,
564 dbias,
565 compute_dtype,
566 } => {
567 7u8.hash(h);
568 mode.hash(h);
569 x.hash(h);
570 dy.hash(h);
571 scale.hash(h);
572 mean.hash(h);
573 var.hash(h);
574 dx.hash(h);
575 dscale.hash(h);
576 dbias.hash(h);
577 compute_dtype.hash(h);
578 }
579 OpSpec::Matmul {
580 a,
581 b,
582 c,
583 compute_dtype,
584 } => {
585 8u8.hash(h);
586 a.hash(h);
587 b.hash(h);
588 c.hash(h);
589 compute_dtype.hash(h);
590 }
591 OpSpec::Reduce {
592 op,
593 x,
594 y,
595 compute_dtype,
596 } => {
597 9u8.hash(h);
598 op.hash(h);
599 x.hash(h);
600 y.hash(h);
601 compute_dtype.hash(h);
602 }
603 OpSpec::Reshape { x, y } => {
604 10u8.hash(h);
605 x.hash(h);
606 y.hash(h);
607 }
608 }
609 }
610}
611
612#[derive(Debug, Clone)]
614pub struct OperationGraphSpec {
615 pub tensors: Vec<TensorSpec>,
616 pub ops: Vec<OpSpec>,
617 pub name: String,
619}
620
621impl OperationGraphSpec {
622 pub fn new(name: impl Into<String>) -> Self {
623 Self {
624 tensors: Vec::new(),
625 ops: Vec::new(),
626 name: name.into(),
627 }
628 }
629
630 pub fn add_tensor(&mut self, t: TensorSpec) -> i64 {
631 let uid = t.uid;
632 self.tensors.push(t);
633 uid
634 }
635
636 pub fn add_op(&mut self, op: OpSpec) {
637 self.ops.push(op);
638 }
639
640 pub fn signature(&self) -> u64 {
642 let mut h = DefaultHasher::new();
643 self.tensors.hash(&mut h);
644 self.ops.hash(&mut h);
645 h.finish()
646 }
647
648 #[cfg(feature = "cudnn")]
655 pub fn build_into(
656 &self,
657 _handle: cudnn_sys::cudnnHandle_t,
658 ) -> Result<crate::sys::cudnn::BackendDescriptor, GpuError> {
659 Err(GpuError::LibraryError {
666 lib: "cudnn",
667 msg: "OperationGraphSpec::build_into not yet wired (Phase 2 \
668 cuDNN frontend skeleton)"
669 .to_string(),
670 })
671 }
672}
673
674#[derive(Debug)]
678pub struct CachedPlan {
679 pub signature: u64,
680 pub op_kind: &'static str,
681 pub dtype: DtypeTag,
682 pub workspace_bytes: usize,
683 #[cfg(feature = "cudnn")]
685 pub plan: Option<crate::sys::cudnn::BackendDescriptor>,
686}
687
688unsafe impl Send for CachedPlan {}
689
690#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
693pub struct PlanCacheKey {
694 pub op_kind: &'static str,
695 pub dtype: DtypeTag,
696 pub signature: u64,
697}
698
699pub struct PlanCache {
703 inner: LruCache<PlanCacheKey, CachedPlan>,
704}
705
706impl PlanCache {
707 pub fn new(cap: usize) -> Self {
708 Self {
709 inner: LruCache::new(NonZeroUsize::new(cap.max(1)).unwrap()),
710 }
711 }
712
713 pub fn get(&mut self, key: &PlanCacheKey) -> Option<&CachedPlan> {
714 self.inner.get(key)
715 }
716
717 pub fn put(&mut self, key: PlanCacheKey, plan: CachedPlan) {
718 self.inner.put(key, plan);
719 }
720
721 pub fn len(&self) -> usize {
722 self.inner.len()
723 }
724
725 pub fn cap(&self) -> usize {
726 self.inner.cap().get()
727 }
728
729 pub fn clear(&mut self) {
730 self.inner.clear();
731 }
732}
733
734impl Default for PlanCache {
735 fn default() -> Self {
736 Self::new(DEFAULT_PLAN_CACHE_SIZE)
737 }
738}
739
740pub fn cache_key(
743 op_kind: &'static str,
744 dtype: DtypeTag,
745 graph: &OperationGraphSpec,
746) -> PlanCacheKey {
747 PlanCacheKey {
748 op_kind,
749 dtype,
750 signature: graph.signature(),
751 }
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757
758 #[test]
759 fn nchw_packed_strides_4d() {
760 let dims = vec![2i64, 3, 4, 5];
761 let s = packed_strides(&dims, TensorLayout::NchwPacked);
762 assert_eq!(s, vec![60, 20, 5, 1]);
764 }
765
766 #[test]
767 fn nhwc_packed_strides_4d() {
768 let dims = vec![2i64, 3, 4, 5];
769 let s = packed_strides(&dims, TensorLayout::NhwcPacked);
770 assert_eq!(s[1], 1);
772 assert_eq!(s[3], 3);
773 assert_eq!(s[2], 15);
774 assert_eq!(s[0], 60);
775 }
776
777 #[test]
778 fn tensor_spec_round_trip() {
779 let t = TensorSpec::new(1, DtypeTag::F32, vec![1, 3, 8, 8], TensorLayout::NchwPacked)
780 .with_alignment(32);
781 assert_eq!(t.dims, vec![1, 3, 8, 8]);
782 assert_eq!(t.strides, vec![192, 64, 8, 1]);
783 assert_eq!(t.alignment, 32);
784 assert!(!t.is_virtual);
785 }
786
787 #[test]
788 fn op_graph_signature_is_deterministic() {
789 let mut g1 = OperationGraphSpec::new("conv");
790 g1.add_tensor(TensorSpec::new(
791 1,
792 DtypeTag::F32,
793 vec![1, 3, 8, 8],
794 TensorLayout::NchwPacked,
795 ));
796 g1.add_tensor(TensorSpec::new(
797 2,
798 DtypeTag::F32,
799 vec![16, 3, 3, 3],
800 TensorLayout::NchwPacked,
801 ));
802 g1.add_tensor(TensorSpec::new(
803 3,
804 DtypeTag::F32,
805 vec![1, 16, 6, 6],
806 TensorLayout::NchwPacked,
807 ));
808 g1.add_op(OpSpec::ConvFwd {
809 x: 1,
810 w: 2,
811 y: 3,
812 spatial_dims: 2,
813 pre_padding: vec![0, 0],
814 post_padding: vec![0, 0],
815 stride: vec![1, 1],
816 dilation: vec![1, 1],
817 compute_dtype: DtypeTag::F32,
818 alpha: 1.0,
819 beta: 0.0,
820 });
821 let s1 = g1.signature();
822
823 let mut g2 = OperationGraphSpec::new("conv-renamed");
824 g2.add_tensor(TensorSpec::new(
825 1,
826 DtypeTag::F32,
827 vec![1, 3, 8, 8],
828 TensorLayout::NchwPacked,
829 ));
830 g2.add_tensor(TensorSpec::new(
831 2,
832 DtypeTag::F32,
833 vec![16, 3, 3, 3],
834 TensorLayout::NchwPacked,
835 ));
836 g2.add_tensor(TensorSpec::new(
837 3,
838 DtypeTag::F32,
839 vec![1, 16, 6, 6],
840 TensorLayout::NchwPacked,
841 ));
842 g2.add_op(OpSpec::ConvFwd {
843 x: 1,
844 w: 2,
845 y: 3,
846 spatial_dims: 2,
847 pre_padding: vec![0, 0],
848 post_padding: vec![0, 0],
849 stride: vec![1, 1],
850 dilation: vec![1, 1],
851 compute_dtype: DtypeTag::F32,
852 alpha: 1.0,
853 beta: 0.0,
854 });
855 let s2 = g2.signature();
856 assert_eq!(s1, s2);
858 }
859
860 #[test]
861 fn plan_cache_lru_eviction() {
862 let mut cache = PlanCache::new(2);
863 let k1 = PlanCacheKey {
864 op_kind: "conv_fwd",
865 dtype: DtypeTag::F32,
866 signature: 1,
867 };
868 let k2 = PlanCacheKey {
869 op_kind: "conv_fwd",
870 dtype: DtypeTag::F32,
871 signature: 2,
872 };
873 let k3 = PlanCacheKey {
874 op_kind: "conv_fwd",
875 dtype: DtypeTag::F32,
876 signature: 3,
877 };
878 let mk = |sig| CachedPlan {
879 signature: sig,
880 op_kind: "conv_fwd",
881 dtype: DtypeTag::F32,
882 workspace_bytes: 0,
883 #[cfg(feature = "cudnn")]
884 plan: None,
885 };
886 cache.put(k1, mk(1));
887 cache.put(k2, mk(2));
888 cache.put(k3, mk(3));
889 assert_eq!(cache.len(), 2);
890 assert!(cache.get(&k1).is_none());
891 assert!(cache.get(&k2).is_some());
892 assert!(cache.get(&k3).is_some());
893 }
894
895 #[test]
896 fn dtype_tags_have_names() {
897 assert_eq!(DtypeTag::F32.name(), "f32");
898 assert_eq!(DtypeTag::F16.name(), "f16");
899 assert_eq!(DtypeTag::Bf16.name(), "bf16");
900 assert_eq!(DtypeTag::I8.name(), "i8");
901 }
902
903 #[test]
907 fn backend_descriptor_builder_round_trip() {
908 let mut graph = OperationGraphSpec::new("test-graph");
909 let x = graph.add_tensor(TensorSpec::new(
910 1,
911 DtypeTag::F32,
912 vec![2, 3, 4, 4],
913 TensorLayout::NchwPacked,
914 ));
915 let w = graph.add_tensor(TensorSpec::new(
916 2,
917 DtypeTag::F32,
918 vec![8, 3, 3, 3],
919 TensorLayout::NchwPacked,
920 ));
921 let y = graph.add_tensor(
922 TensorSpec::new(3, DtypeTag::F32, vec![2, 8, 2, 2], TensorLayout::NchwPacked)
923 .virtualized(),
924 );
925 graph.add_op(OpSpec::ConvFwd {
926 x,
927 w,
928 y,
929 spatial_dims: 2,
930 pre_padding: vec![0, 0],
931 post_padding: vec![0, 0],
932 stride: vec![1, 1],
933 dilation: vec![1, 1],
934 compute_dtype: DtypeTag::F32,
935 alpha: 1.0,
936 beta: 0.0,
937 });
938 let act_out = graph.add_tensor(TensorSpec::new(
940 4,
941 DtypeTag::F32,
942 vec![2, 8, 2, 2],
943 TensorLayout::NchwPacked,
944 ));
945 graph.add_op(OpSpec::Pointwise {
946 mode: PointwiseMode::Relu,
947 x: y,
948 b: None,
949 y: act_out,
950 compute_dtype: DtypeTag::F32,
951 alpha1: 1.0,
952 alpha2: 0.0,
953 });
954 assert_eq!(graph.tensors.len(), 4);
955 assert_eq!(graph.ops.len(), 2);
956 let cloned = graph.clone();
958 assert_eq!(graph.signature(), cloned.signature());
959 let mut graph2 = graph.clone();
961 graph2.tensors[0].strides = vec![999, 1, 1, 1];
962 assert_ne!(graph.signature(), graph2.signature());
963 }
964}