atomr_accel_cuda/kernel/blas/
gemm.rs1use std::sync::Arc;
11
12use cudarc::cublas::sys::cublasOperation_t;
13use cudarc::cublas::{Gemm, GemmConfig};
14use tokio::sync::oneshot;
15
16use crate::dtype::GemmSupported;
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::dispatch::{BlasDispatchCtx, GemmDispatch};
20use crate::kernel::envelope;
21
22const LIB: &str = "cublas";
23
24pub struct GemmRequest<T: GemmSupported> {
56 pub a: GpuRef<T>,
57 pub b: GpuRef<T>,
58 pub c: GpuRef<T>,
59 pub m: i32,
60 pub n: i32,
61 pub k: i32,
62 pub alpha: T,
63 pub beta: T,
64 pub trans_a: cublasOperation_t,
65 pub trans_b: cublasOperation_t,
66 pub lda: i32,
67 pub ldb: i32,
68 pub ldc: i32,
69 pub reply: oneshot::Sender<Result<(), GpuError>>,
70}
71
72impl<T> GemmRequest<T>
73where
74 T: GemmSupported,
75 GemmRequest<T>: GemmDispatch,
76{
77 pub fn into_msg(self) -> crate::kernel::BlasMsg {
79 crate::kernel::BlasMsg::Gemm(Box::new(self))
80 }
81}
82
83fn dispatch_gemm<T>(req: GemmRequest<T>, ctx: &BlasDispatchCtx<'_>)
90where
91 T: GemmSupported + Copy,
92 cudarc::cublas::CudaBlas: Gemm<T>,
93{
94 let GemmRequest {
95 a,
96 b,
97 c,
98 m,
99 n,
100 k,
101 alpha,
102 beta,
103 trans_a,
104 trans_b,
105 lda,
106 ldb,
107 ldc,
108 reply,
109 } = req;
110
111 let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
112 Ok(t) => t,
113 Err(e) => {
114 let _ = reply.send(Err(e));
115 return;
116 }
117 };
118
119 let cfg = GemmConfig::<T> {
120 transa: trans_a,
121 transb: trans_b,
122 m,
123 n,
124 k,
125 alpha,
126 lda,
127 ldb,
128 beta,
129 ldc,
130 };
131
132 let mut c_owned = match Arc::try_unwrap(c_slice) {
137 Ok(s) => s,
138 Err(_arc) => {
139 let _ = reply.send(Err(GpuError::Unrecoverable(
140 "GEMM target buffer C has more than one live reference; \
141 caller must hold the unique GpuRef to write to it"
142 .into(),
143 )));
144 return;
145 }
146 };
147
148 c.record_write(ctx.stream);
149
150 let cublas = ctx.cublas.clone();
151 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
152 let res = unsafe { cublas.gemm(cfg, &*a_slice, &*b_slice, &mut c_owned) };
156 match res {
157 Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
158 Err(e) => Err(GpuError::LibraryError {
159 lib: LIB,
160 msg: format!("gemm enqueue: {e}"),
161 }),
162 }
163 });
164}
165
166impl GemmDispatch for GemmRequest<f32> {
169 fn dtype_name(&self) -> &'static str {
170 <f32 as atomr_accel::AccelDtype>::NAME
171 }
172 fn op_name(&self) -> &'static str {
173 "gemm"
174 }
175 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
176 dispatch_gemm::<f32>(*self, ctx);
177 }
178}
179
180impl GemmDispatch for GemmRequest<f64> {
181 fn dtype_name(&self) -> &'static str {
182 <f64 as atomr_accel::AccelDtype>::NAME
183 }
184 fn op_name(&self) -> &'static str {
185 "gemm"
186 }
187 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
188 dispatch_gemm::<f64>(*self, ctx);
189 }
190}
191
192#[cfg(feature = "f16")]
193impl GemmDispatch for GemmRequest<half::f16> {
194 fn dtype_name(&self) -> &'static str {
195 <half::f16 as atomr_accel::AccelDtype>::NAME
196 }
197 fn op_name(&self) -> &'static str {
198 "gemm"
199 }
200 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
201 dispatch_gemm::<half::f16>(*self, ctx);
202 }
203}
204
205#[cfg(feature = "f16")]
206impl GemmDispatch for GemmRequest<half::bf16> {
207 fn dtype_name(&self) -> &'static str {
208 <half::bf16 as atomr_accel::AccelDtype>::NAME
209 }
210 fn op_name(&self) -> &'static str {
211 "gemm"
212 }
213 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
214 dispatch_gemm::<half::bf16>(*self, ctx);
215 }
216}
217
218#[cfg(test)]
219pub(crate) mod tests_helpers {
220 use crate::gpu_ref::GpuRef;
221
222 pub fn gpu_ref_stub<T>() -> GpuRef<T> {
233 GpuRef::<T>::for_test_no_gpu_leaked()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use tokio::sync::oneshot;
241
242 fn _assert_send<T: Send>() {}
243
244 #[test]
245 fn gemm_request_dispatches_for_f32_f64_f16_bf16() {
246 _assert_send::<GemmRequest<f32>>();
250 _assert_send::<GemmRequest<f64>>();
251 #[cfg(feature = "f16")]
252 {
253 _assert_send::<GemmRequest<half::f16>>();
254 _assert_send::<GemmRequest<half::bf16>>();
255 }
256
257 let req = stub_request::<f32>();
262 let boxed: Box<dyn GemmDispatch> = Box::new(req);
263 assert_eq!(boxed.op_name(), "gemm");
264 assert_eq!(boxed.dtype_name(), "f32");
265 Box::leak(boxed);
267
268 let req = stub_request::<f64>();
269 let boxed: Box<dyn GemmDispatch> = Box::new(req);
270 assert_eq!(boxed.dtype_name(), "f64");
271 Box::leak(boxed);
272
273 #[cfg(feature = "f16")]
274 {
275 let req = stub_request::<half::f16>();
276 let boxed: Box<dyn GemmDispatch> = Box::new(req);
277 assert_eq!(boxed.dtype_name(), "f16");
278 Box::leak(boxed);
279
280 let req = stub_request::<half::bf16>();
281 let boxed: Box<dyn GemmDispatch> = Box::new(req);
282 assert_eq!(boxed.dtype_name(), "bf16");
283 Box::leak(boxed);
284 }
285 }
286
287 #[test]
288 fn deprecated_sgemm_alias_still_constructs() {
289 #[allow(deprecated)]
295 {
296 let (tx, _rx) = oneshot::channel();
297 let req = crate::device::SgemmRequest {
298 a: gpu_ref_stub::<f32>(),
299 b: gpu_ref_stub::<f32>(),
300 c: gpu_ref_stub::<f32>(),
301 m: 1,
302 n: 1,
303 k: 1,
304 alpha: 1.0,
305 beta: 0.0,
306 reply: tx,
307 };
308 let msg = crate::kernel::BlasMsg::Sgemm(Box::new(req));
309 Box::leak(Box::new(msg));
312 }
313 }
314
315 fn stub_request<T>() -> GemmRequest<T>
319 where
320 T: GemmSupported + num_one_zero::NumOneZero,
321 GemmRequest<T>: GemmDispatch,
322 {
323 let (tx, _rx) = oneshot::channel();
324 GemmRequest::<T> {
328 a: gpu_ref_stub::<T>(),
329 b: gpu_ref_stub::<T>(),
330 c: gpu_ref_stub::<T>(),
331 m: 1,
332 n: 1,
333 k: 1,
334 alpha: <T as num_one_zero::NumOneZero>::one(),
335 beta: <T as num_one_zero::NumOneZero>::zero(),
336 trans_a: cublasOperation_t::CUBLAS_OP_N,
337 trans_b: cublasOperation_t::CUBLAS_OP_N,
338 lda: 1,
339 ldb: 1,
340 ldc: 1,
341 reply: tx,
342 }
343 }
344
345 mod num_one_zero {
349 pub trait NumOneZero: Copy {
350 fn one() -> Self;
351 fn zero() -> Self;
352 }
353 impl NumOneZero for f32 {
354 fn one() -> Self {
355 1.0
356 }
357 fn zero() -> Self {
358 0.0
359 }
360 }
361 impl NumOneZero for f64 {
362 fn one() -> Self {
363 1.0
364 }
365 fn zero() -> Self {
366 0.0
367 }
368 }
369 #[cfg(feature = "f16")]
370 impl NumOneZero for half::f16 {
371 fn one() -> Self {
372 half::f16::ONE
373 }
374 fn zero() -> Self {
375 half::f16::ZERO
376 }
377 }
378 #[cfg(feature = "f16")]
379 impl NumOneZero for half::bf16 {
380 fn one() -> Self {
381 half::bf16::ONE
382 }
383 fn zero() -> Self {
384 half::bf16::ZERO
385 }
386 }
387 }
388
389 use super::tests_helpers::gpu_ref_stub;
390}