atomr_accel_cuda/hopper/
cluster.rs1use std::fmt;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct ClusterDim {
26 pub x: u32,
27 pub y: u32,
28 pub z: u32,
29}
30
31impl ClusterDim {
32 pub const fn new(x: u32, y: u32, z: u32) -> Self {
33 Self { x, y, z }
34 }
35
36 pub const fn unit() -> Self {
37 Self { x: 1, y: 1, z: 1 }
38 }
39
40 pub const fn block_count(self) -> u32 {
42 self.x * self.y * self.z
43 }
44
45 pub fn validate(self, allow_non_portable: bool) -> Result<(), ClusterError> {
49 if self.x == 0 || self.y == 0 || self.z == 0 {
50 return Err(ClusterError::ZeroDim);
51 }
52 let n = self.block_count();
53 if n > 8 && !allow_non_portable {
54 return Err(ClusterError::PortableLimit(n));
55 }
56 if n > 16 {
57 return Err(ClusterError::HardLimit(n));
58 }
59 Ok(())
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum ClusterError {
65 ZeroDim,
66 PortableLimit(u32),
68 HardLimit(u32),
70 Driver(i32),
72}
73
74impl fmt::Display for ClusterError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 ClusterError::ZeroDim => write!(f, "cluster dim contains a zero"),
78 ClusterError::PortableLimit(n) => write!(
79 f,
80 "cluster size {n} > 8 (portable limit); set allow_non_portable=true to opt in"
81 ),
82 ClusterError::HardLimit(n) => write!(f, "cluster size {n} > 16 (hard limit)"),
83 ClusterError::Driver(c) => write!(f, "cudaLaunchKernelExC returned {c}"),
84 }
85 }
86}
87
88impl std::error::Error for ClusterError {}
89
90#[derive(Debug, Clone)]
96pub struct LaunchSpec {
97 pub grid_dim: (u32, u32, u32),
98 pub block_dim: (u32, u32, u32),
99 pub cluster_dim: ClusterDim,
100 pub shared_bytes: u32,
101 pub allow_non_portable_cluster: bool,
103}
104
105impl LaunchSpec {
106 pub fn new(grid_dim: (u32, u32, u32), block_dim: (u32, u32, u32)) -> Self {
109 Self {
110 grid_dim,
111 block_dim,
112 cluster_dim: ClusterDim::unit(),
113 shared_bytes: 0,
114 allow_non_portable_cluster: false,
115 }
116 }
117
118 pub fn with_cluster(mut self, cluster: ClusterDim) -> Result<Self, ClusterError> {
121 cluster.validate(self.allow_non_portable_cluster)?;
122 self.cluster_dim = cluster;
123 Ok(self)
124 }
125
126 pub fn allow_non_portable(mut self) -> Self {
129 self.allow_non_portable_cluster = true;
130 self
131 }
132
133 pub fn with_shared_bytes(mut self, bytes: u32) -> Self {
135 self.shared_bytes = bytes;
136 self
137 }
138
139 pub fn has_cluster(&self) -> bool {
142 self.cluster_dim != ClusterDim::unit()
143 }
144}
145
146pub const fn dsm_total_bytes(cluster: ClusterDim, per_block: u32) -> u64 {
149 (cluster.block_count() as u64) * (per_block as u64)
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
159 fn launch_spec_with_cluster_dim_constructs() {
160 let spec = LaunchSpec::new((128, 1, 1), (256, 1, 1))
161 .with_cluster(ClusterDim::new(2, 2, 1))
162 .unwrap()
163 .with_shared_bytes(48 * 1024);
164 assert_eq!(spec.cluster_dim.block_count(), 4);
165 assert!(spec.has_cluster());
166 assert_eq!(spec.shared_bytes, 48 * 1024);
167 spec.cluster_dim.validate(false).unwrap();
169 }
170
171 #[test]
172 fn portable_limit_rejects_cluster_of_nine() {
173 let cluster = ClusterDim::new(3, 3, 1); assert!(matches!(
175 cluster.validate(false).unwrap_err(),
176 ClusterError::PortableLimit(9)
177 ));
178 cluster.validate(true).unwrap();
180 }
181
182 #[test]
183 fn hard_limit_rejects_cluster_of_seventeen() {
184 let cluster = ClusterDim::new(17, 1, 1);
185 assert!(matches!(
186 cluster.validate(true).unwrap_err(),
187 ClusterError::HardLimit(17)
188 ));
189 }
190
191 #[test]
192 fn zero_dim_rejected() {
193 let cluster = ClusterDim::new(0, 1, 1);
194 assert!(matches!(
195 cluster.validate(true).unwrap_err(),
196 ClusterError::ZeroDim
197 ));
198 }
199
200 #[test]
201 fn dsm_total_bytes_scales_linearly() {
202 let cluster = ClusterDim::new(2, 2, 2); assert_eq!(dsm_total_bytes(cluster, 4096), 8 * 4096);
204 assert_eq!(dsm_total_bytes(ClusterDim::unit(), 4096), 4096);
205 }
206
207 #[test]
208 fn unit_spec_has_no_cluster() {
209 let spec = LaunchSpec::new((1, 1, 1), (32, 1, 1));
210 assert!(!spec.has_cluster());
211 }
212}