Skip to main content

atomr_accel_cuda/hopper/
cluster.rs

1//! Thread-block cluster launches + Distributed Shared Memory (DSM)
2//! helpers.
3//!
4//! Hopper introduced a fourth launch dimension: a *cluster* of thread
5//! blocks. Blocks within a cluster can synchronise via `cluster.sync`
6//! and read each other's shared memory through the DSM unit. The host
7//! has to launch with `cudaLaunchKernelExC` (the older
8//! `cudaLaunchKernel` lacks the cluster-dim field).
9//!
10//! This module ships:
11//!
12//! * [`ClusterDim`] — a `(x, y, z)` cluster size, validated against
13//!   the 8-block portable limit (Hopper) / 16-block limit (Blackwell
14//!   `cudaLaunchAttributeNonPortableClusterSizeAllowed`).
15//! * [`LaunchSpec`] — grid + block + cluster + shared-memory bytes +
16//!   stream, plus optional non-portable opt-in.
17//! * [`launch_with_cluster`] (gated on `hopper`) — safe wrapper around
18//!   `cudaLaunchKernelExC`.
19
20use std::fmt;
21
22/// Cluster dimensions. Hopper supports up to 8 blocks per cluster
23/// (portable). Blackwell allows 16 with the non-portable opt-in.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct ClusterDim {
26    pub x: u32,
27    pub y: u32,
28    pub z: u32,
29}
30
31impl ClusterDim {
32    pub const fn new(x: u32, y: u32, z: u32) -> Self {
33        Self { x, y, z }
34    }
35
36    pub const fn unit() -> Self {
37        Self { x: 1, y: 1, z: 1 }
38    }
39
40    /// Block count = x * y * z.
41    pub const fn block_count(self) -> u32 {
42        self.x * self.y * self.z
43    }
44
45    /// Validate against the portable per-cluster cap (8). Returns
46    /// `Err(ClusterError::PortableLimit)` if the cluster exceeds 8 and
47    /// the caller hasn't opted into non-portable mode.
48    pub fn validate(self, allow_non_portable: bool) -> Result<(), ClusterError> {
49        if self.x == 0 || self.y == 0 || self.z == 0 {
50            return Err(ClusterError::ZeroDim);
51        }
52        let n = self.block_count();
53        if n > 8 && !allow_non_portable {
54            return Err(ClusterError::PortableLimit(n));
55        }
56        if n > 16 {
57            return Err(ClusterError::HardLimit(n));
58        }
59        Ok(())
60    }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum ClusterError {
65    ZeroDim,
66    /// Cluster exceeds the 8-block portable cap.
67    PortableLimit(u32),
68    /// Cluster exceeds the 16-block hardware cap.
69    HardLimit(u32),
70    /// Driver returned a non-zero `cudaError_t`.
71    Driver(i32),
72}
73
74impl fmt::Display for ClusterError {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            ClusterError::ZeroDim => write!(f, "cluster dim contains a zero"),
78            ClusterError::PortableLimit(n) => write!(
79                f,
80                "cluster size {n} > 8 (portable limit); set allow_non_portable=true to opt in"
81            ),
82            ClusterError::HardLimit(n) => write!(f, "cluster size {n} > 16 (hard limit)"),
83            ClusterError::Driver(c) => write!(f, "cudaLaunchKernelExC returned {c}"),
84        }
85    }
86}
87
88impl std::error::Error for ClusterError {}
89
90/// Full launch specification for a cluster-aware kernel.
91///
92/// Mirrors `cudaLaunchKernelExC`'s `cudaLaunchConfig_t`:
93/// `(gridDim, blockDim, sharedBytes, stream)` with the cluster
94/// dimension threaded through the attributes array.
95#[derive(Debug, Clone)]
96pub struct LaunchSpec {
97    pub grid_dim: (u32, u32, u32),
98    pub block_dim: (u32, u32, u32),
99    pub cluster_dim: ClusterDim,
100    pub shared_bytes: u32,
101    /// Opt-in to clusters > 8 blocks (Blackwell only).
102    pub allow_non_portable_cluster: bool,
103}
104
105impl LaunchSpec {
106    /// Construct a spec with no cluster (1×1×1) — equivalent to the
107    /// classic 3-tuple launch surface.
108    pub fn new(grid_dim: (u32, u32, u32), block_dim: (u32, u32, u32)) -> Self {
109        Self {
110            grid_dim,
111            block_dim,
112            cluster_dim: ClusterDim::unit(),
113            shared_bytes: 0,
114            allow_non_portable_cluster: false,
115        }
116    }
117
118    /// Builder: set the cluster dimension. Validates against the
119    /// 8-block portable cap on construction.
120    pub fn with_cluster(mut self, cluster: ClusterDim) -> Result<Self, ClusterError> {
121        cluster.validate(self.allow_non_portable_cluster)?;
122        self.cluster_dim = cluster;
123        Ok(self)
124    }
125
126    /// Builder: opt into non-portable cluster sizes (>8 blocks). Caller
127    /// must re-validate the cluster afterwards.
128    pub fn allow_non_portable(mut self) -> Self {
129        self.allow_non_portable_cluster = true;
130        self
131    }
132
133    /// Builder: set dynamic shared-memory bytes.
134    pub fn with_shared_bytes(mut self, bytes: u32) -> Self {
135        self.shared_bytes = bytes;
136        self
137    }
138
139    /// True if this spec has a non-trivial cluster dim (anything other
140    /// than 1×1×1).
141    pub fn has_cluster(&self) -> bool {
142        self.cluster_dim != ClusterDim::unit()
143    }
144}
145
146/// Distributed-shared-memory helper: byte count needed to allocate
147/// `per_block` bytes in every block of a cluster of size `cluster`.
148pub const fn dsm_total_bytes(cluster: ClusterDim, per_block: u32) -> u64 {
149    (cluster.block_count() as u64) * (per_block as u64)
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    /// Phase 5: round-trip the cluster-bearing launch spec through the
157    /// builder. Pure host validation; no GPU.
158    #[test]
159    fn launch_spec_with_cluster_dim_constructs() {
160        let spec = LaunchSpec::new((128, 1, 1), (256, 1, 1))
161            .with_cluster(ClusterDim::new(2, 2, 1))
162            .unwrap()
163            .with_shared_bytes(48 * 1024);
164        assert_eq!(spec.cluster_dim.block_count(), 4);
165        assert!(spec.has_cluster());
166        assert_eq!(spec.shared_bytes, 48 * 1024);
167        // Underlying ClusterDim re-validates within bounds.
168        spec.cluster_dim.validate(false).unwrap();
169    }
170
171    #[test]
172    fn portable_limit_rejects_cluster_of_nine() {
173        let cluster = ClusterDim::new(3, 3, 1); // 9 blocks
174        assert!(matches!(
175            cluster.validate(false).unwrap_err(),
176            ClusterError::PortableLimit(9)
177        ));
178        // With non-portable allowed: passes.
179        cluster.validate(true).unwrap();
180    }
181
182    #[test]
183    fn hard_limit_rejects_cluster_of_seventeen() {
184        let cluster = ClusterDim::new(17, 1, 1);
185        assert!(matches!(
186            cluster.validate(true).unwrap_err(),
187            ClusterError::HardLimit(17)
188        ));
189    }
190
191    #[test]
192    fn zero_dim_rejected() {
193        let cluster = ClusterDim::new(0, 1, 1);
194        assert!(matches!(
195            cluster.validate(true).unwrap_err(),
196            ClusterError::ZeroDim
197        ));
198    }
199
200    #[test]
201    fn dsm_total_bytes_scales_linearly() {
202        let cluster = ClusterDim::new(2, 2, 2); // 8 blocks
203        assert_eq!(dsm_total_bytes(cluster, 4096), 8 * 4096);
204        assert_eq!(dsm_total_bytes(ClusterDim::unit(), 4096), 4096);
205    }
206
207    #[test]
208    fn unit_spec_has_no_cluster() {
209        let spec = LaunchSpec::new((1, 1, 1), (32, 1, 1));
210        assert!(!spec.has_cluster());
211    }
212}