1use std::sync::Arc;
9
10use cudarc::cublas::sys::{
11 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t,
12};
13use cudarc::driver::{sys::CUdeviceptr, DevicePtr, DevicePtrMut};
14use tokio::sync::oneshot;
15
16use crate::dtype::{GeamSupported, SyrkSupported, TrsmSupported};
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::dispatch::{BlasDispatchCtx, BlasL3Dispatch};
20use crate::kernel::envelope;
21use crate::sys::cublas as syscublas;
22
23const LIB: &str = "cublas";
24
25pub struct GeamRequest<T: GeamSupported> {
29 pub trans_a: cublasOperation_t,
30 pub trans_b: cublasOperation_t,
31 pub m: i32,
32 pub n: i32,
33 pub alpha: T,
34 pub a: GpuRef<T>,
35 pub lda: i32,
36 pub beta: T,
37 pub b: GpuRef<T>,
38 pub ldb: i32,
39 pub c: GpuRef<T>,
40 pub ldc: i32,
41 pub reply: oneshot::Sender<Result<(), GpuError>>,
42}
43
44trait GeamCall: GeamSupported {
45 unsafe fn call(
48 handle: cudarc::cublas::sys::cublasHandle_t,
49 transa: cublasOperation_t,
50 transb: cublasOperation_t,
51 m: i32,
52 n: i32,
53 alpha: *const Self,
54 a: CUdeviceptr,
55 lda: i32,
56 beta: *const Self,
57 b: CUdeviceptr,
58 ldb: i32,
59 c: CUdeviceptr,
60 ldc: i32,
61 ) -> Result<(), GpuError>;
62}
63
64impl GeamCall for f32 {
65 unsafe fn call(
66 handle: cudarc::cublas::sys::cublasHandle_t,
67 transa: cublasOperation_t,
68 transb: cublasOperation_t,
69 m: i32,
70 n: i32,
71 alpha: *const Self,
72 a: CUdeviceptr,
73 lda: i32,
74 beta: *const Self,
75 b: CUdeviceptr,
76 ldb: i32,
77 c: CUdeviceptr,
78 ldc: i32,
79 ) -> Result<(), GpuError> {
80 syscublas::sgeam(
81 handle, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, ldc,
82 )
83 }
84}
85
86impl GeamCall for f64 {
87 unsafe fn call(
88 handle: cudarc::cublas::sys::cublasHandle_t,
89 transa: cublasOperation_t,
90 transb: cublasOperation_t,
91 m: i32,
92 n: i32,
93 alpha: *const Self,
94 a: CUdeviceptr,
95 lda: i32,
96 beta: *const Self,
97 b: CUdeviceptr,
98 ldb: i32,
99 c: CUdeviceptr,
100 ldc: i32,
101 ) -> Result<(), GpuError> {
102 syscublas::dgeam(
103 handle, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, ldc,
104 )
105 }
106}
107
108fn dispatch_geam<T>(req: GeamRequest<T>, ctx: &BlasDispatchCtx<'_>)
109where
110 T: GeamSupported + GeamCall + Copy,
111{
112 let GeamRequest {
113 trans_a,
114 trans_b,
115 m,
116 n,
117 alpha,
118 a,
119 lda,
120 beta,
121 b,
122 ldb,
123 c,
124 ldc,
125 reply,
126 } = req;
127 let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
128 Ok(t) => t,
129 Err(e) => {
130 let _ = reply.send(Err(e));
131 return;
132 }
133 };
134 let mut c_owned = match Arc::try_unwrap(c_slice) {
135 Ok(s) => s,
136 Err(_) => {
137 let _ = reply.send(Err(GpuError::Unrecoverable(
138 "GEAM target buffer C has more than one live reference".into(),
139 )));
140 return;
141 }
142 };
143 c.record_write(ctx.stream);
144 let cublas = ctx.cublas.clone();
145 let stream = ctx.stream.clone();
146 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
147 let res = {
148 let (a_ptr, _a_rec) = (*a_slice).device_ptr(&stream);
149 let (b_ptr, _b_rec) = (*b_slice).device_ptr(&stream);
150 let (c_ptr, _c_rec) = c_owned.device_ptr_mut(&stream);
151 unsafe {
152 T::call(
153 *cublas.handle(),
154 trans_a,
155 trans_b,
156 m,
157 n,
158 (&alpha) as *const T,
159 a_ptr,
160 lda,
161 (&beta) as *const T,
162 b_ptr,
163 ldb,
164 c_ptr,
165 ldc,
166 )
167 }
168 };
169 match res {
170 Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
171 Err(e) => Err(e),
172 }
173 });
174}
175
176impl BlasL3Dispatch for GeamRequest<f32> {
177 fn dtype_name(&self) -> &'static str {
178 <f32 as atomr_accel::AccelDtype>::NAME
179 }
180 fn op_name(&self) -> &'static str {
181 "geam"
182 }
183 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
184 dispatch_geam::<f32>(*self, ctx);
185 }
186}
187
188impl BlasL3Dispatch for GeamRequest<f64> {
189 fn dtype_name(&self) -> &'static str {
190 <f64 as atomr_accel::AccelDtype>::NAME
191 }
192 fn op_name(&self) -> &'static str {
193 "geam"
194 }
195 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
196 dispatch_geam::<f64>(*self, ctx);
197 }
198}
199
200pub struct SyrkRequest<T: SyrkSupported> {
205 pub uplo: cublasFillMode_t,
206 pub trans: cublasOperation_t,
207 pub n: i32,
208 pub k: i32,
209 pub alpha: T,
210 pub a: GpuRef<T>,
211 pub lda: i32,
212 pub beta: T,
213 pub c: GpuRef<T>,
214 pub ldc: i32,
215 pub reply: oneshot::Sender<Result<(), GpuError>>,
216}
217
218trait SyrkCall: SyrkSupported {
219 unsafe fn call(
222 handle: cudarc::cublas::sys::cublasHandle_t,
223 uplo: cublasFillMode_t,
224 trans: cublasOperation_t,
225 n: i32,
226 k: i32,
227 alpha: *const Self,
228 a: CUdeviceptr,
229 lda: i32,
230 beta: *const Self,
231 c: CUdeviceptr,
232 ldc: i32,
233 ) -> Result<(), GpuError>;
234}
235
236impl SyrkCall for f32 {
237 unsafe fn call(
238 handle: cudarc::cublas::sys::cublasHandle_t,
239 uplo: cublasFillMode_t,
240 trans: cublasOperation_t,
241 n: i32,
242 k: i32,
243 alpha: *const Self,
244 a: CUdeviceptr,
245 lda: i32,
246 beta: *const Self,
247 c: CUdeviceptr,
248 ldc: i32,
249 ) -> Result<(), GpuError> {
250 syscublas::ssyrk(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
251 }
252}
253
254impl SyrkCall for f64 {
255 unsafe fn call(
256 handle: cudarc::cublas::sys::cublasHandle_t,
257 uplo: cublasFillMode_t,
258 trans: cublasOperation_t,
259 n: i32,
260 k: i32,
261 alpha: *const Self,
262 a: CUdeviceptr,
263 lda: i32,
264 beta: *const Self,
265 c: CUdeviceptr,
266 ldc: i32,
267 ) -> Result<(), GpuError> {
268 syscublas::dsyrk(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
269 }
270}
271
272fn dispatch_syrk<T>(req: SyrkRequest<T>, ctx: &BlasDispatchCtx<'_>)
273where
274 T: SyrkSupported + SyrkCall + Copy,
275{
276 let SyrkRequest {
277 uplo,
278 trans,
279 n,
280 k,
281 alpha,
282 a,
283 lda,
284 beta,
285 c,
286 ldc,
287 reply,
288 } = req;
289 let (a_slice, c_slice) = match envelope::access_all_2(&a, &c) {
290 Ok(t) => t,
291 Err(e) => {
292 let _ = reply.send(Err(e));
293 return;
294 }
295 };
296 let mut c_owned = match Arc::try_unwrap(c_slice) {
297 Ok(s) => s,
298 Err(_) => {
299 let _ = reply.send(Err(GpuError::Unrecoverable(
300 "SYRK target buffer C has more than one live reference".into(),
301 )));
302 return;
303 }
304 };
305 c.record_write(ctx.stream);
306 let cublas = ctx.cublas.clone();
307 let stream = ctx.stream.clone();
308 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
309 let res = {
310 let (a_ptr, _a_rec) = (*a_slice).device_ptr(&stream);
311 let (c_ptr, _c_rec) = c_owned.device_ptr_mut(&stream);
312 unsafe {
313 T::call(
314 *cublas.handle(),
315 uplo,
316 trans,
317 n,
318 k,
319 (&alpha) as *const T,
320 a_ptr,
321 lda,
322 (&beta) as *const T,
323 c_ptr,
324 ldc,
325 )
326 }
327 };
328 match res {
329 Ok(()) => Ok((cublas, a_slice, c_owned)),
330 Err(e) => Err(e),
331 }
332 });
333}
334
335impl BlasL3Dispatch for SyrkRequest<f32> {
336 fn dtype_name(&self) -> &'static str {
337 <f32 as atomr_accel::AccelDtype>::NAME
338 }
339 fn op_name(&self) -> &'static str {
340 "syrk"
341 }
342 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
343 dispatch_syrk::<f32>(*self, ctx);
344 }
345}
346
347impl BlasL3Dispatch for SyrkRequest<f64> {
348 fn dtype_name(&self) -> &'static str {
349 <f64 as atomr_accel::AccelDtype>::NAME
350 }
351 fn op_name(&self) -> &'static str {
352 "syrk"
353 }
354 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
355 dispatch_syrk::<f64>(*self, ctx);
356 }
357}
358
359pub struct TrsmRequest<T: TrsmSupported> {
364 pub side: cublasSideMode_t,
365 pub uplo: cublasFillMode_t,
366 pub trans: cublasOperation_t,
367 pub diag: cublasDiagType_t,
368 pub m: i32,
369 pub n: i32,
370 pub alpha: T,
371 pub a: GpuRef<T>,
372 pub lda: i32,
373 pub b: GpuRef<T>,
374 pub ldb: i32,
375 pub reply: oneshot::Sender<Result<(), GpuError>>,
376}
377
378trait TrsmCall: TrsmSupported {
379 unsafe fn call(
382 handle: cudarc::cublas::sys::cublasHandle_t,
383 side: cublasSideMode_t,
384 uplo: cublasFillMode_t,
385 trans: cublasOperation_t,
386 diag: cublasDiagType_t,
387 m: i32,
388 n: i32,
389 alpha: *const Self,
390 a: CUdeviceptr,
391 lda: i32,
392 b: CUdeviceptr,
393 ldb: i32,
394 ) -> Result<(), GpuError>;
395}
396
397impl TrsmCall for f32 {
398 unsafe fn call(
399 handle: cudarc::cublas::sys::cublasHandle_t,
400 side: cublasSideMode_t,
401 uplo: cublasFillMode_t,
402 trans: cublasOperation_t,
403 diag: cublasDiagType_t,
404 m: i32,
405 n: i32,
406 alpha: *const Self,
407 a: CUdeviceptr,
408 lda: i32,
409 b: CUdeviceptr,
410 ldb: i32,
411 ) -> Result<(), GpuError> {
412 syscublas::strsm(handle, side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
413 }
414}
415
416impl TrsmCall for f64 {
417 unsafe fn call(
418 handle: cudarc::cublas::sys::cublasHandle_t,
419 side: cublasSideMode_t,
420 uplo: cublasFillMode_t,
421 trans: cublasOperation_t,
422 diag: cublasDiagType_t,
423 m: i32,
424 n: i32,
425 alpha: *const Self,
426 a: CUdeviceptr,
427 lda: i32,
428 b: CUdeviceptr,
429 ldb: i32,
430 ) -> Result<(), GpuError> {
431 syscublas::dtrsm(handle, side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
432 }
433}
434
435fn dispatch_trsm<T>(req: TrsmRequest<T>, ctx: &BlasDispatchCtx<'_>)
436where
437 T: TrsmSupported + TrsmCall + Copy,
438{
439 let TrsmRequest {
440 side,
441 uplo,
442 trans,
443 diag,
444 m,
445 n,
446 alpha,
447 a,
448 lda,
449 b,
450 ldb,
451 reply,
452 } = req;
453 let (a_slice, b_slice) = match envelope::access_all_2(&a, &b) {
454 Ok(t) => t,
455 Err(e) => {
456 let _ = reply.send(Err(e));
457 return;
458 }
459 };
460 let mut b_owned = match Arc::try_unwrap(b_slice) {
461 Ok(s) => s,
462 Err(_) => {
463 let _ = reply.send(Err(GpuError::Unrecoverable(
464 "TRSM target buffer B has more than one live reference".into(),
465 )));
466 return;
467 }
468 };
469 b.record_write(ctx.stream);
470 let cublas = ctx.cublas.clone();
471 let stream = ctx.stream.clone();
472 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
473 let res = {
474 let (a_ptr, _a_rec) = (*a_slice).device_ptr(&stream);
475 let (b_ptr, _b_rec) = b_owned.device_ptr_mut(&stream);
476 unsafe {
477 T::call(
478 *cublas.handle(),
479 side,
480 uplo,
481 trans,
482 diag,
483 m,
484 n,
485 (&alpha) as *const T,
486 a_ptr,
487 lda,
488 b_ptr,
489 ldb,
490 )
491 }
492 };
493 match res {
494 Ok(()) => Ok((cublas, a_slice, b_owned)),
495 Err(e) => Err(e),
496 }
497 });
498}
499
500impl BlasL3Dispatch for TrsmRequest<f32> {
501 fn dtype_name(&self) -> &'static str {
502 <f32 as atomr_accel::AccelDtype>::NAME
503 }
504 fn op_name(&self) -> &'static str {
505 "trsm"
506 }
507 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
508 dispatch_trsm::<f32>(*self, ctx);
509 }
510}
511
512impl BlasL3Dispatch for TrsmRequest<f64> {
513 fn dtype_name(&self) -> &'static str {
514 <f64 as atomr_accel::AccelDtype>::NAME
515 }
516 fn op_name(&self) -> &'static str {
517 "trsm"
518 }
519 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
520 dispatch_trsm::<f64>(*self, ctx);
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::super::gemm::tests_helpers::gpu_ref_stub;
527 use super::*;
528 use tokio::sync::oneshot;
529
530 #[test]
531 fn geam_request_round_trip() {
532 let (tx, _rx) = oneshot::channel();
533 let req = GeamRequest::<f32> {
534 trans_a: cublasOperation_t::CUBLAS_OP_N,
535 trans_b: cublasOperation_t::CUBLAS_OP_N,
536 m: 4,
537 n: 4,
538 alpha: 1.0,
539 a: gpu_ref_stub::<f32>(),
540 lda: 4,
541 beta: 1.0,
542 b: gpu_ref_stub::<f32>(),
543 ldb: 4,
544 c: gpu_ref_stub::<f32>(),
545 ldc: 4,
546 reply: tx,
547 };
548 let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
549 assert_eq!(boxed.op_name(), "geam");
550 assert_eq!(boxed.dtype_name(), "f32");
551 Box::leak(boxed);
552
553 let (tx, _rx) = oneshot::channel();
554 let req = GeamRequest::<f64> {
555 trans_a: cublasOperation_t::CUBLAS_OP_N,
556 trans_b: cublasOperation_t::CUBLAS_OP_N,
557 m: 4,
558 n: 4,
559 alpha: 1.0,
560 a: gpu_ref_stub::<f64>(),
561 lda: 4,
562 beta: 1.0,
563 b: gpu_ref_stub::<f64>(),
564 ldb: 4,
565 c: gpu_ref_stub::<f64>(),
566 ldc: 4,
567 reply: tx,
568 };
569 let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
570 assert_eq!(boxed.dtype_name(), "f64");
571 Box::leak(boxed);
572 }
573
574 #[test]
575 fn syrk_request_round_trip() {
576 let (tx, _rx) = oneshot::channel();
577 let req = SyrkRequest::<f32> {
578 uplo: cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
579 trans: cublasOperation_t::CUBLAS_OP_N,
580 n: 4,
581 k: 4,
582 alpha: 1.0,
583 a: gpu_ref_stub::<f32>(),
584 lda: 4,
585 beta: 0.0,
586 c: gpu_ref_stub::<f32>(),
587 ldc: 4,
588 reply: tx,
589 };
590 let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
591 assert_eq!(boxed.op_name(), "syrk");
592 Box::leak(boxed);
593 }
594
595 #[test]
596 fn trsm_request_round_trip() {
597 let (tx, _rx) = oneshot::channel();
598 let req = TrsmRequest::<f32> {
599 side: cublasSideMode_t::CUBLAS_SIDE_LEFT,
600 uplo: cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
601 trans: cublasOperation_t::CUBLAS_OP_N,
602 diag: cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
603 m: 4,
604 n: 4,
605 alpha: 1.0,
606 a: gpu_ref_stub::<f32>(),
607 lda: 4,
608 b: gpu_ref_stub::<f32>(),
609 ldb: 4,
610 reply: tx,
611 };
612 let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
613 assert_eq!(boxed.op_name(), "trsm");
614 Box::leak(boxed);
615 }
616}