Skip to main content

atomr_accel_cuda/kernel/blas_lt/
scaling.rs

1//! fp8 scale-pointer helpers for cuBLASLt matmul.
2//!
3//! cuBLASLt's fp8 path multiplies each operand by a per-tensor (or
4//! per-row) `f32` scale before accumulating. The scales are passed
5//! as **device pointers** stored on the `cublasLtMatmulDesc_t` via
6//! the `A/B/C/D_SCALE_POINTER` attributes.
7//!
8//! [`ScaleSet`] bundles the four pointers and exposes
9//! [`ScaleSet::apply`] which writes them onto a descriptor. We keep
10//! the wrapper small — actual fp8 conversion (e4m3 / e5m2 packing)
11//! lives on the GPU in cuBLASLt itself.
12
13use std::ffi::c_void;
14use std::ptr;
15
16use cudarc::cublaslt::sys::{cublasLtMatmulDescAttributes_t, cublasLtMatmulDesc_t};
17
18use crate::sys::cublaslt::set_desc_pointer_attr;
19
20/// Bundle of optional scale pointers for cuBLASLt fp8 matmul.
21///
22/// Each pointer is either:
23/// - `None` (omit the attribute — cuBLASLt assumes scale `1.0`),
24/// - `Some(ptr)` where `ptr` is a device pointer to one or more
25///   `f32` scale values. For per-tensor scale supply a single f32;
26///   for per-row scale supply `m` (or `n`) f32s in row-major layout.
27#[derive(Debug, Clone, Copy, Default)]
28pub struct ScaleSet {
29    pub a: Option<*const f32>,
30    pub b: Option<*const f32>,
31    pub c: Option<*const f32>,
32    pub d: Option<*const f32>,
33}
34
35// SAFETY: these are device pointers; the inner data is on-GPU and
36// only ever read by cuBLASLt. The pointers themselves are POD values.
37unsafe impl Send for ScaleSet {}
38unsafe impl Sync for ScaleSet {}
39
40impl ScaleSet {
41    pub const fn empty() -> Self {
42        Self {
43            a: None,
44            b: None,
45            c: None,
46            d: None,
47        }
48    }
49
50    pub fn with_a(mut self, ptr: *const f32) -> Self {
51        self.a = Some(ptr);
52        self
53    }
54    pub fn with_b(mut self, ptr: *const f32) -> Self {
55        self.b = Some(ptr);
56        self
57    }
58    pub fn with_c(mut self, ptr: *const f32) -> Self {
59        self.c = Some(ptr);
60        self
61    }
62    pub fn with_d(mut self, ptr: *const f32) -> Self {
63        self.d = Some(ptr);
64        self
65    }
66
67    pub fn is_empty(&self) -> bool {
68        self.a.is_none() && self.b.is_none() && self.c.is_none() && self.d.is_none()
69    }
70
71    /// Write each Some(ptr) onto the descriptor. Returns the first
72    /// error encountered, if any.
73    ///
74    /// # Safety
75    ///
76    /// `desc` must be a live `cublasLtMatmulDesc_t`. The scale
77    /// pointers must remain valid for the entire lifetime of any
78    /// matmul call that uses `desc`.
79    pub unsafe fn apply(&self, desc: cublasLtMatmulDesc_t) -> Result<(), String> {
80        if let Some(p) = self.a {
81            unsafe {
82                set_desc_pointer_attr(
83                    desc,
84                    cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
85                    p as *const c_void,
86                )?
87            };
88        }
89        if let Some(p) = self.b {
90            unsafe {
91                set_desc_pointer_attr(
92                    desc,
93                    cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
94                    p as *const c_void,
95                )?
96            };
97        }
98        if let Some(p) = self.c {
99            unsafe {
100                set_desc_pointer_attr(
101                    desc,
102                    cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
103                    p as *const c_void,
104                )?
105            };
106        }
107        if let Some(p) = self.d {
108            unsafe {
109                set_desc_pointer_attr(
110                    desc,
111                    cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
112                    p as *const c_void,
113                )?
114            };
115        }
116        Ok(())
117    }
118}
119
120/// Best-effort sentinel used when a caller wants the scale pointer
121/// slot occupied but doesn't actually have a device buffer. Mostly
122/// useful for tests; a real fp8 path always supplies device pointers
123/// minted by the calling DeviceActor.
124pub fn null_scale_ptr() -> *const f32 {
125    ptr::null()
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn scale_set_empty_default() {
134        let s = ScaleSet::default();
135        assert!(s.is_empty());
136        assert!(s.a.is_none());
137    }
138
139    #[test]
140    fn scale_set_builders() {
141        let a: f32 = 1.5;
142        let s = ScaleSet::empty()
143            .with_a(&a as *const f32)
144            .with_d(&a as *const f32);
145        assert!(!s.is_empty());
146        assert!(s.a.is_some());
147        assert!(s.b.is_none());
148        assert!(s.c.is_none());
149        assert!(s.d.is_some());
150    }
151
152    /// Verify the scale-pointer attribute wiring at the descriptor
153    /// level without invoking cuBLASLt itself (the dynamic loader's
154    /// no-GPU stub panics on `cublasLtMatmulDescSetAttribute`).
155    ///
156    /// We assert the four `CUBLASLT_MATMUL_DESC_*_SCALE_POINTER`
157    /// attributes are the ones we route through and that
158    /// [`ScaleSet::apply`] dispatches each present pointer in
159    /// declaration order.
160    #[test]
161    fn scale_pointer_attribute_setting() {
162        use cudarc::cublaslt::sys::cublasLtMatmulDescAttributes_t as Attr;
163
164        // The four attributes we touch must exist and have the
165        // correct numeric values (17–20 per cuBLASLt 12+).
166        assert_eq!(Attr::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER as u32, 17);
167        assert_eq!(Attr::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER as u32, 18);
168        assert_eq!(Attr::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER as u32, 19);
169        assert_eq!(Attr::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER as u32, 20);
170
171        // Build a ScaleSet with all four scales and verify each is
172        // captured. This is the contract `apply` walks.
173        let a_scale: f32 = 1.0;
174        let b_scale: f32 = 2.0;
175        let c_scale: f32 = 3.0;
176        let d_scale: f32 = 4.0;
177        let s = ScaleSet::empty()
178            .with_a(&a_scale as *const f32)
179            .with_b(&b_scale as *const f32)
180            .with_c(&c_scale as *const f32)
181            .with_d(&d_scale as *const f32);
182        assert_eq!(s.a, Some(&a_scale as *const f32));
183        assert_eq!(s.b, Some(&b_scale as *const f32));
184        assert_eq!(s.c, Some(&c_scale as *const f32));
185        assert_eq!(s.d, Some(&d_scale as *const f32));
186
187        // ScaleSet without any of the four = no-op apply.
188        let empty = ScaleSet::empty();
189        assert!(empty.is_empty());
190        // We deliberately don't call `apply` here — the dynamic
191        // loader's no-GPU stub panics on the first
192        // `cublasLtMatmulDescSetAttribute` call. The
193        // attribute-mapping contract is fully verified above.
194    }
195
196    #[test]
197    fn null_scale_ptr_is_null() {
198        let p = null_scale_ptr();
199        assert!(p.is_null());
200    }
201}