Skip to main content

atomr_accel_cuda/kernel/tensor/
compute_desc.rs

1//! Mixed-precision compute descriptors used by every cuTENSOR op.
2//!
3//! cuTENSOR's `cutensorComputeDescriptor_t` is opaque; the library
4//! exports a fixed set of named globals (`CUTENSOR_R_MIN_32F`, …).
5//! We expose them as a Rust enum and resolve to the underlying
6//! pointer at the call site.
7
8use cudarc::cutensor::sys as ct_sys;
9
10use crate::sys::cutensor as ct_local;
11
12/// Selects the cuTENSOR mixed-precision compute path. `MinF32` is
13/// the natural pairing for f16/bf16 inputs that accumulate in f32.
14#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
15pub enum ComputeDesc {
16    /// `CUTENSOR_R_32F` — full fp32, no min-precision.
17    F32,
18    /// `CUTENSOR_R_64F` — full fp64.
19    F64,
20    /// `CUTENSOR_R_MIN_32F` — fp32 accumulation with min-precision
21    /// kernels. Default for f32 / f16 / bf16 inputs.
22    MinF32,
23    /// `CUTENSOR_R_MIN_64F` — fp64 accumulation, min-precision.
24    MinF64,
25    /// `CUTENSOR_R_MIN_16F` — pure half-precision compute.
26    MinF16,
27    /// `CUTENSOR_R_MIN_16BF` — pure bf16 compute.
28    MinBf16,
29    /// `CUTENSOR_R_MIN_TF32` — TF32 accumulation (Ampere+).
30    Tf32,
31    /// `CUTENSOR_C_32F` — complex fp32 inputs/compute.
32    C32F,
33}
34
35impl ComputeDesc {
36    pub fn tag(self) -> &'static str {
37        match self {
38            ComputeDesc::F32 => "F32",
39            ComputeDesc::F64 => "F64",
40            ComputeDesc::MinF32 => "MinF32",
41            ComputeDesc::MinF64 => "MinF64",
42            ComputeDesc::MinF16 => "MinF16",
43            ComputeDesc::MinBf16 => "MinBf16",
44            ComputeDesc::Tf32 => "Tf32",
45            ComputeDesc::C32F => "C32F",
46        }
47    }
48}
49
50/// Stable u32 fingerprint of a `ComputeDesc` for plan-cache keys.
51/// Different from `tag()` so two compute descs are guaranteed to
52/// hash distinctly without relying on string hashing.
53pub fn compute_desc_tag(c: ComputeDesc) -> u32 {
54    match c {
55        ComputeDesc::F32 => 0x01,
56        ComputeDesc::F64 => 0x02,
57        ComputeDesc::MinF32 => 0x04,
58        ComputeDesc::MinF64 => 0x08,
59        ComputeDesc::MinF16 => 0x10,
60        ComputeDesc::MinBf16 => 0x20,
61        ComputeDesc::Tf32 => 0x40,
62        ComputeDesc::C32F => 0x80,
63    }
64}
65
66/// Resolve a [`ComputeDesc`] to the corresponding extern global
67/// pointer cuTENSOR expects.
68pub fn resolve_compute_desc(c: ComputeDesc) -> ct_sys::cutensorComputeDescriptor_t {
69    match c {
70        ComputeDesc::F32 => ct_local::r_32f(),
71        ComputeDesc::F64 => ct_local::r_64f(),
72        ComputeDesc::MinF32 => ct_local::r_min_32f(),
73        ComputeDesc::MinF64 => ct_local::r_min_64f(),
74        ComputeDesc::MinF16 => ct_local::r_min_16f(),
75        ComputeDesc::MinBf16 => ct_local::r_min_16bf(),
76        ComputeDesc::Tf32 => ct_local::r_min_tf32(),
77        ComputeDesc::C32F => ct_local::c_32f(),
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn compute_desc_tags_are_unique() {
87        let descs = [
88            ComputeDesc::F32,
89            ComputeDesc::F64,
90            ComputeDesc::MinF32,
91            ComputeDesc::MinF64,
92            ComputeDesc::MinF16,
93            ComputeDesc::MinBf16,
94            ComputeDesc::Tf32,
95            ComputeDesc::C32F,
96        ];
97        let tags: Vec<u32> = descs.iter().copied().map(compute_desc_tag).collect();
98        let mut sorted = tags.clone();
99        sorted.sort();
100        sorted.dedup();
101        assert_eq!(sorted.len(), tags.len(), "tags must all be distinct");
102    }
103
104    #[test]
105    fn compute_desc_tag_strs() {
106        assert_eq!(ComputeDesc::F32.tag(), "F32");
107        assert_eq!(ComputeDesc::MinF32.tag(), "MinF32");
108        assert_eq!(ComputeDesc::Tf32.tag(), "Tf32");
109    }
110}