Skip to main content

atomr_accel_cuda/hopper/
tma.rs

1//! Tensor Memory Accelerator (TMA) host-side descriptor builder.
2//!
3//! On Hopper (sm_90+) and Blackwell, TMA decouples global → shared
4//! memory tile copies from the threads that issue them: the kernel
5//! issues `cp.async.bulk.tensor.NN.global.shared` against an opaque
6//! [`CUtensorMap`](cudarc::driver::sys::CUtensorMap), which the host
7//! built once via `cuTensorMapEncodeTiled`. The kernel then waits on a
8//! barrier (`mbarrier.try_wait`) for the copy to land.
9//!
10//! [`TensorMapDescriptor`] is a host-side builder for the `tiled`
11//! flavour of the encode-call. The free `encode` method returns the
12//! 128-byte tensor-map struct cudarc surfaces as
13//! `cudarc::driver::sys::CUtensorMap`. This module makes no attempt to
14//! cover the `im2col` / `im2col-wide` flavours — those have shape-
15//! specific descriptor sets that fit poorly into a uniform builder.
16
17use std::fmt;
18
19/// Element dtype consumed by the TMA. Matches
20/// [`cudarc::driver::sys::CUtensorMapDataType_enum`] one-to-one. We
21/// duplicate the enum so callers don't depend on cudarc's `sys` module
22/// directly (which is gated on a CUDA-version feature in cudarc).
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum TensorMapDataType {
25    UInt8,
26    UInt16,
27    UInt32,
28    Int32,
29    UInt64,
30    Int64,
31    Float16,
32    Float32,
33    Float64,
34    BFloat16,
35    Float32Ftz,
36    /// TF32 / `__nv_tf32`.
37    TFloat32,
38    TFloat32Ftz,
39    /// Blackwell-only fp4/fp6/fp8 variants. Available with the
40    /// `blackwell` cargo feature only — the host-side enum is allowed
41    /// regardless of feature so unit tests can round-trip the value;
42    /// runtime kernels need a Blackwell driver.
43    Float8E4m3,
44    Float8E5m2,
45    Float4E2m1,
46    Float6E2m3,
47    Float6E3m2,
48}
49
50impl TensorMapDataType {
51    /// Numeric value used by the underlying CUDA driver. Matches the
52    /// public ABI of `CUtensorMapDataType_enum`.
53    pub fn as_u32(self) -> u32 {
54        match self {
55            TensorMapDataType::UInt8 => 0,
56            TensorMapDataType::UInt16 => 1,
57            TensorMapDataType::UInt32 => 2,
58            TensorMapDataType::Int32 => 3,
59            TensorMapDataType::UInt64 => 4,
60            TensorMapDataType::Int64 => 5,
61            TensorMapDataType::Float16 => 6,
62            TensorMapDataType::Float32 => 7,
63            TensorMapDataType::Float64 => 8,
64            TensorMapDataType::BFloat16 => 9,
65            TensorMapDataType::Float32Ftz => 10,
66            TensorMapDataType::TFloat32 => 11,
67            TensorMapDataType::TFloat32Ftz => 12,
68            TensorMapDataType::Float8E4m3 => 13,
69            TensorMapDataType::Float8E5m2 => 14,
70            TensorMapDataType::Float4E2m1 => 15,
71            TensorMapDataType::Float6E2m3 => 16,
72            TensorMapDataType::Float6E3m2 => 17,
73        }
74    }
75}
76
77/// Interleave layout for 1D / 2D / 3D bulk-tile copies.
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum TensorMapInterleave {
80    /// No interleave — natural row-major / column-major as described by
81    /// `global_strides`.
82    None,
83    /// 16-byte interleave — packs four fp32 lanes per 16B chunk.
84    Bytes16,
85    /// 32-byte interleave — packs eight fp32 lanes / sixteen fp16 lanes.
86    Bytes32,
87}
88
89impl TensorMapInterleave {
90    pub fn as_u32(self) -> u32 {
91        match self {
92            TensorMapInterleave::None => 0,
93            TensorMapInterleave::Bytes16 => 1,
94            TensorMapInterleave::Bytes32 => 2,
95        }
96    }
97}
98
99/// Shared-memory swizzle pattern. A swizzled load/store interleaves
100/// rows so 4-thread bank conflicts can't arise.
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum TensorMapSwizzle {
103    None,
104    /// 32B swizzle — cache-line aligned 4-element rows.
105    Bytes32,
106    /// 64B swizzle — half cache line.
107    Bytes64,
108    /// 128B swizzle — full cache line. Most common for wgmma feeds.
109    Bytes128,
110}
111
112impl TensorMapSwizzle {
113    pub fn as_u32(self) -> u32 {
114        match self {
115            TensorMapSwizzle::None => 0,
116            TensorMapSwizzle::Bytes32 => 1,
117            TensorMapSwizzle::Bytes64 => 2,
118            TensorMapSwizzle::Bytes128 => 3,
119        }
120    }
121}
122
123/// L2 promotion hint — the subset of L2 cache the TMA is allowed to
124/// promote into.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum TensorMapL2Promotion {
127    None,
128    Bytes64,
129    Bytes128,
130    Bytes256,
131}
132
133impl TensorMapL2Promotion {
134    pub fn as_u32(self) -> u32 {
135        match self {
136            TensorMapL2Promotion::None => 0,
137            TensorMapL2Promotion::Bytes64 => 1,
138            TensorMapL2Promotion::Bytes128 => 2,
139            TensorMapL2Promotion::Bytes256 => 3,
140        }
141    }
142}
143
144/// Out-of-bounds fill mode for partial tiles.
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub enum TensorMapOobFill {
147    /// Pad with `0`. Most common for matmul tail tiles.
148    NaZero,
149    /// Pad with NaN where the dtype representation supports it.
150    NanRequest,
151}
152
153impl TensorMapOobFill {
154    pub fn as_u32(self) -> u32 {
155        match self {
156            TensorMapOobFill::NaZero => 0,
157            TensorMapOobFill::NanRequest => 1,
158        }
159    }
160}
161
162/// Host-side builder for the tiled flavour of `cuTensorMapEncodeTiled`.
163///
164/// Cap at 5 dimensions (the CUDA-driver hard limit). Validation:
165///
166/// * `rank` must equal `global_dim.len()` and `box_dim.len()` and
167///   `element_strides.len()`; `global_strides.len() == rank - 1`.
168/// * `global_address` must be 16-byte aligned (TMA requires it).
169/// * Every entry of `global_dim` and `box_dim` must be non-zero.
170#[derive(Debug, Clone)]
171pub struct TensorMapDescriptor {
172    pub data_type: TensorMapDataType,
173    pub global_address: usize,
174    pub global_dim: Vec<u64>,
175    pub global_strides: Vec<u64>,
176    pub box_dim: Vec<u32>,
177    pub element_strides: Vec<u32>,
178    pub interleave: TensorMapInterleave,
179    pub swizzle: TensorMapSwizzle,
180    pub l2_promotion: TensorMapL2Promotion,
181    pub oob_fill: TensorMapOobFill,
182}
183
184impl TensorMapDescriptor {
185    /// Construct a defaulted-shape descriptor. Caller must populate
186    /// `global_dim`, `global_strides`, `box_dim`, `element_strides`
187    /// before calling [`TensorMapDescriptor::validate`].
188    pub fn new(data_type: TensorMapDataType, global_address: usize) -> Self {
189        Self {
190            data_type,
191            global_address,
192            global_dim: Vec::new(),
193            global_strides: Vec::new(),
194            box_dim: Vec::new(),
195            element_strides: Vec::new(),
196            interleave: TensorMapInterleave::None,
197            swizzle: TensorMapSwizzle::None,
198            l2_promotion: TensorMapL2Promotion::None,
199            oob_fill: TensorMapOobFill::NaZero,
200        }
201    }
202
203    /// Tensor rank (1..=5).
204    pub fn rank(&self) -> usize {
205        self.global_dim.len()
206    }
207
208    /// Validate that all sizes line up and the global address is 16B
209    /// aligned. Returns `Err(TmaEncodeError::*)` on mismatch. Pure host
210    /// validation — does not call into the driver.
211    pub fn validate(&self) -> Result<(), TmaEncodeError> {
212        let r = self.rank();
213        if r == 0 || r > 5 {
214            return Err(TmaEncodeError::BadRank(r));
215        }
216        if self.box_dim.len() != r {
217            return Err(TmaEncodeError::Mismatch {
218                what: "box_dim",
219                expected: r,
220                got: self.box_dim.len(),
221            });
222        }
223        if self.element_strides.len() != r {
224            return Err(TmaEncodeError::Mismatch {
225                what: "element_strides",
226                expected: r,
227                got: self.element_strides.len(),
228            });
229        }
230        // CUDA's API takes `rank - 1` global strides because the
231        // innermost stride is implicit (= sizeof(elem)).
232        if !self.global_strides.is_empty() && self.global_strides.len() != r - 1 {
233            return Err(TmaEncodeError::Mismatch {
234                what: "global_strides",
235                expected: r - 1,
236                got: self.global_strides.len(),
237            });
238        }
239        if self.global_address % 16 != 0 {
240            return Err(TmaEncodeError::UnalignedAddress(self.global_address));
241        }
242        if self.global_dim.contains(&0) {
243            return Err(TmaEncodeError::ZeroDim("global_dim"));
244        }
245        if self.box_dim.contains(&0) {
246            return Err(TmaEncodeError::ZeroDim("box_dim"));
247        }
248        Ok(())
249    }
250
251    /// Validate, then call `cuTensorMapEncodeTiled` to populate a
252    /// 128-byte `CUtensorMap`. Available only with the `hopper` cargo
253    /// feature (otherwise the call would have nothing to encode into).
254    ///
255    /// # Safety
256    ///
257    /// The caller must ensure `global_address` points at a
258    /// driver-mapped device allocation that lives at least as long as
259    /// the kernels that consume the resulting tensor map. The function
260    /// itself is safe to call (it does not dereference the address) but
261    /// kernels launched against the encoded map perform device reads.
262    #[cfg(feature = "hopper")]
263    pub fn encode(&self) -> Result<TensorMap, TmaEncodeError> {
264        use cudarc::driver::sys as cu;
265
266        self.validate()?;
267        let mut tm: cu::CUtensorMap = unsafe { std::mem::zeroed() };
268        // SAFETY: every pointer is into our own Vecs which live for
269        // the duration of this call. cuTensorMapEncodeTiled copies
270        // the values into the opaque map; no aliasing required.
271        let res = unsafe {
272            cu::cuTensorMapEncodeTiled(
273                &mut tm,
274                std::mem::transmute::<u32, cu::CUtensorMapDataType>(self.data_type.as_u32()),
275                self.rank() as cu::cuuint32_t,
276                self.global_address as *mut _,
277                self.global_dim.as_ptr(),
278                self.global_strides.as_ptr(),
279                self.box_dim.as_ptr(),
280                self.element_strides.as_ptr(),
281                std::mem::transmute::<u32, cu::CUtensorMapInterleave>(self.interleave.as_u32()),
282                std::mem::transmute::<u32, cu::CUtensorMapSwizzle>(self.swizzle.as_u32()),
283                std::mem::transmute::<u32, cu::CUtensorMapL2promotion>(self.l2_promotion.as_u32()),
284                std::mem::transmute::<u32, cu::CUtensorMapFloatOOBfill>(self.oob_fill.as_u32()),
285            )
286        };
287        if res != cu::CUresult::CUDA_SUCCESS {
288            return Err(TmaEncodeError::DriverError(res as i32));
289        }
290        Ok(TensorMap(tm))
291    }
292}
293
294/// Opaque 128-byte handle returned by [`TensorMapDescriptor::encode`].
295/// Pass into NVRTC kernels as a `const __grid_constant__ CUtensorMap`.
296#[cfg(feature = "hopper")]
297pub struct TensorMap(pub cudarc::driver::sys::CUtensorMap);
298
299#[cfg(feature = "hopper")]
300impl TensorMap {
301    /// Raw byte pointer to the encoded map. Useful when the kernel
302    /// signature wants `const CUtensorMap*`.
303    pub fn as_ptr(&self) -> *const cudarc::driver::sys::CUtensorMap {
304        &self.0
305    }
306}
307
308#[derive(Debug, Clone, PartialEq, Eq)]
309pub enum TmaEncodeError {
310    BadRank(usize),
311    Mismatch {
312        what: &'static str,
313        expected: usize,
314        got: usize,
315    },
316    UnalignedAddress(usize),
317    ZeroDim(&'static str),
318    /// Driver returned a non-zero `CUresult`.
319    DriverError(i32),
320}
321
322impl fmt::Display for TmaEncodeError {
323    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324        match self {
325            TmaEncodeError::BadRank(r) => write!(f, "TMA rank {r} out of [1,5]"),
326            TmaEncodeError::Mismatch {
327                what,
328                expected,
329                got,
330            } => {
331                write!(
332                    f,
333                    "TMA descriptor: {what}.len() = {got}, expected {expected}"
334                )
335            }
336            TmaEncodeError::UnalignedAddress(a) => {
337                write!(f, "TMA global_address 0x{a:x} is not 16-byte aligned")
338            }
339            TmaEncodeError::ZeroDim(field) => write!(f, "TMA descriptor: {field} contains a zero"),
340            TmaEncodeError::DriverError(c) => write!(f, "cuTensorMapEncodeTiled returned {c}"),
341        }
342    }
343}
344
345impl std::error::Error for TmaEncodeError {}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    fn sample_2d_descriptor() -> TensorMapDescriptor {
352        TensorMapDescriptor {
353            data_type: TensorMapDataType::Float16,
354            global_address: 0x1_0000, // 16B aligned
355            global_dim: vec![1024, 1024],
356            global_strides: vec![1024 * 2], // row stride in bytes for fp16
357            box_dim: vec![64, 64],
358            element_strides: vec![1, 1],
359            interleave: TensorMapInterleave::None,
360            swizzle: TensorMapSwizzle::Bytes128,
361            l2_promotion: TensorMapL2Promotion::Bytes128,
362            oob_fill: TensorMapOobFill::NaZero,
363        }
364    }
365
366    /// Phase 5 test: round-trip a tiled-TMA descriptor through the
367    /// builder + validate path. No GPU required — the validation logic
368    /// is host-side.
369    #[test]
370    fn tensor_map_encode_descriptor_round_trip() {
371        let d = sample_2d_descriptor();
372        d.validate().expect("sample descriptor must validate");
373
374        // Field round-trip.
375        assert_eq!(d.rank(), 2);
376        assert_eq!(d.data_type.as_u32(), TensorMapDataType::Float16.as_u32());
377        assert_eq!(d.swizzle.as_u32(), TensorMapSwizzle::Bytes128.as_u32());
378        assert_eq!(
379            d.l2_promotion.as_u32(),
380            TensorMapL2Promotion::Bytes128.as_u32()
381        );
382        assert_eq!(d.oob_fill.as_u32(), TensorMapOobFill::NaZero.as_u32());
383
384        // Mutate to misaligned address — must reject.
385        let mut bad = d.clone();
386        bad.global_address = 0x1_0001;
387        assert!(matches!(
388            bad.validate().unwrap_err(),
389            TmaEncodeError::UnalignedAddress(_)
390        ));
391
392        // Mutate to wrong-length box_dim — must reject.
393        let mut bad = d.clone();
394        bad.box_dim.push(32);
395        assert!(matches!(
396            bad.validate().unwrap_err(),
397            TmaEncodeError::Mismatch {
398                what: "box_dim",
399                ..
400            }
401        ));
402
403        // Rank 0 — reject.
404        let bad = TensorMapDescriptor::new(TensorMapDataType::Float32, 0x10);
405        assert!(matches!(
406            bad.validate().unwrap_err(),
407            TmaEncodeError::BadRank(0)
408        ));
409
410        // Rank 6 — reject.
411        let bad = TensorMapDescriptor {
412            data_type: TensorMapDataType::Float32,
413            global_address: 0x10,
414            global_dim: vec![1; 6],
415            global_strides: vec![4; 5],
416            box_dim: vec![1; 6],
417            element_strides: vec![1; 6],
418            interleave: TensorMapInterleave::None,
419            swizzle: TensorMapSwizzle::None,
420            l2_promotion: TensorMapL2Promotion::None,
421            oob_fill: TensorMapOobFill::NaZero,
422        };
423        assert!(matches!(
424            bad.validate().unwrap_err(),
425            TmaEncodeError::BadRank(6)
426        ));
427    }
428
429    /// Every dtype/swizzle/interleave round-trips through `as_u32` to
430    /// a unique value (enum identity).
431    #[test]
432    fn enum_discriminants_are_unique() {
433        let dts = [
434            TensorMapDataType::UInt8,
435            TensorMapDataType::UInt16,
436            TensorMapDataType::UInt32,
437            TensorMapDataType::Int32,
438            TensorMapDataType::UInt64,
439            TensorMapDataType::Int64,
440            TensorMapDataType::Float16,
441            TensorMapDataType::Float32,
442            TensorMapDataType::Float64,
443            TensorMapDataType::BFloat16,
444            TensorMapDataType::Float32Ftz,
445            TensorMapDataType::TFloat32,
446            TensorMapDataType::TFloat32Ftz,
447            TensorMapDataType::Float8E4m3,
448            TensorMapDataType::Float8E5m2,
449            TensorMapDataType::Float4E2m1,
450            TensorMapDataType::Float6E2m3,
451            TensorMapDataType::Float6E3m2,
452        ];
453        let mut seen = std::collections::HashSet::new();
454        for d in dts {
455            assert!(
456                seen.insert(d.as_u32()),
457                "duplicate dtype discriminant for {d:?}"
458            );
459        }
460    }
461}