Skip to main content

atomr_accel_cuda/kernel/blas/
mod.rs

1//! `BlasActor` — full cuBLAS surface (Phase 1 cuBLAS slice).
2//!
3//! Wraps a [`cudarc::cublas::CudaBlas`] handle, performs cuBLAS L1/L2/
4//! L3 ops on its assigned stream, and returns completion via the
5//! configured [`CompletionStrategy`] (§3.2 stateless-handle archetype +
6//! §5.10 callback wiring).
7//!
8//! Sub-modules:
9//! - [`gemm`] — typed `Gemm<T>` for f32/f64/f16/bf16 (cudarc safe layer)
10//!   and the legacy `SgemmRequest` adapter that routes through
11//!   `Gemm<f32>` for back-compat.
12//! - [`gemm_strided_batched`] — strided-batched gemm via cudarc's safe
13//!   layer for f32/f64/f16/bf16; can drop to
14//!   [`crate::sys::cublas::gemm_strided_batched_ex`] if more dtypes are
15//!   needed in a follow-up.
16//! - [`l1`] — axpy / dot / nrm2 / scal / asum / iamax / iamin / copy /
17//!   swap / rot via the cuBLAS ex-suffix entry points.
18//! - [`l2`] — gemv / ger via cudarc's `Gemv<T>` and the local
19//!   `cublasGemv_v2` / `cublasGer_v2` wrappers.
20//! - [`l3`] — geam / syrk / trsm via the local `cublasSgeam` /
21//!   `cublasSsyrk_v2` / `cublasStrsm_v2` wrappers (and dgeam/dsyrk/
22//!   dtrsm).
23//! - [`scaling`] — fp8 scaling-factor helpers (per-tensor / per-row),
24//!   stubbed under the `cublas-fp8` feature for use by `cublasGemmEx`
25//!   on Hopper+.
26//!
27//! The mailbox is freed immediately after the kernel is enqueued — the
28//! actor never blocks on the GPU (§5.2). Reply delivery happens on the
29//! Tokio task spawned by [`crate::kernel::envelope::run_kernel`].
30
31use std::sync::Arc;
32
33use atomr_core::actor::{Context, Props};
34use atomr_macros::Actor;
35use cudarc::cublas::CudaBlas;
36
37use crate::completion::{CompletionStrategy, HostFnCompletion};
38use crate::device::{DeviceState, SgemmRequest};
39use crate::error::GpuError;
40use crate::kernel::dispatch::{
41    BlasDispatchCtx, BlasL1Dispatch, BlasL2Dispatch, BlasL3Dispatch, GemmDispatch,
42    GemmStridedBatchedDispatch,
43};
44use crate::stream::{ActorHints, StreamAllocator};
45
46pub mod gemm;
47pub mod gemm_strided_batched;
48pub mod l1;
49pub mod l2;
50pub mod l3;
51pub mod scaling;
52
53pub use gemm::GemmRequest;
54pub use gemm_strided_batched::GemmStridedBatchedRequest;
55pub use l1::{
56    AsumRequest, AxpyRequest, CopyRequest, DotRequest, IamaxRequest, IaminRequest, Nrm2Request,
57    RotRequest, ScalRequest, SwapRequest,
58};
59pub use l2::{GemvRequest, GerRequest};
60pub use l3::{GeamRequest, SyrkRequest, TrsmRequest};
61
62/// Public messages for `BlasActor`. Each variant boxes a typed
63/// dispatcher trait object so the dtype dimension travels through the
64/// box without forcing an N-fold mailbox explosion.
65pub enum BlasMsg {
66    /// Generic typed gemm (canonical form). Construct via
67    /// [`BlasMsg::gemm::<T>`] or
68    /// [`gemm::GemmRequest::<T>::into_msg`].
69    Gemm(Box<dyn GemmDispatch>),
70    /// L1 ops boxed in [`BlasL1Dispatch`].
71    L1(Box<dyn BlasL1Dispatch>),
72    /// L2 ops boxed in [`BlasL2Dispatch`].
73    L2(Box<dyn BlasL2Dispatch>),
74    /// L3 ops other than gemm (geam / syrk / trsm).
75    L3(Box<dyn BlasL3Dispatch>),
76    /// Strided-batched gemm.
77    GemmStridedBatched(Box<dyn GemmStridedBatchedDispatch>),
78    /// Legacy alias kept for back-compat — routes through `Gemm<f32>`
79    /// internally.
80    #[deprecated(note = "use BlasMsg::gemm::<f32>(GemmRequest::<f32> { ... })")]
81    Sgemm(Box<crate::device::SgemmRequest>),
82}
83
84impl BlasMsg {
85    /// Construct a `BlasMsg::Gemm` from a typed [`GemmRequest<T>`].
86    /// Convenience wrapper so callers don't have to box manually.
87    pub fn gemm<T: crate::dtype::GemmSupported>(req: GemmRequest<T>) -> Self
88    where
89        GemmRequest<T>: GemmDispatch,
90    {
91        Self::Gemm(Box::new(req))
92    }
93
94    /// Construct a `BlasMsg::GemmStridedBatched` from a typed request.
95    pub fn gemm_strided_batched<T: crate::dtype::GemmSupported>(
96        req: GemmStridedBatchedRequest<T>,
97    ) -> Self
98    where
99        GemmStridedBatchedRequest<T>: GemmStridedBatchedDispatch,
100    {
101        Self::GemmStridedBatched(Box::new(req))
102    }
103}
104
105/// Two-track construction: a real cuBLAS-backed actor (`props`), and a
106/// mock variant used by `examples/echo_no_gpu` and unit tests where no
107/// GPU is present.
108#[derive(Actor)]
109#[msg(BlasMsg)]
110pub struct BlasActor {
111    inner: BlasInner,
112}
113
114pub(crate) enum BlasInner {
115    Real {
116        cublas: Arc<CudaBlas>,
117        stream: Arc<cudarc::driver::CudaStream>,
118        completion: Arc<dyn CompletionStrategy>,
119        state: Arc<DeviceState>,
120    },
121    Mock,
122}
123
124impl BlasActor {
125    /// Build a [`Props<BlasActor>`] from a stream+allocator+completion
126    /// triple. Panics from inside the factory closure with
127    /// `"ContextPoisoned: CudaBlas::new failed: …"` so the supervisor
128    /// can restart the actor on handle-creation failure.
129    pub fn props(
130        stream: Arc<cudarc::driver::CudaStream>,
131        allocator: Arc<dyn StreamAllocator>,
132        completion: Arc<dyn CompletionStrategy>,
133        state: Arc<DeviceState>,
134    ) -> Props<Self> {
135        let actor_stream = allocator.acquire(ActorHints::default());
136        debug_assert!(Arc::ptr_eq(&actor_stream, &stream));
137        Props::create(move || {
138            let cublas = match CudaBlas::new(stream.clone()) {
139                Ok(b) => b,
140                Err(e) => panic!("ContextPoisoned: CudaBlas::new failed: {e}"),
141            };
142            BlasActor {
143                inner: BlasInner::Real {
144                    cublas: Arc::new(cublas),
145                    stream: stream.clone(),
146                    completion: completion.clone(),
147                    state: state.clone(),
148                },
149            }
150        })
151    }
152
153    /// Back-compat shim for callers using the F1 constructor signature.
154    /// Wraps the legacy `(stream, PerActorAllocator, HostFnCompletion)`
155    /// into the F2 form. New code should call [`BlasActor::props`].
156    pub fn props_legacy(
157        stream: Arc<cudarc::driver::CudaStream>,
158        allocator: crate::stream::PerActorAllocator,
159        completion: HostFnCompletion,
160        state: Arc<DeviceState>,
161    ) -> Props<Self> {
162        let alloc: Arc<dyn StreamAllocator> = Arc::new(allocator);
163        let comp: Arc<dyn CompletionStrategy> = Arc::new(completion);
164        Self::props(stream, alloc, comp, state)
165    }
166
167    pub fn mock_props() -> Props<Self> {
168        Props::create(|| BlasActor {
169            inner: BlasInner::Mock,
170        })
171    }
172}
173
174impl BlasActor {
175    async fn handle_msg(&mut self, _ctx: &mut Context<Self>, msg: BlasMsg) {
176        // Decompose the message into a known shape, then run the
177        // dispatcher with a borrowed `BlasDispatchCtx` so each variant
178        // shares the same enqueue/completion plumbing.
179        match &self.inner {
180            BlasInner::Mock => match msg {
181                BlasMsg::Gemm(d) => mock_reply(d.op_name()),
182                BlasMsg::L1(d) => mock_reply(d.op_name()),
183                BlasMsg::L2(d) => mock_reply(d.op_name()),
184                BlasMsg::L3(d) => mock_reply(d.op_name()),
185                BlasMsg::GemmStridedBatched(d) => mock_reply(d.op_name()),
186                #[allow(deprecated)]
187                BlasMsg::Sgemm(req) => {
188                    let _ = req.reply.send(Err(GpuError::Unrecoverable(
189                        "Sgemm not supported in mock mode".into(),
190                    )));
191                }
192            },
193            BlasInner::Real {
194                cublas,
195                stream,
196                completion,
197                state,
198            } => {
199                let ctx = BlasDispatchCtx {
200                    cublas,
201                    stream,
202                    completion,
203                    state,
204                };
205                match msg {
206                    BlasMsg::Gemm(d) => d.dispatch(&ctx),
207                    BlasMsg::L1(d) => d.dispatch(&ctx),
208                    BlasMsg::L2(d) => d.dispatch(&ctx),
209                    BlasMsg::L3(d) => d.dispatch(&ctx),
210                    BlasMsg::GemmStridedBatched(d) => d.dispatch(&ctx),
211                    #[allow(deprecated)]
212                    BlasMsg::Sgemm(req) => {
213                        // Route through Gemm<f32> internally so all
214                        // back-compat callers benefit from the same
215                        // dispatch path.
216                        let SgemmRequest {
217                            a,
218                            b,
219                            c,
220                            m,
221                            n,
222                            k,
223                            alpha,
224                            beta,
225                            reply,
226                        } = *req;
227                        let typed = GemmRequest::<f32> {
228                            a,
229                            b,
230                            c,
231                            m,
232                            n,
233                            k,
234                            alpha,
235                            beta,
236                            trans_a: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
237                            trans_b: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
238                            lda: m,
239                            ldb: k,
240                            ldc: m,
241                            reply,
242                        };
243                        let boxed: Box<dyn GemmDispatch> = Box::new(typed);
244                        boxed.dispatch(&ctx);
245                    }
246                }
247            }
248        }
249    }
250}
251
252/// Mock-mode reply helper. We don't have access to the request's
253/// `oneshot::Sender` (it lives inside the boxed dispatcher), so the
254/// only thing we can do is drop the dispatcher — the receiver
255/// observes `Err(RecvError)` which surfaces as a typed error at the
256/// caller. Tracing logs the dropped op name so tests can spot it.
257fn mock_reply(op: &'static str) {
258    tracing::debug!(op, "BlasActor (mock): dropping op without reply");
259}