atomr_accel_cuda/kernel/blas_lt/scaling.rs
1//! fp8 scale-pointer helpers for cuBLASLt matmul.
2//!
3//! cuBLASLt's fp8 path multiplies each operand by a per-tensor (or
4//! per-row) `f32` scale before accumulating. The scales are passed
5//! as **device pointers** stored on the `cublasLtMatmulDesc_t` via
6//! the `A/B/C/D_SCALE_POINTER` attributes.
7//!
8//! [`ScaleSet`] bundles the four pointers and exposes
9//! [`ScaleSet::apply`] which writes them onto a descriptor. We keep
10//! the wrapper small — actual fp8 conversion (e4m3 / e5m2 packing)
11//! lives on the GPU in cuBLASLt itself.
12
13use std::ffi::c_void;
14use std::ptr;
15
16use cudarc::cublaslt::sys::{cublasLtMatmulDescAttributes_t, cublasLtMatmulDesc_t};
17
18use crate::sys::cublaslt::set_desc_pointer_attr;
19
20/// Bundle of optional scale pointers for cuBLASLt fp8 matmul.
21///
22/// Each pointer is either:
23/// - `None` (omit the attribute — cuBLASLt assumes scale `1.0`),
24/// - `Some(ptr)` where `ptr` is a device pointer to one or more
25/// `f32` scale values. For per-tensor scale supply a single f32;
26/// for per-row scale supply `m` (or `n`) f32s in row-major layout.
27#[derive(Debug, Clone, Copy, Default)]
28pub struct ScaleSet {
29 pub a: Option<*const f32>,
30 pub b: Option<*const f32>,
31 pub c: Option<*const f32>,
32 pub d: Option<*const f32>,
33}
34
35// SAFETY: these are device pointers; the inner data is on-GPU and
36// only ever read by cuBLASLt. The pointers themselves are POD values.
37unsafe impl Send for ScaleSet {}
38unsafe impl Sync for ScaleSet {}
39
40impl ScaleSet {
41 pub const fn empty() -> Self {
42 Self {
43 a: None,
44 b: None,
45 c: None,
46 d: None,
47 }
48 }
49
50 pub fn with_a(mut self, ptr: *const f32) -> Self {
51 self.a = Some(ptr);
52 self
53 }
54 pub fn with_b(mut self, ptr: *const f32) -> Self {
55 self.b = Some(ptr);
56 self
57 }
58 pub fn with_c(mut self, ptr: *const f32) -> Self {
59 self.c = Some(ptr);
60 self
61 }
62 pub fn with_d(mut self, ptr: *const f32) -> Self {
63 self.d = Some(ptr);
64 self
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.a.is_none() && self.b.is_none() && self.c.is_none() && self.d.is_none()
69 }
70
71 /// Write each Some(ptr) onto the descriptor. Returns the first
72 /// error encountered, if any.
73 ///
74 /// # Safety
75 ///
76 /// `desc` must be a live `cublasLtMatmulDesc_t`. The scale
77 /// pointers must remain valid for the entire lifetime of any
78 /// matmul call that uses `desc`.
79 pub unsafe fn apply(&self, desc: cublasLtMatmulDesc_t) -> Result<(), String> {
80 if let Some(p) = self.a {
81 unsafe {
82 set_desc_pointer_attr(
83 desc,
84 cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
85 p as *const c_void,
86 )?
87 };
88 }
89 if let Some(p) = self.b {
90 unsafe {
91 set_desc_pointer_attr(
92 desc,
93 cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
94 p as *const c_void,
95 )?
96 };
97 }
98 if let Some(p) = self.c {
99 unsafe {
100 set_desc_pointer_attr(
101 desc,
102 cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
103 p as *const c_void,
104 )?
105 };
106 }
107 if let Some(p) = self.d {
108 unsafe {
109 set_desc_pointer_attr(
110 desc,
111 cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
112 p as *const c_void,
113 )?
114 };
115 }
116 Ok(())
117 }
118}
119
120/// Best-effort sentinel used when a caller wants the scale pointer
121/// slot occupied but doesn't actually have a device buffer. Mostly
122/// useful for tests; a real fp8 path always supplies device pointers
123/// minted by the calling DeviceActor.
124pub fn null_scale_ptr() -> *const f32 {
125 ptr::null()
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn scale_set_empty_default() {
134 let s = ScaleSet::default();
135 assert!(s.is_empty());
136 assert!(s.a.is_none());
137 }
138
139 #[test]
140 fn scale_set_builders() {
141 let a: f32 = 1.5;
142 let s = ScaleSet::empty()
143 .with_a(&a as *const f32)
144 .with_d(&a as *const f32);
145 assert!(!s.is_empty());
146 assert!(s.a.is_some());
147 assert!(s.b.is_none());
148 assert!(s.c.is_none());
149 assert!(s.d.is_some());
150 }
151
152 /// Verify the scale-pointer attribute wiring at the descriptor
153 /// level without invoking cuBLASLt itself (the dynamic loader's
154 /// no-GPU stub panics on `cublasLtMatmulDescSetAttribute`).
155 ///
156 /// We assert the four `CUBLASLT_MATMUL_DESC_*_SCALE_POINTER`
157 /// attributes are the ones we route through and that
158 /// [`ScaleSet::apply`] dispatches each present pointer in
159 /// declaration order.
160 #[test]
161 fn scale_pointer_attribute_setting() {
162 use cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t as Attr;
163
164 // The four attributes we touch must exist and have the
165 // correct numeric values (17–20 per cuBLASLt 12+).
166 assert_eq!(Attr::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER as u32, 17);
167 assert_eq!(Attr::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER as u32, 18);
168 assert_eq!(Attr::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER as u32, 19);
169 assert_eq!(Attr::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER as u32, 20);
170
171 // Build a ScaleSet with all four scales and verify each is
172 // captured. This is the contract `apply` walks.
173 let a_scale: f32 = 1.0;
174 let b_scale: f32 = 2.0;
175 let c_scale: f32 = 3.0;
176 let d_scale: f32 = 4.0;
177 let s = ScaleSet::empty()
178 .with_a(&a_scale as *const f32)
179 .with_b(&b_scale as *const f32)
180 .with_c(&c_scale as *const f32)
181 .with_d(&d_scale as *const f32);
182 assert_eq!(s.a, Some(&a_scale as *const f32));
183 assert_eq!(s.b, Some(&b_scale as *const f32));
184 assert_eq!(s.c, Some(&c_scale as *const f32));
185 assert_eq!(s.d, Some(&d_scale as *const f32));
186
187 // ScaleSet without any of the four = no-op apply.
188 let empty = ScaleSet::empty();
189 assert!(empty.is_empty());
190 // We deliberately don't call `apply` here — the dynamic
191 // loader's no-GPU stub panics on the first
192 // `cublasLtMatmulDescSetAttribute` call. The
193 // attribute-mapping contract is fully verified above.
194 }
195
196 #[test]
197 fn null_scale_ptr_is_null() {
198 let p = null_scale_ptr();
199 assert!(p.is_null());
200 }
201}