Skip to main content

atomr_accel_cuda/sys/
cublaslt.rs

1//! Thin local wrappers over `cudarc::cublaslt::sys` for the entry
2//! points cudarc 0.19.4's safe layer doesn't expose. Each helper is
3//! `unsafe fn` if it dereferences raw handles; the safe shape is just
4//! "create / set-attribute / destroy" RAII over the opaque handle.
5//!
6//! Used by:
7//! - [`crate::kernel::blas_lt::heuristic`] for
8//!   `cublasLtMatmulPreferenceCreate` + `cublasLtMatmulAlgoGetHeuristic`,
9//! - [`crate::kernel::blas_lt::scaling`] for the per-tensor fp8 scale
10//!   pointer descriptor attributes,
11//! - [`crate::kernel::blas_lt::epilogue`] for setting the
12//!   `CUBLASLT_MATMUL_DESC_EPILOGUE` attribute.
13
14use std::ffi::c_void;
15use std::ptr;
16
17#[cfg(test)]
18use cudarc::cublaslt::sys::cublasLtMatmulDescOpaque_t;
19use cudarc::cublaslt::sys::{
20    cublasLtMatmulDescAttributes_t, cublasLtMatmulDesc_t, cublasLtMatmulPreferenceAttributes_t,
21    cublasLtMatmulPreferenceOpaque_t, cublasLtMatmulPreference_t, cublasStatus_t,
22};
23
24/// Map a `cublasStatus_t` into an `Err(String)` for the cuBLASLt
25/// status codes we care about. We don't go through `CublasError` here
26/// because cudarc's `result` module doesn't expose every variant we
27/// touch — the string form is sufficient since these errors funnel
28/// straight into `GpuError::LibraryError`.
29pub fn check(status: cublasStatus_t, op: &str) -> Result<(), String> {
30    match status {
31        cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
32        other => Err(format!("{op}: {other:?}")),
33    }
34}
35
36/// RAII handle for a `cublasLtMatmulPreference_t`. Created via
37/// [`Preference::new`], destroyed on Drop.
38pub struct Preference {
39    pub raw: cublasLtMatmulPreference_t,
40}
41
42// SAFETY: the underlying preference object is plain CPU-side state
43// that cuBLASLt synchronizes internally. We only ever touch it through
44// the cuBLASLt API.
45unsafe impl Send for Preference {}
46unsafe impl Sync for Preference {}
47
48impl Preference {
49    pub fn new() -> Result<Self, String> {
50        let mut raw: cublasLtMatmulPreference_t =
51            ptr::null_mut::<cublasLtMatmulPreferenceOpaque_t>();
52        let status = unsafe { cudarc::cublaslt::sys::cublasLtMatmulPreferenceCreate(&mut raw) };
53        check(status, "cublasLtMatmulPreferenceCreate")?;
54        Ok(Self { raw })
55    }
56
57    /// Set a u64-valued preference attribute (the most common case —
58    /// `MAX_WORKSPACE_BYTES`, `IMPL_MASK`, `REDUCTION_SCHEME_MASK` all
59    /// take a u64).
60    pub fn set_u64(
61        &self,
62        attr: cublasLtMatmulPreferenceAttributes_t,
63        value: u64,
64    ) -> Result<(), String> {
65        let status = unsafe {
66            cudarc::cublaslt::sys::cublasLtMatmulPreferenceSetAttribute(
67                self.raw,
68                attr,
69                &value as *const u64 as *const c_void,
70                std::mem::size_of::<u64>(),
71            )
72        };
73        check(status, "cublasLtMatmulPreferenceSetAttribute")
74    }
75}
76
77impl Drop for Preference {
78    fn drop(&mut self) {
79        if !self.raw.is_null() {
80            unsafe {
81                let _ = cudarc::cublaslt::sys::cublasLtMatmulPreferenceDestroy(self.raw);
82            }
83            self.raw = ptr::null_mut::<cublasLtMatmulPreferenceOpaque_t>();
84        }
85    }
86}
87
88/// Set a pointer-valued attribute on a matmul descriptor (used for
89/// `BIAS_POINTER`, `EPILOGUE_AUX_POINTER`, fp8 scale pointers).
90///
91/// # Safety
92///
93/// `desc` must be a live `cublasLtMatmulDesc_t`; `ptr` must remain
94/// valid for the entire lifetime of any matmul call that uses `desc`.
95pub unsafe fn set_desc_pointer_attr(
96    desc: cublasLtMatmulDesc_t,
97    attr: cublasLtMatmulDescAttributes_t,
98    ptr: *const c_void,
99) -> Result<(), String> {
100    let status = unsafe {
101        cudarc::cublaslt::sys::cublasLtMatmulDescSetAttribute(
102            desc,
103            attr,
104            &ptr as *const *const c_void as *const c_void,
105            std::mem::size_of::<*const c_void>(),
106        )
107    };
108    check(status, "cublasLtMatmulDescSetAttribute(pointer)")
109}
110
111/// Set an i32-valued attribute on a matmul descriptor (used for
112/// `EPILOGUE`, the cudarc-bindgen `cublasLtEpilogue_t` is repr(u32)
113/// but the cuBLASLt API expects `int`).
114///
115/// # Safety
116///
117/// `desc` must be a live `cublasLtMatmulDesc_t`.
118pub unsafe fn set_desc_i32_attr(
119    desc: cublasLtMatmulDesc_t,
120    attr: cublasLtMatmulDescAttributes_t,
121    value: i32,
122) -> Result<(), String> {
123    let status = unsafe {
124        cudarc::cublaslt::sys::cublasLtMatmulDescSetAttribute(
125            desc,
126            attr,
127            &value as *const i32 as *const c_void,
128            std::mem::size_of::<i32>(),
129        )
130    };
131    check(status, "cublasLtMatmulDescSetAttribute(i32)")
132}
133
134/// Stand-in opaque-descriptor type used by tests that mock attribute
135/// writes without touching the real CUDA driver.
136#[cfg(test)]
137pub fn mock_desc_handle() -> cublasLtMatmulDesc_t {
138    let leaked: Box<cublasLtMatmulDescOpaque_t> = Box::new(unsafe { std::mem::zeroed() });
139    Box::into_raw(leaked)
140}
141
142/// Drop a mock descriptor allocated via [`mock_desc_handle`].
143#[cfg(test)]
144pub unsafe fn drop_mock_desc(desc: cublasLtMatmulDesc_t) {
145    if !desc.is_null() {
146        let _ = unsafe { Box::from_raw(desc) };
147    }
148}