atomr_accel_cuda/kernel/blas_lt/
matmul.rs1use std::sync::Arc;
21
22use cudarc::cublaslt::{Activation, Matmul, MatmulConfig};
23use tokio::sync::oneshot;
24
25use crate::dtype::{DTypeKind, GemmSupported};
26use crate::error::GpuError;
27use crate::gpu_ref::GpuRef;
28use crate::kernel::blas_lt::epilogue::Epilogue;
29use crate::kernel::blas_lt::scaling::ScaleSet;
30use crate::kernel::dispatch::{BlasLtDispatch, BlasLtDispatchCtx};
31use crate::kernel::envelope;
32
33const LIB: &str = "cublaslt";
34
35pub struct MatmulRequest<T: GemmSupported> {
37 pub a: GpuRef<T>,
38 pub b: GpuRef<T>,
39 pub c: GpuRef<T>,
40 pub d: Option<GpuRef<T>>,
45 pub m: i32,
46 pub n: i32,
47 pub k: i32,
48 pub alpha: T::Scalar,
49 pub beta: T::Scalar,
50 pub transa: bool,
51 pub transb: bool,
52 pub lda: i64,
53 pub ldb: i64,
54 pub ldc: i64,
55 pub ldd: i64,
56 pub epilogue: Epilogue,
57 pub bias: Option<GpuRef<T>>,
58 pub gelu_aux: Option<GpuRef<T>>,
59 pub scales: ScaleSet,
60 pub workspace_size: usize,
64 pub reply: oneshot::Sender<Result<(), GpuError>>,
65}
66
67impl<T: GemmSupported> std::fmt::Debug for MatmulRequest<T> {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("MatmulRequest")
70 .field("dtype", &T::NAME)
71 .field("m", &self.m)
72 .field("n", &self.n)
73 .field("k", &self.k)
74 .field("transa", &self.transa)
75 .field("transb", &self.transb)
76 .field("epilogue", &self.epilogue)
77 .field("workspace_size", &self.workspace_size)
78 .finish()
79 }
80}
81
82trait CudarcMatmulPath: GemmSupported {
85 fn dispatch_safe(req: Box<MatmulRequest<Self>>, ctx: &BlasLtDispatchCtx<'_>);
86}
87
88impl CudarcMatmulPath for f32 {
89 fn dispatch_safe(req: Box<MatmulRequest<f32>>, ctx: &BlasLtDispatchCtx<'_>) {
90 dispatch_safe_path::<f32>(req, ctx);
91 }
92}
93
94#[cfg(feature = "f16")]
95impl CudarcMatmulPath for half::f16 {
96 fn dispatch_safe(req: Box<MatmulRequest<half::f16>>, ctx: &BlasLtDispatchCtx<'_>) {
97 dispatch_safe_path::<half::f16>(req, ctx);
98 }
99}
100
101#[cfg(feature = "f16")]
102impl CudarcMatmulPath for half::bf16 {
103 fn dispatch_safe(req: Box<MatmulRequest<half::bf16>>, ctx: &BlasLtDispatchCtx<'_>) {
104 dispatch_safe_path::<half::bf16>(req, ctx);
105 }
106}
107
108trait UnsupportedMatmulPath {
112 fn dispatch_unsupported(reply: oneshot::Sender<Result<(), GpuError>>, dtype: &'static str);
113}
114
115impl<T> UnsupportedMatmulPath for T {
116 fn dispatch_unsupported(reply: oneshot::Sender<Result<(), GpuError>>, dtype: &'static str) {
117 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
118 "BlasLtActor: matmul<{dtype}> not yet implemented (Phase 1 sys-level wiring pending)"
119 ))));
120 }
121}
122
123impl BlasLtDispatch for MatmulRequest<f64> {
125 fn dtype_kind(&self) -> DTypeKind {
126 DTypeKind::F64
127 }
128 fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
129 <f64 as UnsupportedMatmulPath>::dispatch_unsupported(self.reply, "f64");
130 }
131}
132
133impl BlasLtDispatch for MatmulRequest<f32> {
134 fn dtype_kind(&self) -> DTypeKind {
135 DTypeKind::F32
136 }
137 fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
138 <f32 as CudarcMatmulPath>::dispatch_safe(self, ctx);
139 }
140}
141
142#[cfg(feature = "f16")]
143impl BlasLtDispatch for MatmulRequest<half::f16> {
144 fn dtype_kind(&self) -> DTypeKind {
145 DTypeKind::F16
146 }
147 fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
148 <half::f16 as CudarcMatmulPath>::dispatch_safe(self, ctx);
149 }
150}
151
152#[cfg(feature = "f16")]
153impl BlasLtDispatch for MatmulRequest<half::bf16> {
154 fn dtype_kind(&self) -> DTypeKind {
155 DTypeKind::Bf16
156 }
157 fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
158 <half::bf16 as CudarcMatmulPath>::dispatch_safe(self, ctx);
159 }
160}
161
162#[cfg(feature = "cublas-fp8")]
163impl BlasLtDispatch for MatmulRequest<crate::dtype::F8E4m3> {
164 fn dtype_kind(&self) -> DTypeKind {
165 DTypeKind::F8E4m3
166 }
167 fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
168 <crate::dtype::F8E4m3 as UnsupportedMatmulPath>::dispatch_unsupported(
169 self.reply, "fp8e4m3",
170 );
171 }
172}
173
174#[cfg(feature = "cublas-fp8")]
175impl BlasLtDispatch for MatmulRequest<crate::dtype::F8E5m2> {
176 fn dtype_kind(&self) -> DTypeKind {
177 DTypeKind::F8E5m2
178 }
179 fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
180 <crate::dtype::F8E5m2 as UnsupportedMatmulPath>::dispatch_unsupported(
181 self.reply, "fp8e5m2",
182 );
183 }
184}
185
186fn dispatch_safe_path<T>(req: Box<MatmulRequest<T>>, ctx: &BlasLtDispatchCtx<'_>)
193where
194 T: GemmSupported + cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
195 cudarc::cublaslt::CudaBlasLT: Matmul<T>,
196 T::Scalar: Into<f32> + Copy,
197{
198 let MatmulRequest {
199 a,
200 b,
201 c,
202 d: _d,
203 m,
204 n,
205 k,
206 alpha,
207 beta,
208 transa,
209 transb,
210 lda,
211 ldb,
212 ldc,
213 ldd: _ldd,
214 epilogue,
215 bias,
216 gelu_aux: _gelu_aux,
217 scales: _scales,
218 workspace_size: _workspace_size,
219 reply,
220 } = *req;
221
222 let _entry = ctx
226 .heuristic
227 .get(&crate::kernel::blas_lt::heuristic::HeuristicKey::new(
228 m,
229 n,
230 k,
231 T::KIND,
232 transa,
233 transb,
234 epilogue,
235 ctx.sm_arch,
236 ));
237
238 let activation = match epilogue {
243 Epilogue::Relu | Epilogue::ReluBias | Epilogue::ReluAux | Epilogue::ReluAuxBias => {
244 Some(Activation::Relu)
245 }
246 Epilogue::Gelu | Epilogue::GeluBias | Epilogue::GeluAux | Epilogue::GeluAuxBias => {
247 Some(Activation::Gelu)
248 }
249 _ => None,
250 };
251
252 let cfg = MatmulConfig {
253 transa,
254 transb,
255 transc: false,
256 m: m as u64,
257 n: n as u64,
258 k: k as u64,
259 alpha: alpha.into(),
260 lda,
261 ldb,
262 beta: beta.into(),
263 ldc,
264 stride_a: None,
265 stride_b: None,
266 stride_c: None,
267 stride_bias: None,
268 batch_size: None,
269 };
270
271 let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
272 Ok(t) => t,
273 Err(e) => {
274 let _ = reply.send(Err(e));
275 return;
276 }
277 };
278 let bias_slice = match bias.as_ref() {
279 None => None,
280 Some(g) => match g.access() {
281 Ok(s) => Some(s.clone()),
282 Err(e) => {
283 let _ = reply.send(Err(e));
284 return;
285 }
286 },
287 };
288 let mut c_owned = match Arc::try_unwrap(c_slice) {
289 Ok(s) => s,
290 Err(_) => {
291 let _ = reply.send(Err(GpuError::Unrecoverable(
292 "BlasLt C has multiple live references".into(),
293 )));
294 return;
295 }
296 };
297 c.record_write(ctx.stream);
298
299 let blas_lt = ctx.blas_lt.clone();
300 let stream = ctx.stream;
301 let completion = ctx.completion;
302
303 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
304 let bias_ref = bias_slice.as_ref().map(|s| &**s);
305 let act_ref = activation.as_ref();
306 let res =
308 unsafe { blas_lt.matmul(cfg, &*a_slice, &*b_slice, &mut c_owned, bias_ref, act_ref) };
309 match res {
310 Ok(()) => Ok((a_slice, b_slice, c_owned, bias_slice, blas_lt)),
311 Err(e) => Err(GpuError::LibraryError {
312 lib: LIB,
313 msg: format!("matmul: {e}"),
314 }),
315 }
316 });
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 fn make_request<T: GemmSupported>() -> MatmulRequest<T>
324 where
325 T::Scalar: Default,
326 {
327 let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
328 let _ = (T::NAME, tx);
332 unreachable!("type-instantiation-only helper")
333 }
334
335 #[test]
341 fn matmul_request_dispatches_for_f32_f16_bf16() {
342 fn _accepts_f32(b: Box<dyn BlasLtDispatch>) -> Box<dyn BlasLtDispatch> {
344 b
345 }
346 let _f32_kind: fn(&MatmulRequest<f32>) -> DTypeKind = MatmulRequest::<f32>::dtype_kind;
348 let _f64_kind: fn(&MatmulRequest<f64>) -> DTypeKind = MatmulRequest::<f64>::dtype_kind;
349 #[cfg(feature = "f16")]
350 let _f16_kind: fn(&MatmulRequest<half::f16>) -> DTypeKind =
351 MatmulRequest::<half::f16>::dtype_kind;
352 #[cfg(feature = "f16")]
353 let _bf16_kind: fn(&MatmulRequest<half::bf16>) -> DTypeKind =
354 MatmulRequest::<half::bf16>::dtype_kind;
355
356 assert_eq!(<f32 as atomr_accel::AccelDtype>::KIND, DTypeKind::F32);
360 assert_eq!(<f64 as atomr_accel::AccelDtype>::KIND, DTypeKind::F64);
361 #[cfg(feature = "f16")]
362 {
363 assert_eq!(<half::f16 as atomr_accel::AccelDtype>::KIND, DTypeKind::F16);
364 assert_eq!(
365 <half::bf16 as atomr_accel::AccelDtype>::KIND,
366 DTypeKind::Bf16
367 );
368 }
369 let _ = make_request::<f32> as fn() -> MatmulRequest<f32>;
371 }
372}