Skip to main content

atomr_accel_cuda/kernel/blas/
gemm.rs

1//! Typed `GemmRequest<T>` + `GemmDispatch` impls.
2//!
3//! cudarc 0.19 exposes `cudarc::cublas::Gemm<T>` for f32, f64, and
4//! (under feature `f16`) `half::f16` and `half::bf16`. The dispatcher
5//! re-uses that safe trait so we don't have to touch
6//! `cublasGemmEx` directly for the common dtypes — fp8 is the future
7//! follow-up that lights up `crate::sys::cublas::gemm_ex` once the
8//! `cublas-fp8` feature is wired (see [`super::scaling`]).
9
10use std::sync::Arc;
11
12use cudarc::cublas::sys::cublasOperation_t;
13use cudarc::cublas::{Gemm, GemmConfig};
14use tokio::sync::oneshot;
15
16use crate::dtype::GemmSupported;
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::dispatch::{BlasDispatchCtx, GemmDispatch};
20use crate::kernel::envelope;
21
22const LIB: &str = "cublas";
23
24/// Typed cuBLAS gemm request: `C = α·op(A)·op(B) + β·C`.
25///
26/// `lda`/`ldb`/`ldc` follow cuBLAS's column-major convention (see
27/// cuBLAS docs). For the no-transpose case, `lda = m`, `ldb = k`,
28/// `ldc = m`.
29///
30/// # Capability marker compile-fail
31///
32/// `T: GemmSupported` gates the dtype matrix. cuBLAS does **not**
33/// support i64 gemm, so building a `GemmRequest::<i64>` is rejected
34/// at compile time:
35///
36/// ```compile_fail
37/// # use atomr_accel_cuda::kernel::GemmRequest;
38/// # use atomr_accel_cuda::gpu_ref::GpuRef;
39/// # use cudarc::cublas::sys::cublasOperation_t;
40/// # let (tx, _rx) = tokio::sync::oneshot::channel();
41/// # let a: GpuRef<i64> = unimplemented!();
42/// # let b: GpuRef<i64> = unimplemented!();
43/// # let c: GpuRef<i64> = unimplemented!();
44/// // Fails: i64 does not implement `GemmSupported`.
45/// let _req = GemmRequest::<i64> {
46///     a, b, c,
47///     m: 1, n: 1, k: 1,
48///     alpha: 1, beta: 0,
49///     trans_a: cublasOperation_t::CUBLAS_OP_N,
50///     trans_b: cublasOperation_t::CUBLAS_OP_N,
51///     lda: 1, ldb: 1, ldc: 1,
52///     reply: tx,
53/// };
54/// ```
55pub struct GemmRequest<T: GemmSupported> {
56    pub a: GpuRef<T>,
57    pub b: GpuRef<T>,
58    pub c: GpuRef<T>,
59    pub m: i32,
60    pub n: i32,
61    pub k: i32,
62    pub alpha: T,
63    pub beta: T,
64    pub trans_a: cublasOperation_t,
65    pub trans_b: cublasOperation_t,
66    pub lda: i32,
67    pub ldb: i32,
68    pub ldc: i32,
69    pub reply: oneshot::Sender<Result<(), GpuError>>,
70}
71
72impl<T> GemmRequest<T>
73where
74    T: GemmSupported,
75    GemmRequest<T>: GemmDispatch,
76{
77    /// Box-and-wrap into a [`crate::kernel::BlasMsg::Gemm`] variant.
78    pub fn into_msg(self) -> crate::kernel::BlasMsg {
79        crate::kernel::BlasMsg::Gemm(Box::new(self))
80    }
81}
82
83/// Generic dispatch body shared across every `Gemm<T>` cudarc impl.
84///
85/// We split it into a function so the trait impl stays tiny. The
86/// `T: Gemm<...> for CudaBlas` bound forces every call site to pick a
87/// dtype cudarc actually implements; calling `gemm::<i64>(...)` would
88/// fail to compile.
89fn dispatch_gemm<T>(req: GemmRequest<T>, ctx: &BlasDispatchCtx<'_>)
90where
91    T: GemmSupported + Copy,
92    cudarc::cublas::CudaBlas: Gemm<T>,
93{
94    let GemmRequest {
95        a,
96        b,
97        c,
98        m,
99        n,
100        k,
101        alpha,
102        beta,
103        trans_a,
104        trans_b,
105        lda,
106        ldb,
107        ldc,
108        reply,
109    } = req;
110
111    let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
112        Ok(t) => t,
113        Err(e) => {
114            let _ = reply.send(Err(e));
115            return;
116        }
117    };
118
119    let cfg = GemmConfig::<T> {
120        transa: trans_a,
121        transb: trans_b,
122        m,
123        n,
124        k,
125        alpha,
126        lda,
127        ldb,
128        beta,
129        ldc,
130    };
131
132    // cudarc's `gemm` requires `&mut C: DevicePtrMut<T>`. An
133    // `Arc<CudaSlice<T>>` doesn't satisfy that: we have to unwrap the
134    // Arc. The caller must hold the unique `GpuRef` to the output
135    // buffer or the unwrap fails — single-writer enforcement.
136    let mut c_owned = match Arc::try_unwrap(c_slice) {
137        Ok(s) => s,
138        Err(_arc) => {
139            let _ = reply.send(Err(GpuError::Unrecoverable(
140                "GEMM target buffer C has more than one live reference; \
141                 caller must hold the unique GpuRef to write to it"
142                    .into(),
143            )));
144            return;
145        }
146    };
147
148    c.record_write(ctx.stream);
149
150    let cublas = ctx.cublas.clone();
151    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
152        // SAFETY: cudarc's `gemm` is unsafe because invalid
153        // m/n/k/lda/ldb/ldc can read out of bounds. The caller is
154        // responsible for valid dims.
155        let res = unsafe { cublas.gemm(cfg, &*a_slice, &*b_slice, &mut c_owned) };
156        match res {
157            Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
158            Err(e) => Err(GpuError::LibraryError {
159                lib: LIB,
160                msg: format!("gemm enqueue: {e}"),
161            }),
162        }
163    });
164}
165
166// ─────────────── concrete `GemmDispatch` impls ───────────────
167
168impl GemmDispatch for GemmRequest<f32> {
169    fn dtype_name(&self) -> &'static str {
170        <f32 as atomr_accel::AccelDtype>::NAME
171    }
172    fn op_name(&self) -> &'static str {
173        "gemm"
174    }
175    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
176        dispatch_gemm::<f32>(*self, ctx);
177    }
178}
179
180impl GemmDispatch for GemmRequest<f64> {
181    fn dtype_name(&self) -> &'static str {
182        <f64 as atomr_accel::AccelDtype>::NAME
183    }
184    fn op_name(&self) -> &'static str {
185        "gemm"
186    }
187    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
188        dispatch_gemm::<f64>(*self, ctx);
189    }
190}
191
192#[cfg(feature = "f16")]
193impl GemmDispatch for GemmRequest<half::f16> {
194    fn dtype_name(&self) -> &'static str {
195        <half::f16 as atomr_accel::AccelDtype>::NAME
196    }
197    fn op_name(&self) -> &'static str {
198        "gemm"
199    }
200    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
201        dispatch_gemm::<half::f16>(*self, ctx);
202    }
203}
204
205#[cfg(feature = "f16")]
206impl GemmDispatch for GemmRequest<half::bf16> {
207    fn dtype_name(&self) -> &'static str {
208        <half::bf16 as atomr_accel::AccelDtype>::NAME
209    }
210    fn op_name(&self) -> &'static str {
211        "gemm"
212    }
213    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
214        dispatch_gemm::<half::bf16>(*self, ctx);
215    }
216}
217
218#[cfg(test)]
219pub(crate) mod tests_helpers {
220    use crate::gpu_ref::GpuRef;
221
222    /// Fabricate a `GpuRef<T>` for op-name / dtype-name unit tests
223    /// that never dispatch the request. The returned `GpuRef` is
224    /// **leaked** by the caller (via `Box::leak` on the surrounding
225    /// boxed request) so cudarc's `Drop for CudaSlice<T>` never runs.
226    ///
227    /// SAFETY: the underlying `CudaSlice` is uninitialized. Reading
228    /// from it or dispatching the request is undefined behaviour.
229    /// Tests that use this helper only inspect the boxed
230    /// dispatcher's `op_name` / `dtype_name`, which don't touch
231    /// the slice.
232    pub fn gpu_ref_stub<T>() -> GpuRef<T> {
233        GpuRef::<T>::for_test_no_gpu_leaked()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use tokio::sync::oneshot;
241
242    fn _assert_send<T: Send>() {}
243
244    #[test]
245    fn gemm_request_dispatches_for_f32_f64_f16_bf16() {
246        // Compile-time assertion: every `GemmRequest<T>` for the
247        // dtypes we ship is `Send + 'static` so it can travel through
248        // the boxed-dispatcher mailbox.
249        _assert_send::<GemmRequest<f32>>();
250        _assert_send::<GemmRequest<f64>>();
251        #[cfg(feature = "f16")]
252        {
253            _assert_send::<GemmRequest<half::f16>>();
254            _assert_send::<GemmRequest<half::bf16>>();
255        }
256
257        // Runtime assertion: each dtype's boxed dispatcher reports
258        // its op + dtype correctly. We fabricate a `GemmRequest` and
259        // immediately box-leak it so cudarc's `Drop` for the
260        // fabricated slice never runs.
261        let req = stub_request::<f32>();
262        let boxed: Box<dyn GemmDispatch> = Box::new(req);
263        assert_eq!(boxed.op_name(), "gemm");
264        assert_eq!(boxed.dtype_name(), "f32");
265        // Leak — see helper docs.
266        Box::leak(boxed);
267
268        let req = stub_request::<f64>();
269        let boxed: Box<dyn GemmDispatch> = Box::new(req);
270        assert_eq!(boxed.dtype_name(), "f64");
271        Box::leak(boxed);
272
273        #[cfg(feature = "f16")]
274        {
275            let req = stub_request::<half::f16>();
276            let boxed: Box<dyn GemmDispatch> = Box::new(req);
277            assert_eq!(boxed.dtype_name(), "f16");
278            Box::leak(boxed);
279
280            let req = stub_request::<half::bf16>();
281            let boxed: Box<dyn GemmDispatch> = Box::new(req);
282            assert_eq!(boxed.dtype_name(), "bf16");
283            Box::leak(boxed);
284        }
285    }
286
287    #[test]
288    fn deprecated_sgemm_alias_still_constructs() {
289        // Build a legacy `BlasMsg::Sgemm(Box<SgemmRequest>)`. The
290        // explicit `#[allow(deprecated)]` exercises the back-compat
291        // path that constructs the variant; routing through
292        // `Gemm<f32>` is exercised at the actor handler level (live
293        // GPU integration tests).
294        #[allow(deprecated)]
295        {
296            let (tx, _rx) = oneshot::channel();
297            let req = crate::device::SgemmRequest {
298                a: gpu_ref_stub::<f32>(),
299                b: gpu_ref_stub::<f32>(),
300                c: gpu_ref_stub::<f32>(),
301                m: 1,
302                n: 1,
303                k: 1,
304                alpha: 1.0,
305                beta: 0.0,
306                reply: tx,
307            };
308            let msg = crate::kernel::BlasMsg::Sgemm(Box::new(req));
309            // Leak: same rationale as `gpu_ref_stub` — the
310            // fabricated CudaSlice's Drop must not run.
311            Box::leak(Box::new(msg));
312        }
313    }
314
315    /// Build a fully-populated [`GemmRequest<T>`] backed by the
316    /// fabricated [`gpu_ref_stub`] buffers. Used for op-name /
317    /// dtype-name assertions only — never dispatched.
318    fn stub_request<T>() -> GemmRequest<T>
319    where
320        T: GemmSupported + num_one_zero::NumOneZero,
321        GemmRequest<T>: GemmDispatch,
322    {
323        let (tx, _rx) = oneshot::channel();
324        // The Receiver drops at function exit; the Sender stays
325        // inside the request and is leaked along with it via
326        // `Box::leak` at the call site.
327        GemmRequest::<T> {
328            a: gpu_ref_stub::<T>(),
329            b: gpu_ref_stub::<T>(),
330            c: gpu_ref_stub::<T>(),
331            m: 1,
332            n: 1,
333            k: 1,
334            alpha: <T as num_one_zero::NumOneZero>::one(),
335            beta: <T as num_one_zero::NumOneZero>::zero(),
336            trans_a: cublasOperation_t::CUBLAS_OP_N,
337            trans_b: cublasOperation_t::CUBLAS_OP_N,
338            lda: 1,
339            ldb: 1,
340            ldc: 1,
341            reply: tx,
342        }
343    }
344
345    /// Local "one"/"zero" trait so the test stub can build alpha/beta
346    /// without depending on `num-traits`. f32/f64 use literals;
347    /// half-precision uses `half::f16::ZERO`/`ONE`.
348    mod num_one_zero {
349        pub trait NumOneZero: Copy {
350            fn one() -> Self;
351            fn zero() -> Self;
352        }
353        impl NumOneZero for f32 {
354            fn one() -> Self {
355                1.0
356            }
357            fn zero() -> Self {
358                0.0
359            }
360        }
361        impl NumOneZero for f64 {
362            fn one() -> Self {
363                1.0
364            }
365            fn zero() -> Self {
366                0.0
367            }
368        }
369        #[cfg(feature = "f16")]
370        impl NumOneZero for half::f16 {
371            fn one() -> Self {
372                half::f16::ONE
373            }
374            fn zero() -> Self {
375                half::f16::ZERO
376            }
377        }
378        #[cfg(feature = "f16")]
379        impl NumOneZero for half::bf16 {
380            fn one() -> Self {
381                half::bf16::ONE
382            }
383            fn zero() -> Self {
384                half::bf16::ZERO
385            }
386        }
387    }
388
389    use super::tests_helpers::gpu_ref_stub;
390}