Skip to main content

atomr_accel_cuda/kernel/blas_lt/
matmul.rs

1//! Typed `MatmulRequest<T: GemmSupported>` plus the `BlasLtDispatch`
2//! impl that routes it through the kernel envelope.
3//!
4//! Today's pre-Phase-1 actor accepted only `MatmulConfig + GpuRef<f32>`.
5//! `MatmulRequest<T>` widens that to:
6//! - any `T: GemmSupported` (f32 / f64 / f16 / bf16 / fp8),
7//! - explicit `D` output buffer (so fp8 split-k and out-of-place
8//!   cases work),
9//! - the curated [`Epilogue`] enum,
10//! - optional `bias`, `gelu_aux`,
11//! - per-tensor / per-row fp8 scale pointers via [`ScaleSet`],
12//! - a `workspace_size` hint folded into the heuristic search.
13//!
14//! cudarc 0.19.4's safe `Matmul` trait is implemented for `f32` and
15//! (under feature `f16`) `half::f16` / `half::bf16`. For dtypes
16//! cudarc doesn't yet wrap (fp8) the dispatch falls through to a
17//! typed `Err(GpuError::Unrecoverable)` until we land the sys-level
18//! path — see [`dispatch_safe_path`] below.
19
20use std::sync::Arc;
21
22use cudarc::cublaslt::{Activation, Matmul, MatmulConfig};
23use tokio::sync::oneshot;
24
25use crate::dtype::{DTypeKind, GemmSupported};
26use crate::error::GpuError;
27use crate::gpu_ref::GpuRef;
28use crate::kernel::blas_lt::epilogue::Epilogue;
29use crate::kernel::blas_lt::scaling::ScaleSet;
30use crate::kernel::dispatch::{BlasLtDispatch, BlasLtDispatchCtx};
31use crate::kernel::envelope;
32
33const LIB: &str = "cublaslt";
34
35/// Typed matmul request. Public surface; instantiated by callers.
36pub struct MatmulRequest<T: GemmSupported> {
37    pub a: GpuRef<T>,
38    pub b: GpuRef<T>,
39    pub c: GpuRef<T>,
40    /// Optional explicit `D` output buffer. cuBLASLt allows
41    /// out-of-place matmul where the result lands in `D` rather than
42    /// in-place into `C`. Required for fp8 (the scale-back step
43    /// produces a different dtype than the accumulator).
44    pub d: Option<GpuRef<T>>,
45    pub m: i32,
46    pub n: i32,
47    pub k: i32,
48    pub alpha: T::Scalar,
49    pub beta: T::Scalar,
50    pub transa: bool,
51    pub transb: bool,
52    pub lda: i64,
53    pub ldb: i64,
54    pub ldc: i64,
55    pub ldd: i64,
56    pub epilogue: Epilogue,
57    pub bias: Option<GpuRef<T>>,
58    pub gelu_aux: Option<GpuRef<T>>,
59    pub scales: ScaleSet,
60    /// Hint for the heuristic: maximum workspace bytes the algorithm
61    /// search may use. A reasonable default is `4 * 1024 * 1024`
62    /// (cuBLASLt's standard 4 MiB minimum).
63    pub workspace_size: usize,
64    pub reply: oneshot::Sender<Result<(), GpuError>>,
65}
66
67impl<T: GemmSupported> std::fmt::Debug for MatmulRequest<T> {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("MatmulRequest")
70            .field("dtype", &T::NAME)
71            .field("m", &self.m)
72            .field("n", &self.n)
73            .field("k", &self.k)
74            .field("transa", &self.transa)
75            .field("transb", &self.transb)
76            .field("epilogue", &self.epilogue)
77            .field("workspace_size", &self.workspace_size)
78            .finish()
79    }
80}
81
82/// Internal sealing trait — bridges `T: GemmSupported` to the cudarc
83/// `Matmul<T>` impls. Concrete impls below; one per safe-dtype.
84trait CudarcMatmulPath: GemmSupported {
85    fn dispatch_safe(req: Box<MatmulRequest<Self>>, ctx: &BlasLtDispatchCtx<'_>);
86}
87
88impl CudarcMatmulPath for f32 {
89    fn dispatch_safe(req: Box<MatmulRequest<f32>>, ctx: &BlasLtDispatchCtx<'_>) {
90        dispatch_safe_path::<f32>(req, ctx);
91    }
92}
93
94#[cfg(feature = "f16")]
95impl CudarcMatmulPath for half::f16 {
96    fn dispatch_safe(req: Box<MatmulRequest<half::f16>>, ctx: &BlasLtDispatchCtx<'_>) {
97        dispatch_safe_path::<half::f16>(req, ctx);
98    }
99}
100
101#[cfg(feature = "f16")]
102impl CudarcMatmulPath for half::bf16 {
103    fn dispatch_safe(req: Box<MatmulRequest<half::bf16>>, ctx: &BlasLtDispatchCtx<'_>) {
104        dispatch_safe_path::<half::bf16>(req, ctx);
105    }
106}
107
108/// Bridge for dtypes cudarc 0.19.4 doesn't wrap with a `Matmul<T>`
109/// impl yet (f64, fp8). Reply with `Unrecoverable("dtype …")` until
110/// the sys-level path lands.
111trait UnsupportedMatmulPath {
112    fn dispatch_unsupported(reply: oneshot::Sender<Result<(), GpuError>>, dtype: &'static str);
113}
114
115impl<T> UnsupportedMatmulPath for T {
116    fn dispatch_unsupported(reply: oneshot::Sender<Result<(), GpuError>>, dtype: &'static str) {
117        let _ = reply.send(Err(GpuError::Unrecoverable(format!(
118            "BlasLtActor: matmul<{dtype}> not yet implemented (Phase 1 sys-level wiring pending)"
119        ))));
120    }
121}
122
123/// f64 (cudarc 0.19.4 has no `Matmul<f64>` impl).
124impl BlasLtDispatch for MatmulRequest<f64> {
125    fn dtype_kind(&self) -> DTypeKind {
126        DTypeKind::F64
127    }
128    fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
129        <f64 as UnsupportedMatmulPath>::dispatch_unsupported(self.reply, "f64");
130    }
131}
132
133impl BlasLtDispatch for MatmulRequest<f32> {
134    fn dtype_kind(&self) -> DTypeKind {
135        DTypeKind::F32
136    }
137    fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
138        <f32 as CudarcMatmulPath>::dispatch_safe(self, ctx);
139    }
140}
141
142#[cfg(feature = "f16")]
143impl BlasLtDispatch for MatmulRequest<half::f16> {
144    fn dtype_kind(&self) -> DTypeKind {
145        DTypeKind::F16
146    }
147    fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
148        <half::f16 as CudarcMatmulPath>::dispatch_safe(self, ctx);
149    }
150}
151
152#[cfg(feature = "f16")]
153impl BlasLtDispatch for MatmulRequest<half::bf16> {
154    fn dtype_kind(&self) -> DTypeKind {
155        DTypeKind::Bf16
156    }
157    fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
158        <half::bf16 as CudarcMatmulPath>::dispatch_safe(self, ctx);
159    }
160}
161
162#[cfg(feature = "cublas-fp8")]
163impl BlasLtDispatch for MatmulRequest<crate::dtype::F8E4m3> {
164    fn dtype_kind(&self) -> DTypeKind {
165        DTypeKind::F8E4m3
166    }
167    fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
168        <crate::dtype::F8E4m3 as UnsupportedMatmulPath>::dispatch_unsupported(
169            self.reply, "fp8e4m3",
170        );
171    }
172}
173
174#[cfg(feature = "cublas-fp8")]
175impl BlasLtDispatch for MatmulRequest<crate::dtype::F8E5m2> {
176    fn dtype_kind(&self) -> DTypeKind {
177        DTypeKind::F8E5m2
178    }
179    fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
180        <crate::dtype::F8E5m2 as UnsupportedMatmulPath>::dispatch_unsupported(
181            self.reply, "fp8e5m2",
182        );
183    }
184}
185
186/// Body of the safe-cudarc dispatch path. The function is generic
187/// over `T` so all three (f32, f16, bf16) share one body. The
188/// scale-pointer / heuristic / workspace-pool integration is wired
189/// through the `MatmulConfig`'s alpha/beta (currently f32-only at the
190/// cudarc surface) — once cudarc lands `Matmul` for fp8 we'll
191/// promote this helper to call into the sys-level descriptor path.
192fn dispatch_safe_path<T>(req: Box<MatmulRequest<T>>, ctx: &BlasLtDispatchCtx<'_>)
193where
194    T: GemmSupported + cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
195    cudarc::cublaslt::CudaBlasLT: Matmul<T>,
196    T::Scalar: Into<f32> + Copy,
197{
198    let MatmulRequest {
199        a,
200        b,
201        c,
202        d: _d,
203        m,
204        n,
205        k,
206        alpha,
207        beta,
208        transa,
209        transb,
210        lda,
211        ldb,
212        ldc,
213        ldd: _ldd,
214        epilogue,
215        bias,
216        gelu_aux: _gelu_aux,
217        scales: _scales,
218        workspace_size: _workspace_size,
219        reply,
220    } = *req;
221
222    // Touch the heuristic cache so we record at least an empty entry
223    // for this shape. (Real heuristic search lands when we move off
224    // cudarc's safe `Matmul` API onto the sys-level descriptor path.)
225    let _entry = ctx
226        .heuristic
227        .get(&crate::kernel::blas_lt::heuristic::HeuristicKey::new(
228            m,
229            n,
230            k,
231            T::KIND,
232            transa,
233            transb,
234            epilogue,
235            ctx.sm_arch,
236        ));
237
238    // Map the curated Epilogue back to cudarc's safe `Activation`
239    // (cudarc's safe API only exposes Relu/Gelu for now). Other
240    // variants degrade to "no activation" under the safe path; the
241    // forthcoming sys-level path consumes the full enum.
242    let activation = match epilogue {
243        Epilogue::Relu | Epilogue::ReluBias | Epilogue::ReluAux | Epilogue::ReluAuxBias => {
244            Some(Activation::Relu)
245        }
246        Epilogue::Gelu | Epilogue::GeluBias | Epilogue::GeluAux | Epilogue::GeluAuxBias => {
247            Some(Activation::Gelu)
248        }
249        _ => None,
250    };
251
252    let cfg = MatmulConfig {
253        transa,
254        transb,
255        transc: false,
256        m: m as u64,
257        n: n as u64,
258        k: k as u64,
259        alpha: alpha.into(),
260        lda,
261        ldb,
262        beta: beta.into(),
263        ldc,
264        stride_a: None,
265        stride_b: None,
266        stride_c: None,
267        stride_bias: None,
268        batch_size: None,
269    };
270
271    let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
272        Ok(t) => t,
273        Err(e) => {
274            let _ = reply.send(Err(e));
275            return;
276        }
277    };
278    let bias_slice = match bias.as_ref() {
279        None => None,
280        Some(g) => match g.access() {
281            Ok(s) => Some(s.clone()),
282            Err(e) => {
283                let _ = reply.send(Err(e));
284                return;
285            }
286        },
287    };
288    let mut c_owned = match Arc::try_unwrap(c_slice) {
289        Ok(s) => s,
290        Err(_) => {
291            let _ = reply.send(Err(GpuError::Unrecoverable(
292                "BlasLt C has multiple live references".into(),
293            )));
294            return;
295        }
296    };
297    c.record_write(ctx.stream);
298
299    let blas_lt = ctx.blas_lt.clone();
300    let stream = ctx.stream;
301    let completion = ctx.completion;
302
303    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
304        let bias_ref = bias_slice.as_ref().map(|s| &**s);
305        let act_ref = activation.as_ref();
306        // SAFETY: matmul is unsafe due to dim-validity contract.
307        let res =
308            unsafe { blas_lt.matmul(cfg, &*a_slice, &*b_slice, &mut c_owned, bias_ref, act_ref) };
309        match res {
310            Ok(()) => Ok((a_slice, b_slice, c_owned, bias_slice, blas_lt)),
311            Err(e) => Err(GpuError::LibraryError {
312                lib: LIB,
313                msg: format!("matmul: {e}"),
314            }),
315        }
316    });
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    fn make_request<T: GemmSupported>() -> MatmulRequest<T>
324    where
325        T::Scalar: Default,
326    {
327        let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
328        // We can't actually construct GpuRef<T> without a DeviceState,
329        // so this helper short-circuits via `make_request_unfilled`
330        // below. The test verifies the type instantiation only.
331        let _ = (T::NAME, tx);
332        unreachable!("type-instantiation-only helper")
333    }
334
335    /// Compile-time check that `MatmulRequest<T>` instantiates for
336    /// every dtype the dispatch trait covers under the active
337    /// feature set. We materialize a function pointer to each
338    /// dispatch impl — that's enough to exercise the trait bounds
339    /// without constructing real GpuRefs.
340    #[test]
341    fn matmul_request_dispatches_for_f32_f16_bf16() {
342        // Type-id ping for the f32 dispatch path.
343        fn _accepts_f32(b: Box<dyn BlasLtDispatch>) -> Box<dyn BlasLtDispatch> {
344            b
345        }
346        // Probe the trait impls exist for every required dtype.
347        let _f32_kind: fn(&MatmulRequest<f32>) -> DTypeKind = MatmulRequest::<f32>::dtype_kind;
348        let _f64_kind: fn(&MatmulRequest<f64>) -> DTypeKind = MatmulRequest::<f64>::dtype_kind;
349        #[cfg(feature = "f16")]
350        let _f16_kind: fn(&MatmulRequest<half::f16>) -> DTypeKind =
351            MatmulRequest::<half::f16>::dtype_kind;
352        #[cfg(feature = "f16")]
353        let _bf16_kind: fn(&MatmulRequest<half::bf16>) -> DTypeKind =
354            MatmulRequest::<half::bf16>::dtype_kind;
355
356        // Confirm the kind tags line up.
357        // We can't construct a request without a GpuRef, but we *can*
358        // probe the const dtype tags through `CudaDtype::KIND`.
359        assert_eq!(<f32 as atomr_accel::AccelDtype>::KIND, DTypeKind::F32);
360        assert_eq!(<f64 as atomr_accel::AccelDtype>::KIND, DTypeKind::F64);
361        #[cfg(feature = "f16")]
362        {
363            assert_eq!(<half::f16 as atomr_accel::AccelDtype>::KIND, DTypeKind::F16);
364            assert_eq!(
365                <half::bf16 as atomr_accel::AccelDtype>::KIND,
366                DTypeKind::Bf16
367            );
368        }
369        // Suppress unused-warning on the unreachable helper.
370        let _ = make_request::<f32> as fn() -> MatmulRequest<f32>;
371    }
372}