Skip to main content

atomr_accel_cuda/kernel/blas_lt/
mod.rs

1//! `BlasLtActor` — wraps [`cudarc::cublaslt::CudaBlasLT`] for
2//! transformer-shaped fused matmul (matmul + bias + activation +
3//! aux-store + bias-grad reduction) across the full dtype matrix
4//! cuBLASLt accepts.
5//!
6//! See [`epilogue`] for the curated `Epilogue` enum, [`heuristic`]
7//! for the algorithm cache, [`workspace`] for the workspace pool,
8//! [`scaling`] for the fp8 scale-pointer wiring, and [`matmul`] for
9//! the typed `MatmulRequest<T>` plus its `BlasLtDispatch` impl.
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_core::actor::{Actor, Context, Props};
15pub use cudarc::cublaslt::Activation;
16use cudarc::cublaslt::{CudaBlasLT, MatmulConfig};
17use tokio::sync::oneshot;
18
19use crate::completion::CompletionStrategy;
20use crate::device::DeviceState;
21use crate::error::GpuError;
22use crate::gpu_ref::GpuRef;
23use crate::kernel::dispatch::{BlasLtDispatch, BlasLtDispatchCtx};
24use crate::stream::StreamAllocator;
25
26pub mod epilogue;
27pub mod heuristic;
28pub mod matmul;
29pub mod scaling;
30pub mod workspace;
31
32pub use epilogue::Epilogue;
33pub use heuristic::{HeuristicCacheRef, HeuristicEntry, HeuristicKey, DEFAULT_HEURISTIC_CAPACITY};
34pub use matmul::MatmulRequest;
35pub use scaling::ScaleSet;
36pub use workspace::{WorkspaceLease, WorkspacePool};
37
38const LIB: &str = "cublaslt";
39
40/// Public message surface.
41pub enum BlasLtMsg {
42    /// Generic matmul over any dtype that implements
43    /// [`crate::dtype::GemmSupported`]. Boxed-erased so `BlasLtActor`
44    /// has a single mailbox type.
45    Matmul(Box<dyn BlasLtDispatch>),
46
47    /// Legacy f32-only constructor preserved for back-compat.
48    /// New callers should use [`BlasLtMsg::matmul`] /
49    /// [`BlasLtMsg::Matmul`] with a typed [`MatmulRequest<f32>`].
50    #[deprecated(
51        since = "0.2.0",
52        note = "use BlasLtMsg::Matmul(Box::new(MatmulRequest::<f32> { … }))"
53    )]
54    MatmulF32 {
55        cfg: MatmulConfig,
56        a: GpuRef<f32>,
57        b: GpuRef<f32>,
58        c: GpuRef<f32>,
59        bias: Option<GpuRef<f32>>,
60        activation: Option<Activation>,
61        reply: oneshot::Sender<Result<(), GpuError>>,
62    },
63}
64
65impl BlasLtMsg {
66    /// Convenience constructor — `BlasLtMsg::matmul::<f32>(req)` is a
67    /// drop-in for callers migrating off the deprecated `MatmulF32`.
68    pub fn matmul<T>(req: MatmulRequest<T>) -> Self
69    where
70        T: crate::dtype::GemmSupported,
71        MatmulRequest<T>: BlasLtDispatch,
72    {
73        Self::Matmul(Box::new(req))
74    }
75}
76
77pub struct BlasLtActor {
78    inner: BlasLtInner,
79}
80
81enum BlasLtInner {
82    Real {
83        blas_lt: Arc<CudaBlasLT>,
84        stream: Arc<cudarc::driver::CudaStream>,
85        completion: Arc<dyn CompletionStrategy>,
86        #[allow(dead_code)]
87        state: Arc<DeviceState>,
88        workspace_pool: WorkspacePool,
89        heuristic_cache: HeuristicCacheRef,
90        sm_arch: u32,
91    },
92    Mock,
93}
94
95impl BlasLtActor {
96    pub fn props(
97        stream: Arc<cudarc::driver::CudaStream>,
98        _allocator: Arc<dyn StreamAllocator>,
99        completion: Arc<dyn CompletionStrategy>,
100        state: Arc<DeviceState>,
101    ) -> Props<Self> {
102        Props::create(move || {
103            let blas_lt = match CudaBlasLT::new(stream.clone()) {
104                Ok(b) => b,
105                Err(e) => panic!("ContextPoisoned: CudaBlasLT::new failed: {e}"),
106            };
107            // SM arch detection — best-effort. cudarc exposes the
108            // device-attribute query through the stream's context, but
109            // it's a fallible runtime call. Default to 0 if we can't
110            // resolve it; the heuristic cache simply won't share
111            // entries across arches in that case (correct, just
112            // slightly less hit-rate).
113            let sm_arch = stream
114                .context()
115                .attribute(
116                    cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
117                )
118                .ok()
119                .map(|m| m as u32 * 10)
120                .unwrap_or(0);
121            BlasLtActor {
122                inner: BlasLtInner::Real {
123                    blas_lt: Arc::new(blas_lt),
124                    stream: stream.clone(),
125                    completion: completion.clone(),
126                    state: state.clone(),
127                    workspace_pool: WorkspacePool::new(),
128                    heuristic_cache: HeuristicCacheRef::default_size(),
129                    sm_arch,
130                },
131            }
132        })
133    }
134
135    pub fn mock_props() -> Props<Self> {
136        Props::create(|| BlasLtActor {
137            inner: BlasLtInner::Mock,
138        })
139    }
140}
141
142#[async_trait]
143impl Actor for BlasLtActor {
144    type Msg = BlasLtMsg;
145
146    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: BlasLtMsg) {
147        match &self.inner {
148            BlasLtInner::Mock => match msg {
149                BlasLtMsg::Matmul(req) => {
150                    // Synthesize a minimal context that lets the
151                    // dispatch's mock-mode reply path fire. We don't
152                    // touch a CUDA handle, so the typed request must
153                    // reply with `Unrecoverable("mock mode")` itself
154                    // — every `BlasLtDispatch` impl owns its reply.
155                    drop(req);
156                }
157                #[allow(deprecated)]
158                BlasLtMsg::MatmulF32 { reply, .. } => {
159                    let _ = reply.send(Err(GpuError::Unrecoverable(
160                        "BlasLtActor in mock mode".into(),
161                    )));
162                }
163            },
164            BlasLtInner::Real {
165                blas_lt,
166                stream,
167                completion,
168                workspace_pool,
169                heuristic_cache,
170                sm_arch,
171                ..
172            } => match msg {
173                BlasLtMsg::Matmul(req) => {
174                    let dctx = BlasLtDispatchCtx {
175                        blas_lt: blas_lt.clone(),
176                        stream,
177                        completion,
178                        workspace: workspace_pool,
179                        heuristic: heuristic_cache.clone(),
180                        sm_arch: *sm_arch,
181                    };
182                    req.dispatch(&dctx);
183                }
184                #[allow(deprecated)]
185                BlasLtMsg::MatmulF32 {
186                    cfg,
187                    a,
188                    b,
189                    c,
190                    bias,
191                    activation,
192                    reply,
193                } => {
194                    enqueue_matmul_f32_legacy(
195                        blas_lt.clone(),
196                        stream,
197                        completion,
198                        cfg,
199                        a,
200                        b,
201                        c,
202                        bias,
203                        activation,
204                        reply,
205                    );
206                }
207            },
208        }
209    }
210}
211
212/// Legacy-path enqueue identical to the pre-Phase-1 implementation
213/// (kept for back-compat with the deprecated `BlasLtMsg::MatmulF32`).
214fn enqueue_matmul_f32_legacy(
215    blas_lt: Arc<CudaBlasLT>,
216    stream: &Arc<cudarc::driver::CudaStream>,
217    completion: &Arc<dyn CompletionStrategy>,
218    cfg: MatmulConfig,
219    a: GpuRef<f32>,
220    b: GpuRef<f32>,
221    c: GpuRef<f32>,
222    bias: Option<GpuRef<f32>>,
223    activation: Option<Activation>,
224    reply: oneshot::Sender<Result<(), GpuError>>,
225) {
226    use crate::kernel::envelope;
227    use cudarc::cublaslt::Matmul;
228
229    let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
230        Ok(t) => t,
231        Err(e) => {
232            let _ = reply.send(Err(e));
233            return;
234        }
235    };
236    let bias_slice = match bias.as_ref() {
237        None => None,
238        Some(g) => match g.access() {
239            Ok(s) => Some(s.clone()),
240            Err(e) => {
241                let _ = reply.send(Err(e));
242                return;
243            }
244        },
245    };
246    let mut c_owned = match Arc::try_unwrap(c_slice) {
247        Ok(s) => s,
248        Err(_) => {
249            let _ = reply.send(Err(GpuError::Unrecoverable(
250                "BlasLt C has multiple live references".into(),
251            )));
252            return;
253        }
254    };
255    c.record_write(stream);
256    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
257        let bias_ref = bias_slice.as_ref().map(|s| &**s);
258        let act_ref = activation.as_ref();
259        // SAFETY: matmul is unsafe due to dim-validity contract.
260        let res =
261            unsafe { blas_lt.matmul(cfg, &*a_slice, &*b_slice, &mut c_owned, bias_ref, act_ref) };
262        match res {
263            Ok(()) => Ok((a_slice, b_slice, c_owned, bias_slice, blas_lt)),
264            Err(e) => Err(GpuError::LibraryError {
265                lib: LIB,
266                msg: format!("matmul: {e}"),
267            }),
268        }
269    });
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn blas_lt_msg_matmul_constructor() {
278        // Compile-time check: BlasLtMsg::matmul::<f32> resolves and
279        // produces a `BlasLtMsg::Matmul` variant.
280        let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
281        // We need to drop tx without sending; constructing a real
282        // MatmulRequest requires a GpuRef which needs a device, so
283        // we only verify the constructor type.
284        let _f: fn(MatmulRequest<f32>) -> BlasLtMsg = BlasLtMsg::matmul::<f32>;
285        drop(tx);
286    }
287}