atomr_accel_cuda/kernel/tensor/
compute_desc.rs1use cudarc::cutensor::sys as ct_sys;
9
10use crate::sys::cutensor as ct_local;
11
12#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
15pub enum ComputeDesc {
16 F32,
18 F64,
20 MinF32,
23 MinF64,
25 MinF16,
27 MinBf16,
29 Tf32,
31 C32F,
33}
34
35impl ComputeDesc {
36 pub fn tag(self) -> &'static str {
37 match self {
38 ComputeDesc::F32 => "F32",
39 ComputeDesc::F64 => "F64",
40 ComputeDesc::MinF32 => "MinF32",
41 ComputeDesc::MinF64 => "MinF64",
42 ComputeDesc::MinF16 => "MinF16",
43 ComputeDesc::MinBf16 => "MinBf16",
44 ComputeDesc::Tf32 => "Tf32",
45 ComputeDesc::C32F => "C32F",
46 }
47 }
48}
49
50pub fn compute_desc_tag(c: ComputeDesc) -> u32 {
54 match c {
55 ComputeDesc::F32 => 0x01,
56 ComputeDesc::F64 => 0x02,
57 ComputeDesc::MinF32 => 0x04,
58 ComputeDesc::MinF64 => 0x08,
59 ComputeDesc::MinF16 => 0x10,
60 ComputeDesc::MinBf16 => 0x20,
61 ComputeDesc::Tf32 => 0x40,
62 ComputeDesc::C32F => 0x80,
63 }
64}
65
66pub fn resolve_compute_desc(c: ComputeDesc) -> ct_sys::cutensorComputeDescriptor_t {
69 match c {
70 ComputeDesc::F32 => ct_local::r_32f(),
71 ComputeDesc::F64 => ct_local::r_64f(),
72 ComputeDesc::MinF32 => ct_local::r_min_32f(),
73 ComputeDesc::MinF64 => ct_local::r_min_64f(),
74 ComputeDesc::MinF16 => ct_local::r_min_16f(),
75 ComputeDesc::MinBf16 => ct_local::r_min_16bf(),
76 ComputeDesc::Tf32 => ct_local::r_min_tf32(),
77 ComputeDesc::C32F => ct_local::c_32f(),
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn compute_desc_tags_are_unique() {
87 let descs = [
88 ComputeDesc::F32,
89 ComputeDesc::F64,
90 ComputeDesc::MinF32,
91 ComputeDesc::MinF64,
92 ComputeDesc::MinF16,
93 ComputeDesc::MinBf16,
94 ComputeDesc::Tf32,
95 ComputeDesc::C32F,
96 ];
97 let tags: Vec<u32> = descs.iter().copied().map(compute_desc_tag).collect();
98 let mut sorted = tags.clone();
99 sorted.sort();
100 sorted.dedup();
101 assert_eq!(sorted.len(), tags.len(), "tags must all be distinct");
102 }
103
104 #[test]
105 fn compute_desc_tag_strs() {
106 assert_eq!(ComputeDesc::F32.tag(), "F32");
107 assert_eq!(ComputeDesc::MinF32.tag(), "MinF32");
108 assert_eq!(ComputeDesc::Tf32.tag(), "Tf32");
109 }
110}