Skip to main content

atomr_accel_cuda/kernel/blas/
l3.rs

1//! Typed L3 ops other than gemm: geam (matrix add/scale), syrk
2//! (symmetric rank-k update), trsm (triangular solve).
3//!
4//! Each of these drops to the local sys-level wrappers in
5//! [`crate::sys::cublas`] because cudarc 0.19 has no safe trait for
6//! them.
7
8use std::sync::Arc;
9
10use cudarc::cublas::sys::{
11    cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t,
12};
13use cudarc::driver::{sys::CUdeviceptr, DevicePtr, DevicePtrMut};
14use tokio::sync::oneshot;
15
16use crate::dtype::{GeamSupported, SyrkSupported, TrsmSupported};
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::dispatch::{BlasDispatchCtx, BlasL3Dispatch};
20use crate::kernel::envelope;
21use crate::sys::cublas as syscublas;
22
23const LIB: &str = "cublas";
24
25// ─────────────────────── GEAM ───────────────────────
26
27/// `C = α·op(A) + β·op(B)`.
28pub struct GeamRequest<T: GeamSupported> {
29    pub trans_a: cublasOperation_t,
30    pub trans_b: cublasOperation_t,
31    pub m: i32,
32    pub n: i32,
33    pub alpha: T,
34    pub a: GpuRef<T>,
35    pub lda: i32,
36    pub beta: T,
37    pub b: GpuRef<T>,
38    pub ldb: i32,
39    pub c: GpuRef<T>,
40    pub ldc: i32,
41    pub reply: oneshot::Sender<Result<(), GpuError>>,
42}
43
44trait GeamCall: GeamSupported {
45    /// # Safety
46    /// All pointers must be valid for the encoded sizes.
47    unsafe fn call(
48        handle: cudarc::cublas::sys::cublasHandle_t,
49        transa: cublasOperation_t,
50        transb: cublasOperation_t,
51        m: i32,
52        n: i32,
53        alpha: *const Self,
54        a: CUdeviceptr,
55        lda: i32,
56        beta: *const Self,
57        b: CUdeviceptr,
58        ldb: i32,
59        c: CUdeviceptr,
60        ldc: i32,
61    ) -> Result<(), GpuError>;
62}
63
64impl GeamCall for f32 {
65    unsafe fn call(
66        handle: cudarc::cublas::sys::cublasHandle_t,
67        transa: cublasOperation_t,
68        transb: cublasOperation_t,
69        m: i32,
70        n: i32,
71        alpha: *const Self,
72        a: CUdeviceptr,
73        lda: i32,
74        beta: *const Self,
75        b: CUdeviceptr,
76        ldb: i32,
77        c: CUdeviceptr,
78        ldc: i32,
79    ) -> Result<(), GpuError> {
80        syscublas::sgeam(
81            handle, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, ldc,
82        )
83    }
84}
85
86impl GeamCall for f64 {
87    unsafe fn call(
88        handle: cudarc::cublas::sys::cublasHandle_t,
89        transa: cublasOperation_t,
90        transb: cublasOperation_t,
91        m: i32,
92        n: i32,
93        alpha: *const Self,
94        a: CUdeviceptr,
95        lda: i32,
96        beta: *const Self,
97        b: CUdeviceptr,
98        ldb: i32,
99        c: CUdeviceptr,
100        ldc: i32,
101    ) -> Result<(), GpuError> {
102        syscublas::dgeam(
103            handle, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, ldc,
104        )
105    }
106}
107
108fn dispatch_geam<T>(req: GeamRequest<T>, ctx: &BlasDispatchCtx<'_>)
109where
110    T: GeamSupported + GeamCall + Copy,
111{
112    let GeamRequest {
113        trans_a,
114        trans_b,
115        m,
116        n,
117        alpha,
118        a,
119        lda,
120        beta,
121        b,
122        ldb,
123        c,
124        ldc,
125        reply,
126    } = req;
127    let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
128        Ok(t) => t,
129        Err(e) => {
130            let _ = reply.send(Err(e));
131            return;
132        }
133    };
134    let mut c_owned = match Arc::try_unwrap(c_slice) {
135        Ok(s) => s,
136        Err(_) => {
137            let _ = reply.send(Err(GpuError::Unrecoverable(
138                "GEAM target buffer C has more than one live reference".into(),
139            )));
140            return;
141        }
142    };
143    c.record_write(ctx.stream);
144    let cublas = ctx.cublas.clone();
145    let stream = ctx.stream.clone();
146    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
147        let res = {
148            let (a_ptr, _a_rec) = (*a_slice).device_ptr(&stream);
149            let (b_ptr, _b_rec) = (*b_slice).device_ptr(&stream);
150            let (c_ptr, _c_rec) = c_owned.device_ptr_mut(&stream);
151            unsafe {
152                T::call(
153                    *cublas.handle(),
154                    trans_a,
155                    trans_b,
156                    m,
157                    n,
158                    (&alpha) as *const T,
159                    a_ptr,
160                    lda,
161                    (&beta) as *const T,
162                    b_ptr,
163                    ldb,
164                    c_ptr,
165                    ldc,
166                )
167            }
168        };
169        match res {
170            Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
171            Err(e) => Err(e),
172        }
173    });
174}
175
176impl BlasL3Dispatch for GeamRequest<f32> {
177    fn dtype_name(&self) -> &'static str {
178        <f32 as atomr_accel::AccelDtype>::NAME
179    }
180    fn op_name(&self) -> &'static str {
181        "geam"
182    }
183    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
184        dispatch_geam::<f32>(*self, ctx);
185    }
186}
187
188impl BlasL3Dispatch for GeamRequest<f64> {
189    fn dtype_name(&self) -> &'static str {
190        <f64 as atomr_accel::AccelDtype>::NAME
191    }
192    fn op_name(&self) -> &'static str {
193        "geam"
194    }
195    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
196        dispatch_geam::<f64>(*self, ctx);
197    }
198}
199
200// ─────────────────────── SYRK ───────────────────────
201
202/// `C := α·op(A)·op(A)^T + β·C`. `op` is `N` or `T`. Updates either
203/// the upper or lower triangle of `C`.
204pub struct SyrkRequest<T: SyrkSupported> {
205    pub uplo: cublasFillMode_t,
206    pub trans: cublasOperation_t,
207    pub n: i32,
208    pub k: i32,
209    pub alpha: T,
210    pub a: GpuRef<T>,
211    pub lda: i32,
212    pub beta: T,
213    pub c: GpuRef<T>,
214    pub ldc: i32,
215    pub reply: oneshot::Sender<Result<(), GpuError>>,
216}
217
218trait SyrkCall: SyrkSupported {
219    /// # Safety
220    /// All pointers must be valid for the encoded sizes.
221    unsafe fn call(
222        handle: cudarc::cublas::sys::cublasHandle_t,
223        uplo: cublasFillMode_t,
224        trans: cublasOperation_t,
225        n: i32,
226        k: i32,
227        alpha: *const Self,
228        a: CUdeviceptr,
229        lda: i32,
230        beta: *const Self,
231        c: CUdeviceptr,
232        ldc: i32,
233    ) -> Result<(), GpuError>;
234}
235
236impl SyrkCall for f32 {
237    unsafe fn call(
238        handle: cudarc::cublas::sys::cublasHandle_t,
239        uplo: cublasFillMode_t,
240        trans: cublasOperation_t,
241        n: i32,
242        k: i32,
243        alpha: *const Self,
244        a: CUdeviceptr,
245        lda: i32,
246        beta: *const Self,
247        c: CUdeviceptr,
248        ldc: i32,
249    ) -> Result<(), GpuError> {
250        syscublas::ssyrk(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
251    }
252}
253
254impl SyrkCall for f64 {
255    unsafe fn call(
256        handle: cudarc::cublas::sys::cublasHandle_t,
257        uplo: cublasFillMode_t,
258        trans: cublasOperation_t,
259        n: i32,
260        k: i32,
261        alpha: *const Self,
262        a: CUdeviceptr,
263        lda: i32,
264        beta: *const Self,
265        c: CUdeviceptr,
266        ldc: i32,
267    ) -> Result<(), GpuError> {
268        syscublas::dsyrk(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
269    }
270}
271
272fn dispatch_syrk<T>(req: SyrkRequest<T>, ctx: &BlasDispatchCtx<'_>)
273where
274    T: SyrkSupported + SyrkCall + Copy,
275{
276    let SyrkRequest {
277        uplo,
278        trans,
279        n,
280        k,
281        alpha,
282        a,
283        lda,
284        beta,
285        c,
286        ldc,
287        reply,
288    } = req;
289    let (a_slice, c_slice) = match envelope::access_all_2(&a, &c) {
290        Ok(t) => t,
291        Err(e) => {
292            let _ = reply.send(Err(e));
293            return;
294        }
295    };
296    let mut c_owned = match Arc::try_unwrap(c_slice) {
297        Ok(s) => s,
298        Err(_) => {
299            let _ = reply.send(Err(GpuError::Unrecoverable(
300                "SYRK target buffer C has more than one live reference".into(),
301            )));
302            return;
303        }
304    };
305    c.record_write(ctx.stream);
306    let cublas = ctx.cublas.clone();
307    let stream = ctx.stream.clone();
308    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
309        let res = {
310            let (a_ptr, _a_rec) = (*a_slice).device_ptr(&stream);
311            let (c_ptr, _c_rec) = c_owned.device_ptr_mut(&stream);
312            unsafe {
313                T::call(
314                    *cublas.handle(),
315                    uplo,
316                    trans,
317                    n,
318                    k,
319                    (&alpha) as *const T,
320                    a_ptr,
321                    lda,
322                    (&beta) as *const T,
323                    c_ptr,
324                    ldc,
325                )
326            }
327        };
328        match res {
329            Ok(()) => Ok((cublas, a_slice, c_owned)),
330            Err(e) => Err(e),
331        }
332    });
333}
334
335impl BlasL3Dispatch for SyrkRequest<f32> {
336    fn dtype_name(&self) -> &'static str {
337        <f32 as atomr_accel::AccelDtype>::NAME
338    }
339    fn op_name(&self) -> &'static str {
340        "syrk"
341    }
342    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
343        dispatch_syrk::<f32>(*self, ctx);
344    }
345}
346
347impl BlasL3Dispatch for SyrkRequest<f64> {
348    fn dtype_name(&self) -> &'static str {
349        <f64 as atomr_accel::AccelDtype>::NAME
350    }
351    fn op_name(&self) -> &'static str {
352        "syrk"
353    }
354    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
355        dispatch_syrk::<f64>(*self, ctx);
356    }
357}
358
359// ─────────────────────── TRSM ───────────────────────
360
361/// Triangular solve: `op(A) · X = α·B` (or `X · op(A) = α·B`).
362/// Solution is written in-place over `B`.
363pub struct TrsmRequest<T: TrsmSupported> {
364    pub side: cublasSideMode_t,
365    pub uplo: cublasFillMode_t,
366    pub trans: cublasOperation_t,
367    pub diag: cublasDiagType_t,
368    pub m: i32,
369    pub n: i32,
370    pub alpha: T,
371    pub a: GpuRef<T>,
372    pub lda: i32,
373    pub b: GpuRef<T>,
374    pub ldb: i32,
375    pub reply: oneshot::Sender<Result<(), GpuError>>,
376}
377
378trait TrsmCall: TrsmSupported {
379    /// # Safety
380    /// All pointers must be valid for the encoded sizes.
381    unsafe fn call(
382        handle: cudarc::cublas::sys::cublasHandle_t,
383        side: cublasSideMode_t,
384        uplo: cublasFillMode_t,
385        trans: cublasOperation_t,
386        diag: cublasDiagType_t,
387        m: i32,
388        n: i32,
389        alpha: *const Self,
390        a: CUdeviceptr,
391        lda: i32,
392        b: CUdeviceptr,
393        ldb: i32,
394    ) -> Result<(), GpuError>;
395}
396
397impl TrsmCall for f32 {
398    unsafe fn call(
399        handle: cudarc::cublas::sys::cublasHandle_t,
400        side: cublasSideMode_t,
401        uplo: cublasFillMode_t,
402        trans: cublasOperation_t,
403        diag: cublasDiagType_t,
404        m: i32,
405        n: i32,
406        alpha: *const Self,
407        a: CUdeviceptr,
408        lda: i32,
409        b: CUdeviceptr,
410        ldb: i32,
411    ) -> Result<(), GpuError> {
412        syscublas::strsm(handle, side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
413    }
414}
415
416impl TrsmCall for f64 {
417    unsafe fn call(
418        handle: cudarc::cublas::sys::cublasHandle_t,
419        side: cublasSideMode_t,
420        uplo: cublasFillMode_t,
421        trans: cublasOperation_t,
422        diag: cublasDiagType_t,
423        m: i32,
424        n: i32,
425        alpha: *const Self,
426        a: CUdeviceptr,
427        lda: i32,
428        b: CUdeviceptr,
429        ldb: i32,
430    ) -> Result<(), GpuError> {
431        syscublas::dtrsm(handle, side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
432    }
433}
434
435fn dispatch_trsm<T>(req: TrsmRequest<T>, ctx: &BlasDispatchCtx<'_>)
436where
437    T: TrsmSupported + TrsmCall + Copy,
438{
439    let TrsmRequest {
440        side,
441        uplo,
442        trans,
443        diag,
444        m,
445        n,
446        alpha,
447        a,
448        lda,
449        b,
450        ldb,
451        reply,
452    } = req;
453    let (a_slice, b_slice) = match envelope::access_all_2(&a, &b) {
454        Ok(t) => t,
455        Err(e) => {
456            let _ = reply.send(Err(e));
457            return;
458        }
459    };
460    let mut b_owned = match Arc::try_unwrap(b_slice) {
461        Ok(s) => s,
462        Err(_) => {
463            let _ = reply.send(Err(GpuError::Unrecoverable(
464                "TRSM target buffer B has more than one live reference".into(),
465            )));
466            return;
467        }
468    };
469    b.record_write(ctx.stream);
470    let cublas = ctx.cublas.clone();
471    let stream = ctx.stream.clone();
472    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
473        let res = {
474            let (a_ptr, _a_rec) = (*a_slice).device_ptr(&stream);
475            let (b_ptr, _b_rec) = b_owned.device_ptr_mut(&stream);
476            unsafe {
477                T::call(
478                    *cublas.handle(),
479                    side,
480                    uplo,
481                    trans,
482                    diag,
483                    m,
484                    n,
485                    (&alpha) as *const T,
486                    a_ptr,
487                    lda,
488                    b_ptr,
489                    ldb,
490                )
491            }
492        };
493        match res {
494            Ok(()) => Ok((cublas, a_slice, b_owned)),
495            Err(e) => Err(e),
496        }
497    });
498}
499
500impl BlasL3Dispatch for TrsmRequest<f32> {
501    fn dtype_name(&self) -> &'static str {
502        <f32 as atomr_accel::AccelDtype>::NAME
503    }
504    fn op_name(&self) -> &'static str {
505        "trsm"
506    }
507    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
508        dispatch_trsm::<f32>(*self, ctx);
509    }
510}
511
512impl BlasL3Dispatch for TrsmRequest<f64> {
513    fn dtype_name(&self) -> &'static str {
514        <f64 as atomr_accel::AccelDtype>::NAME
515    }
516    fn op_name(&self) -> &'static str {
517        "trsm"
518    }
519    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
520        dispatch_trsm::<f64>(*self, ctx);
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::super::gemm::tests_helpers::gpu_ref_stub;
527    use super::*;
528    use tokio::sync::oneshot;
529
530    #[test]
531    fn geam_request_round_trip() {
532        let (tx, _rx) = oneshot::channel();
533        let req = GeamRequest::<f32> {
534            trans_a: cublasOperation_t::CUBLAS_OP_N,
535            trans_b: cublasOperation_t::CUBLAS_OP_N,
536            m: 4,
537            n: 4,
538            alpha: 1.0,
539            a: gpu_ref_stub::<f32>(),
540            lda: 4,
541            beta: 1.0,
542            b: gpu_ref_stub::<f32>(),
543            ldb: 4,
544            c: gpu_ref_stub::<f32>(),
545            ldc: 4,
546            reply: tx,
547        };
548        let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
549        assert_eq!(boxed.op_name(), "geam");
550        assert_eq!(boxed.dtype_name(), "f32");
551        Box::leak(boxed);
552
553        let (tx, _rx) = oneshot::channel();
554        let req = GeamRequest::<f64> {
555            trans_a: cublasOperation_t::CUBLAS_OP_N,
556            trans_b: cublasOperation_t::CUBLAS_OP_N,
557            m: 4,
558            n: 4,
559            alpha: 1.0,
560            a: gpu_ref_stub::<f64>(),
561            lda: 4,
562            beta: 1.0,
563            b: gpu_ref_stub::<f64>(),
564            ldb: 4,
565            c: gpu_ref_stub::<f64>(),
566            ldc: 4,
567            reply: tx,
568        };
569        let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
570        assert_eq!(boxed.dtype_name(), "f64");
571        Box::leak(boxed);
572    }
573
574    #[test]
575    fn syrk_request_round_trip() {
576        let (tx, _rx) = oneshot::channel();
577        let req = SyrkRequest::<f32> {
578            uplo: cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
579            trans: cublasOperation_t::CUBLAS_OP_N,
580            n: 4,
581            k: 4,
582            alpha: 1.0,
583            a: gpu_ref_stub::<f32>(),
584            lda: 4,
585            beta: 0.0,
586            c: gpu_ref_stub::<f32>(),
587            ldc: 4,
588            reply: tx,
589        };
590        let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
591        assert_eq!(boxed.op_name(), "syrk");
592        Box::leak(boxed);
593    }
594
595    #[test]
596    fn trsm_request_round_trip() {
597        let (tx, _rx) = oneshot::channel();
598        let req = TrsmRequest::<f32> {
599            side: cublasSideMode_t::CUBLAS_SIDE_LEFT,
600            uplo: cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
601            trans: cublasOperation_t::CUBLAS_OP_N,
602            diag: cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
603            m: 4,
604            n: 4,
605            alpha: 1.0,
606            a: gpu_ref_stub::<f32>(),
607            lda: 4,
608            b: gpu_ref_stub::<f32>(),
609            ldb: 4,
610            reply: tx,
611        };
612        let boxed: Box<dyn BlasL3Dispatch> = Box::new(req);
613        assert_eq!(boxed.op_name(), "trsm");
614        Box::leak(boxed);
615    }
616}