1use std::fmt;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum TensorMapDataType {
25 UInt8,
26 UInt16,
27 UInt32,
28 Int32,
29 UInt64,
30 Int64,
31 Float16,
32 Float32,
33 Float64,
34 BFloat16,
35 Float32Ftz,
36 TFloat32,
38 TFloat32Ftz,
39 Float8E4m3,
44 Float8E5m2,
45 Float4E2m1,
46 Float6E2m3,
47 Float6E3m2,
48}
49
50impl TensorMapDataType {
51 pub fn as_u32(self) -> u32 {
54 match self {
55 TensorMapDataType::UInt8 => 0,
56 TensorMapDataType::UInt16 => 1,
57 TensorMapDataType::UInt32 => 2,
58 TensorMapDataType::Int32 => 3,
59 TensorMapDataType::UInt64 => 4,
60 TensorMapDataType::Int64 => 5,
61 TensorMapDataType::Float16 => 6,
62 TensorMapDataType::Float32 => 7,
63 TensorMapDataType::Float64 => 8,
64 TensorMapDataType::BFloat16 => 9,
65 TensorMapDataType::Float32Ftz => 10,
66 TensorMapDataType::TFloat32 => 11,
67 TensorMapDataType::TFloat32Ftz => 12,
68 TensorMapDataType::Float8E4m3 => 13,
69 TensorMapDataType::Float8E5m2 => 14,
70 TensorMapDataType::Float4E2m1 => 15,
71 TensorMapDataType::Float6E2m3 => 16,
72 TensorMapDataType::Float6E3m2 => 17,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum TensorMapInterleave {
80 None,
83 Bytes16,
85 Bytes32,
87}
88
89impl TensorMapInterleave {
90 pub fn as_u32(self) -> u32 {
91 match self {
92 TensorMapInterleave::None => 0,
93 TensorMapInterleave::Bytes16 => 1,
94 TensorMapInterleave::Bytes32 => 2,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum TensorMapSwizzle {
103 None,
104 Bytes32,
106 Bytes64,
108 Bytes128,
110}
111
112impl TensorMapSwizzle {
113 pub fn as_u32(self) -> u32 {
114 match self {
115 TensorMapSwizzle::None => 0,
116 TensorMapSwizzle::Bytes32 => 1,
117 TensorMapSwizzle::Bytes64 => 2,
118 TensorMapSwizzle::Bytes128 => 3,
119 }
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum TensorMapL2Promotion {
127 None,
128 Bytes64,
129 Bytes128,
130 Bytes256,
131}
132
133impl TensorMapL2Promotion {
134 pub fn as_u32(self) -> u32 {
135 match self {
136 TensorMapL2Promotion::None => 0,
137 TensorMapL2Promotion::Bytes64 => 1,
138 TensorMapL2Promotion::Bytes128 => 2,
139 TensorMapL2Promotion::Bytes256 => 3,
140 }
141 }
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub enum TensorMapOobFill {
147 NaZero,
149 NanRequest,
151}
152
153impl TensorMapOobFill {
154 pub fn as_u32(self) -> u32 {
155 match self {
156 TensorMapOobFill::NaZero => 0,
157 TensorMapOobFill::NanRequest => 1,
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
171pub struct TensorMapDescriptor {
172 pub data_type: TensorMapDataType,
173 pub global_address: usize,
174 pub global_dim: Vec<u64>,
175 pub global_strides: Vec<u64>,
176 pub box_dim: Vec<u32>,
177 pub element_strides: Vec<u32>,
178 pub interleave: TensorMapInterleave,
179 pub swizzle: TensorMapSwizzle,
180 pub l2_promotion: TensorMapL2Promotion,
181 pub oob_fill: TensorMapOobFill,
182}
183
184impl TensorMapDescriptor {
185 pub fn new(data_type: TensorMapDataType, global_address: usize) -> Self {
189 Self {
190 data_type,
191 global_address,
192 global_dim: Vec::new(),
193 global_strides: Vec::new(),
194 box_dim: Vec::new(),
195 element_strides: Vec::new(),
196 interleave: TensorMapInterleave::None,
197 swizzle: TensorMapSwizzle::None,
198 l2_promotion: TensorMapL2Promotion::None,
199 oob_fill: TensorMapOobFill::NaZero,
200 }
201 }
202
203 pub fn rank(&self) -> usize {
205 self.global_dim.len()
206 }
207
208 pub fn validate(&self) -> Result<(), TmaEncodeError> {
212 let r = self.rank();
213 if r == 0 || r > 5 {
214 return Err(TmaEncodeError::BadRank(r));
215 }
216 if self.box_dim.len() != r {
217 return Err(TmaEncodeError::Mismatch {
218 what: "box_dim",
219 expected: r,
220 got: self.box_dim.len(),
221 });
222 }
223 if self.element_strides.len() != r {
224 return Err(TmaEncodeError::Mismatch {
225 what: "element_strides",
226 expected: r,
227 got: self.element_strides.len(),
228 });
229 }
230 if !self.global_strides.is_empty() && self.global_strides.len() != r - 1 {
233 return Err(TmaEncodeError::Mismatch {
234 what: "global_strides",
235 expected: r - 1,
236 got: self.global_strides.len(),
237 });
238 }
239 if self.global_address % 16 != 0 {
240 return Err(TmaEncodeError::UnalignedAddress(self.global_address));
241 }
242 if self.global_dim.contains(&0) {
243 return Err(TmaEncodeError::ZeroDim("global_dim"));
244 }
245 if self.box_dim.contains(&0) {
246 return Err(TmaEncodeError::ZeroDim("box_dim"));
247 }
248 Ok(())
249 }
250
251 #[cfg(feature = "hopper")]
263 pub fn encode(&self) -> Result<TensorMap, TmaEncodeError> {
264 use cudarc::driver::sys as cu;
265
266 self.validate()?;
267 let mut tm: cu::CUtensorMap = unsafe { std::mem::zeroed() };
268 let res = unsafe {
272 cu::cuTensorMapEncodeTiled(
273 &mut tm,
274 std::mem::transmute::<u32, cu::CUtensorMapDataType>(self.data_type.as_u32()),
275 self.rank() as cu::cuuint32_t,
276 self.global_address as *mut _,
277 self.global_dim.as_ptr(),
278 self.global_strides.as_ptr(),
279 self.box_dim.as_ptr(),
280 self.element_strides.as_ptr(),
281 std::mem::transmute::<u32, cu::CUtensorMapInterleave>(self.interleave.as_u32()),
282 std::mem::transmute::<u32, cu::CUtensorMapSwizzle>(self.swizzle.as_u32()),
283 std::mem::transmute::<u32, cu::CUtensorMapL2promotion>(self.l2_promotion.as_u32()),
284 std::mem::transmute::<u32, cu::CUtensorMapFloatOOBfill>(self.oob_fill.as_u32()),
285 )
286 };
287 if res != cu::CUresult::CUDA_SUCCESS {
288 return Err(TmaEncodeError::DriverError(res as i32));
289 }
290 Ok(TensorMap(tm))
291 }
292}
293
294#[cfg(feature = "hopper")]
297pub struct TensorMap(pub cudarc::driver::sys::CUtensorMap);
298
299#[cfg(feature = "hopper")]
300impl TensorMap {
301 pub fn as_ptr(&self) -> *const cudarc::driver::sys::CUtensorMap {
304 &self.0
305 }
306}
307
308#[derive(Debug, Clone, PartialEq, Eq)]
309pub enum TmaEncodeError {
310 BadRank(usize),
311 Mismatch {
312 what: &'static str,
313 expected: usize,
314 got: usize,
315 },
316 UnalignedAddress(usize),
317 ZeroDim(&'static str),
318 DriverError(i32),
320}
321
322impl fmt::Display for TmaEncodeError {
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 match self {
325 TmaEncodeError::BadRank(r) => write!(f, "TMA rank {r} out of [1,5]"),
326 TmaEncodeError::Mismatch {
327 what,
328 expected,
329 got,
330 } => {
331 write!(
332 f,
333 "TMA descriptor: {what}.len() = {got}, expected {expected}"
334 )
335 }
336 TmaEncodeError::UnalignedAddress(a) => {
337 write!(f, "TMA global_address 0x{a:x} is not 16-byte aligned")
338 }
339 TmaEncodeError::ZeroDim(field) => write!(f, "TMA descriptor: {field} contains a zero"),
340 TmaEncodeError::DriverError(c) => write!(f, "cuTensorMapEncodeTiled returned {c}"),
341 }
342 }
343}
344
345impl std::error::Error for TmaEncodeError {}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 fn sample_2d_descriptor() -> TensorMapDescriptor {
352 TensorMapDescriptor {
353 data_type: TensorMapDataType::Float16,
354 global_address: 0x1_0000, global_dim: vec![1024, 1024],
356 global_strides: vec![1024 * 2], box_dim: vec![64, 64],
358 element_strides: vec![1, 1],
359 interleave: TensorMapInterleave::None,
360 swizzle: TensorMapSwizzle::Bytes128,
361 l2_promotion: TensorMapL2Promotion::Bytes128,
362 oob_fill: TensorMapOobFill::NaZero,
363 }
364 }
365
366 #[test]
370 fn tensor_map_encode_descriptor_round_trip() {
371 let d = sample_2d_descriptor();
372 d.validate().expect("sample descriptor must validate");
373
374 assert_eq!(d.rank(), 2);
376 assert_eq!(d.data_type.as_u32(), TensorMapDataType::Float16.as_u32());
377 assert_eq!(d.swizzle.as_u32(), TensorMapSwizzle::Bytes128.as_u32());
378 assert_eq!(
379 d.l2_promotion.as_u32(),
380 TensorMapL2Promotion::Bytes128.as_u32()
381 );
382 assert_eq!(d.oob_fill.as_u32(), TensorMapOobFill::NaZero.as_u32());
383
384 let mut bad = d.clone();
386 bad.global_address = 0x1_0001;
387 assert!(matches!(
388 bad.validate().unwrap_err(),
389 TmaEncodeError::UnalignedAddress(_)
390 ));
391
392 let mut bad = d.clone();
394 bad.box_dim.push(32);
395 assert!(matches!(
396 bad.validate().unwrap_err(),
397 TmaEncodeError::Mismatch {
398 what: "box_dim",
399 ..
400 }
401 ));
402
403 let bad = TensorMapDescriptor::new(TensorMapDataType::Float32, 0x10);
405 assert!(matches!(
406 bad.validate().unwrap_err(),
407 TmaEncodeError::BadRank(0)
408 ));
409
410 let bad = TensorMapDescriptor {
412 data_type: TensorMapDataType::Float32,
413 global_address: 0x10,
414 global_dim: vec![1; 6],
415 global_strides: vec![4; 5],
416 box_dim: vec![1; 6],
417 element_strides: vec![1; 6],
418 interleave: TensorMapInterleave::None,
419 swizzle: TensorMapSwizzle::None,
420 l2_promotion: TensorMapL2Promotion::None,
421 oob_fill: TensorMapOobFill::NaZero,
422 };
423 assert!(matches!(
424 bad.validate().unwrap_err(),
425 TmaEncodeError::BadRank(6)
426 ));
427 }
428
429 #[test]
432 fn enum_discriminants_are_unique() {
433 let dts = [
434 TensorMapDataType::UInt8,
435 TensorMapDataType::UInt16,
436 TensorMapDataType::UInt32,
437 TensorMapDataType::Int32,
438 TensorMapDataType::UInt64,
439 TensorMapDataType::Int64,
440 TensorMapDataType::Float16,
441 TensorMapDataType::Float32,
442 TensorMapDataType::Float64,
443 TensorMapDataType::BFloat16,
444 TensorMapDataType::Float32Ftz,
445 TensorMapDataType::TFloat32,
446 TensorMapDataType::TFloat32Ftz,
447 TensorMapDataType::Float8E4m3,
448 TensorMapDataType::Float8E5m2,
449 TensorMapDataType::Float4E2m1,
450 TensorMapDataType::Float6E2m3,
451 TensorMapDataType::Float6E3m2,
452 ];
453 let mut seen = std::collections::HashSet::new();
454 for d in dts {
455 assert!(
456 seen.insert(d.as_u32()),
457 "duplicate dtype discriminant for {d:?}"
458 );
459 }
460 }
461}