atomr_accel_cuda/kernel/blas/mod.rs
1//! `BlasActor` — full cuBLAS surface (Phase 1 cuBLAS slice).
2//!
3//! Wraps a [`cudarc::cublas::CudaBlas`] handle, performs cuBLAS L1/L2/
4//! L3 ops on its assigned stream, and returns completion via the
5//! configured [`CompletionStrategy`] (§3.2 stateless-handle archetype +
6//! §5.10 callback wiring).
7//!
8//! Sub-modules:
9//! - [`gemm`] — typed `Gemm<T>` for f32/f64/f16/bf16 (cudarc safe layer)
10//! and the legacy `SgemmRequest` adapter that routes through
11//! `Gemm<f32>` for back-compat.
12//! - [`gemm_strided_batched`] — strided-batched gemm via cudarc's safe
13//! layer for f32/f64/f16/bf16; can drop to
14//! [`crate::sys::cublas::gemm_strided_batched_ex`] if more dtypes are
15//! needed in a follow-up.
16//! - [`l1`] — axpy / dot / nrm2 / scal / asum / iamax / iamin / copy /
17//! swap / rot via the cuBLAS ex-suffix entry points.
18//! - [`l2`] — gemv / ger via cudarc's `Gemv<T>` and the local
19//! `cublasGemv_v2` / `cublasGer_v2` wrappers.
20//! - [`l3`] — geam / syrk / trsm via the local `cublasSgeam` /
21//! `cublasSsyrk_v2` / `cublasStrsm_v2` wrappers (and dgeam/dsyrk/
22//! dtrsm).
23//! - [`scaling`] — fp8 scaling-factor helpers (per-tensor / per-row),
24//! stubbed under the `cublas-fp8` feature for use by `cublasGemmEx`
25//! on Hopper+.
26//!
27//! The mailbox is freed immediately after the kernel is enqueued — the
28//! actor never blocks on the GPU (§5.2). Reply delivery happens on the
29//! Tokio task spawned by [`crate::kernel::envelope::run_kernel`].
30
31use std::sync::Arc;
32
33use atomr_core::actor::{Context, Props};
34use atomr_macros::Actor;
35use cudarc::cublas::CudaBlas;
36
37use crate::completion::{CompletionStrategy, HostFnCompletion};
38use crate::device::{DeviceState, SgemmRequest};
39use crate::error::GpuError;
40use crate::kernel::dispatch::{
41 BlasDispatchCtx, BlasL1Dispatch, BlasL2Dispatch, BlasL3Dispatch, GemmDispatch,
42 GemmStridedBatchedDispatch,
43};
44use crate::stream::{ActorHints, StreamAllocator};
45
46pub mod gemm;
47pub mod gemm_strided_batched;
48pub mod l1;
49pub mod l2;
50pub mod l3;
51pub mod scaling;
52
53pub use gemm::GemmRequest;
54pub use gemm_strided_batched::GemmStridedBatchedRequest;
55pub use l1::{
56 AsumRequest, AxpyRequest, CopyRequest, DotRequest, IamaxRequest, IaminRequest, Nrm2Request,
57 RotRequest, ScalRequest, SwapRequest,
58};
59pub use l2::{GemvRequest, GerRequest};
60pub use l3::{GeamRequest, SyrkRequest, TrsmRequest};
61
62/// Public messages for `BlasActor`. Each variant boxes a typed
63/// dispatcher trait object so the dtype dimension travels through the
64/// box without forcing an N-fold mailbox explosion.
65pub enum BlasMsg {
66 /// Generic typed gemm (canonical form). Construct via
67 /// [`BlasMsg::gemm::<T>`] or
68 /// [`gemm::GemmRequest::<T>::into_msg`].
69 Gemm(Box<dyn GemmDispatch>),
70 /// L1 ops boxed in [`BlasL1Dispatch`].
71 L1(Box<dyn BlasL1Dispatch>),
72 /// L2 ops boxed in [`BlasL2Dispatch`].
73 L2(Box<dyn BlasL2Dispatch>),
74 /// L3 ops other than gemm (geam / syrk / trsm).
75 L3(Box<dyn BlasL3Dispatch>),
76 /// Strided-batched gemm.
77 GemmStridedBatched(Box<dyn GemmStridedBatchedDispatch>),
78 /// Legacy alias kept for back-compat — routes through `Gemm<f32>`
79 /// internally.
80 #[deprecated(note = "use BlasMsg::gemm::<f32>(GemmRequest::<f32> { ... })")]
81 Sgemm(Box<crate::device::SgemmRequest>),
82}
83
84impl BlasMsg {
85 /// Construct a `BlasMsg::Gemm` from a typed [`GemmRequest<T>`].
86 /// Convenience wrapper so callers don't have to box manually.
87 pub fn gemm<T: crate::dtype::GemmSupported>(req: GemmRequest<T>) -> Self
88 where
89 GemmRequest<T>: GemmDispatch,
90 {
91 Self::Gemm(Box::new(req))
92 }
93
94 /// Construct a `BlasMsg::GemmStridedBatched` from a typed request.
95 pub fn gemm_strided_batched<T: crate::dtype::GemmSupported>(
96 req: GemmStridedBatchedRequest<T>,
97 ) -> Self
98 where
99 GemmStridedBatchedRequest<T>: GemmStridedBatchedDispatch,
100 {
101 Self::GemmStridedBatched(Box::new(req))
102 }
103}
104
105/// Two-track construction: a real cuBLAS-backed actor (`props`), and a
106/// mock variant used by `examples/echo_no_gpu` and unit tests where no
107/// GPU is present.
108#[derive(Actor)]
109#[msg(BlasMsg)]
110pub struct BlasActor {
111 inner: BlasInner,
112}
113
114pub(crate) enum BlasInner {
115 Real {
116 cublas: Arc<CudaBlas>,
117 stream: Arc<cudarc::driver::CudaStream>,
118 completion: Arc<dyn CompletionStrategy>,
119 state: Arc<DeviceState>,
120 },
121 Mock,
122}
123
124impl BlasActor {
125 /// Build a [`Props<BlasActor>`] from a stream+allocator+completion
126 /// triple. Panics from inside the factory closure with
127 /// `"ContextPoisoned: CudaBlas::new failed: …"` so the supervisor
128 /// can restart the actor on handle-creation failure.
129 pub fn props(
130 stream: Arc<cudarc::driver::CudaStream>,
131 allocator: Arc<dyn StreamAllocator>,
132 completion: Arc<dyn CompletionStrategy>,
133 state: Arc<DeviceState>,
134 ) -> Props<Self> {
135 let actor_stream = allocator.acquire(ActorHints::default());
136 debug_assert!(Arc::ptr_eq(&actor_stream, &stream));
137 Props::create(move || {
138 let cublas = match CudaBlas::new(stream.clone()) {
139 Ok(b) => b,
140 Err(e) => panic!("ContextPoisoned: CudaBlas::new failed: {e}"),
141 };
142 BlasActor {
143 inner: BlasInner::Real {
144 cublas: Arc::new(cublas),
145 stream: stream.clone(),
146 completion: completion.clone(),
147 state: state.clone(),
148 },
149 }
150 })
151 }
152
153 /// Back-compat shim for callers using the F1 constructor signature.
154 /// Wraps the legacy `(stream, PerActorAllocator, HostFnCompletion)`
155 /// into the F2 form. New code should call [`BlasActor::props`].
156 pub fn props_legacy(
157 stream: Arc<cudarc::driver::CudaStream>,
158 allocator: crate::stream::PerActorAllocator,
159 completion: HostFnCompletion,
160 state: Arc<DeviceState>,
161 ) -> Props<Self> {
162 let alloc: Arc<dyn StreamAllocator> = Arc::new(allocator);
163 let comp: Arc<dyn CompletionStrategy> = Arc::new(completion);
164 Self::props(stream, alloc, comp, state)
165 }
166
167 pub fn mock_props() -> Props<Self> {
168 Props::create(|| BlasActor {
169 inner: BlasInner::Mock,
170 })
171 }
172}
173
174impl BlasActor {
175 async fn handle_msg(&mut self, _ctx: &mut Context<Self>, msg: BlasMsg) {
176 // Decompose the message into a known shape, then run the
177 // dispatcher with a borrowed `BlasDispatchCtx` so each variant
178 // shares the same enqueue/completion plumbing.
179 match &self.inner {
180 BlasInner::Mock => match msg {
181 BlasMsg::Gemm(d) => mock_reply(d.op_name()),
182 BlasMsg::L1(d) => mock_reply(d.op_name()),
183 BlasMsg::L2(d) => mock_reply(d.op_name()),
184 BlasMsg::L3(d) => mock_reply(d.op_name()),
185 BlasMsg::GemmStridedBatched(d) => mock_reply(d.op_name()),
186 #[allow(deprecated)]
187 BlasMsg::Sgemm(req) => {
188 let _ = req.reply.send(Err(GpuError::Unrecoverable(
189 "Sgemm not supported in mock mode".into(),
190 )));
191 }
192 },
193 BlasInner::Real {
194 cublas,
195 stream,
196 completion,
197 state,
198 } => {
199 let ctx = BlasDispatchCtx {
200 cublas,
201 stream,
202 completion,
203 state,
204 };
205 match msg {
206 BlasMsg::Gemm(d) => d.dispatch(&ctx),
207 BlasMsg::L1(d) => d.dispatch(&ctx),
208 BlasMsg::L2(d) => d.dispatch(&ctx),
209 BlasMsg::L3(d) => d.dispatch(&ctx),
210 BlasMsg::GemmStridedBatched(d) => d.dispatch(&ctx),
211 #[allow(deprecated)]
212 BlasMsg::Sgemm(req) => {
213 // Route through Gemm<f32> internally so all
214 // back-compat callers benefit from the same
215 // dispatch path.
216 let SgemmRequest {
217 a,
218 b,
219 c,
220 m,
221 n,
222 k,
223 alpha,
224 beta,
225 reply,
226 } = *req;
227 let typed = GemmRequest::<f32> {
228 a,
229 b,
230 c,
231 m,
232 n,
233 k,
234 alpha,
235 beta,
236 trans_a: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
237 trans_b: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
238 lda: m,
239 ldb: k,
240 ldc: m,
241 reply,
242 };
243 let boxed: Box<dyn GemmDispatch> = Box::new(typed);
244 boxed.dispatch(&ctx);
245 }
246 }
247 }
248 }
249 }
250}
251
252/// Mock-mode reply helper. We don't have access to the request's
253/// `oneshot::Sender` (it lives inside the boxed dispatcher), so the
254/// only thing we can do is drop the dispatcher — the receiver
255/// observes `Err(RecvError)` which surfaces as a typed error at the
256/// caller. Tracing logs the dropped op name so tests can spot it.
257fn mock_reply(op: &'static str) {
258 tracing::debug!(op, "BlasActor (mock): dropping op without reply");
259}