Skip to main content

atomr_accel_cuda/hopper/
wgmma.rs

1//! WGMMA (warp-group matrix multiply accumulate) intrinsic shim.
2//!
3//! Hopper's `wgmma.mma_async.sync` instruction is issued from a
4//! 128-thread warpgroup; the host side has nothing to call, but NVRTC
5//! kernels embed the intrinsics through PTX inline assembly. This
6//! module ships a small set of macro shims (in `atomr_hopper.cuh`) that
7//! pin the asm constraints to the supported `(M, N, K, dtype-A,
8//! dtype-B, dtype-D)` shapes and give Rust callers symbolic names for
9//! the descriptors they have to build host-side.
10//!
11//! Only the most common matmul variants are wrapped. Adding a new
12//! variant means adding a new `WGMMA_MMA_ASYNC_*` macro in
13//! `atomr_hopper.cuh` and a constant in [`WgmmaShape`].
14
15/// Subset of WGMMA matmul shapes commonly exercised by attention /
16/// matmul kernels. The numeric tuple is `(M, N, K)` (row × col × inner).
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum WgmmaShape {
19    /// `m64n64k16` — most-common fp16 variant.
20    M64N64K16,
21    /// `m64n128k16` — wider tile, same fp16.
22    M64N128K16,
23    /// `m64n256k16` — full warpgroup output tile.
24    M64N256K16,
25    /// `m64n64k32` — fp8 (e4m3/e5m2) variant.
26    M64N64K32,
27    /// `m64n128k32` — fp8 wider.
28    M64N128K32,
29    /// `m64n256k32` — fp8 full.
30    M64N256K32,
31}
32
33impl WgmmaShape {
34    /// `(M, N, K)` decomposition.
35    pub fn dims(self) -> (u32, u32, u32) {
36        match self {
37            WgmmaShape::M64N64K16 => (64, 64, 16),
38            WgmmaShape::M64N128K16 => (64, 128, 16),
39            WgmmaShape::M64N256K16 => (64, 256, 16),
40            WgmmaShape::M64N64K32 => (64, 64, 32),
41            WgmmaShape::M64N128K32 => (64, 128, 32),
42            WgmmaShape::M64N256K32 => (64, 256, 32),
43        }
44    }
45
46    /// Macro name (matches `atomr_hopper.cuh`).
47    pub fn macro_name(self) -> &'static str {
48        match self {
49            WgmmaShape::M64N64K16 => "ATOMR_WGMMA_F16_M64N64K16",
50            WgmmaShape::M64N128K16 => "ATOMR_WGMMA_F16_M64N128K16",
51            WgmmaShape::M64N256K16 => "ATOMR_WGMMA_F16_M64N256K16",
52            WgmmaShape::M64N64K32 => "ATOMR_WGMMA_F8_M64N64K32",
53            WgmmaShape::M64N128K32 => "ATOMR_WGMMA_F8_M64N128K32",
54            WgmmaShape::M64N256K32 => "ATOMR_WGMMA_F8_M64N256K32",
55        }
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[test]
64    fn dims_round_trip() {
65        assert_eq!(WgmmaShape::M64N64K16.dims(), (64, 64, 16));
66        assert_eq!(WgmmaShape::M64N256K32.dims(), (64, 256, 32));
67    }
68
69    #[test]
70    fn macro_names_match_header() {
71        assert_eq!(
72            WgmmaShape::M64N64K16.macro_name(),
73            "ATOMR_WGMMA_F16_M64N64K16"
74        );
75    }
76}