atomr_accel_cuda/hopper/wgmma.rs
1//! WGMMA (warp-group matrix multiply accumulate) intrinsic shim.
2//!
3//! Hopper's `wgmma.mma_async.sync` instruction is issued from a
4//! 128-thread warpgroup; the host side has nothing to call, but NVRTC
5//! kernels embed the intrinsics through PTX inline assembly. This
6//! module ships a small set of macro shims (in `atomr_hopper.cuh`) that
7//! pin the asm constraints to the supported `(M, N, K, dtype-A,
8//! dtype-B, dtype-D)` shapes and give Rust callers symbolic names for
9//! the descriptors they have to build host-side.
10//!
11//! Only the most common matmul variants are wrapped. Adding a new
12//! variant means adding a new `WGMMA_MMA_ASYNC_*` macro in
13//! `atomr_hopper.cuh` and a constant in [`WgmmaShape`].
14
15/// Subset of WGMMA matmul shapes commonly exercised by attention /
16/// matmul kernels. The numeric tuple is `(M, N, K)` (row × col × inner).
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum WgmmaShape {
19 /// `m64n64k16` — most-common fp16 variant.
20 M64N64K16,
21 /// `m64n128k16` — wider tile, same fp16.
22 M64N128K16,
23 /// `m64n256k16` — full warpgroup output tile.
24 M64N256K16,
25 /// `m64n64k32` — fp8 (e4m3/e5m2) variant.
26 M64N64K32,
27 /// `m64n128k32` — fp8 wider.
28 M64N128K32,
29 /// `m64n256k32` — fp8 full.
30 M64N256K32,
31}
32
33impl WgmmaShape {
34 /// `(M, N, K)` decomposition.
35 pub fn dims(self) -> (u32, u32, u32) {
36 match self {
37 WgmmaShape::M64N64K16 => (64, 64, 16),
38 WgmmaShape::M64N128K16 => (64, 128, 16),
39 WgmmaShape::M64N256K16 => (64, 256, 16),
40 WgmmaShape::M64N64K32 => (64, 64, 32),
41 WgmmaShape::M64N128K32 => (64, 128, 32),
42 WgmmaShape::M64N256K32 => (64, 256, 32),
43 }
44 }
45
46 /// Macro name (matches `atomr_hopper.cuh`).
47 pub fn macro_name(self) -> &'static str {
48 match self {
49 WgmmaShape::M64N64K16 => "ATOMR_WGMMA_F16_M64N64K16",
50 WgmmaShape::M64N128K16 => "ATOMR_WGMMA_F16_M64N128K16",
51 WgmmaShape::M64N256K16 => "ATOMR_WGMMA_F16_M64N256K16",
52 WgmmaShape::M64N64K32 => "ATOMR_WGMMA_F8_M64N64K32",
53 WgmmaShape::M64N128K32 => "ATOMR_WGMMA_F8_M64N128K32",
54 WgmmaShape::M64N256K32 => "ATOMR_WGMMA_F8_M64N256K32",
55 }
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 #[test]
64 fn dims_round_trip() {
65 assert_eq!(WgmmaShape::M64N64K16.dims(), (64, 64, 16));
66 assert_eq!(WgmmaShape::M64N256K32.dims(), (64, 256, 32));
67 }
68
69 #[test]
70 fn macro_names_match_header() {
71 assert_eq!(
72 WgmmaShape::M64N64K16.macro_name(),
73 "ATOMR_WGMMA_F16_M64N64K16"
74 );
75 }
76}