Skip to main content

atomr_accel_cuda/sys/
cublas.rs

1//! Sys-level safe wrappers for the cuBLAS entry points cudarc 0.19
2//! doesn't expose through its safe layer.
3//!
4//! Wrapped today (Phase 1 cuBLAS slice):
5//! - `cublasGemmEx`, `cublasGemmStridedBatchedEx`
6//! - `cublasAxpyEx`, `cublasScalEx`, `cublasNrm2Ex`, `cublasDotEx`
7//! - `cublasIamaxEx`, `cublasIaminEx`, `cublasAsumEx`
8//! - `cublasCopyEx`, `cublasSwapEx`, `cublasRotEx`
9//! - `cublasGemv_v2`/`cublasDgemv_v2`, `cublasSger_v2`/`cublasDger_v2`
10//! - `cublasSgeam`/`cublasDgeam`
11//! - `cublasSsyrk_v2`/`cublasDsyrk_v2`
12//! - `cublasStrsm_v2`/`cublasDtrsm_v2`
13//!
14//! All callers must hold the cuBLAS handle's stream current on the
15//! same OS thread. The atomr-accel-cuda actor pipeline guarantees
16//! that via `GpuDispatcher`.
17
18#![allow(non_snake_case)]
19
20use core::ffi::{c_int, c_longlong};
21
22use cudarc::cublas::sys::{
23    self, cublasComputeType_t, cublasDiagType_t, cublasFillMode_t, cublasGemmAlgo_t,
24    cublasHandle_t, cublasOperation_t, cublasSideMode_t, cudaDataType,
25};
26use cudarc::driver::sys::CUdeviceptr;
27
28use crate::error::GpuError;
29
30const LIB: &str = "cublas";
31
32#[inline]
33fn check(status: sys::cublasStatus_t, op: &'static str) -> Result<(), GpuError> {
34    match status {
35        sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
36        e => Err(GpuError::LibraryError {
37            lib: LIB,
38            msg: format!("{op}: {e:?}"),
39        }),
40    }
41}
42
43// ───────────────────────────── L3 ─────────────────────────────
44
45/// `cublasGemmEx` — typed-erased gemm with a separate compute type.
46///
47/// # Safety
48/// `a`/`b`/`c` must point to device buffers with the dtypes encoded
49/// in `a_type`/`b_type`/`c_type` and the sizes implied by
50/// `(m,n,k,lda,ldb,ldc)`.
51#[allow(clippy::too_many_arguments)]
52pub unsafe fn gemm_ex(
53    handle: cublasHandle_t,
54    transa: cublasOperation_t,
55    transb: cublasOperation_t,
56    m: c_int,
57    n: c_int,
58    k: c_int,
59    alpha: *const core::ffi::c_void,
60    a: CUdeviceptr,
61    a_type: cudaDataType,
62    lda: c_int,
63    b: CUdeviceptr,
64    b_type: cudaDataType,
65    ldb: c_int,
66    beta: *const core::ffi::c_void,
67    c: CUdeviceptr,
68    c_type: cudaDataType,
69    ldc: c_int,
70    compute_type: cublasComputeType_t,
71    algo: cublasGemmAlgo_t,
72) -> Result<(), GpuError> {
73    let status = sys::cublasGemmEx(
74        handle,
75        transa,
76        transb,
77        m,
78        n,
79        k,
80        alpha,
81        a as *const _,
82        a_type,
83        lda,
84        b as *const _,
85        b_type,
86        ldb,
87        beta,
88        c as *mut _,
89        c_type,
90        ldc,
91        compute_type,
92        algo,
93    );
94    check(status, "cublasGemmEx")
95}
96
97/// `cublasGemmStridedBatchedEx` — typed-erased strided-batched gemm.
98///
99/// # Safety
100/// Like [`gemm_ex`], plus `stride_*` describes the byte stride between
101/// consecutive batch entries inside a single allocation.
102#[allow(clippy::too_many_arguments)]
103pub unsafe fn gemm_strided_batched_ex(
104    handle: cublasHandle_t,
105    transa: cublasOperation_t,
106    transb: cublasOperation_t,
107    m: c_int,
108    n: c_int,
109    k: c_int,
110    alpha: *const core::ffi::c_void,
111    a: CUdeviceptr,
112    a_type: cudaDataType,
113    lda: c_int,
114    stride_a: c_longlong,
115    b: CUdeviceptr,
116    b_type: cudaDataType,
117    ldb: c_int,
118    stride_b: c_longlong,
119    beta: *const core::ffi::c_void,
120    c: CUdeviceptr,
121    c_type: cudaDataType,
122    ldc: c_int,
123    stride_c: c_longlong,
124    batch_count: c_int,
125    compute_type: cublasComputeType_t,
126    algo: cublasGemmAlgo_t,
127) -> Result<(), GpuError> {
128    let status = sys::cublasGemmStridedBatchedEx(
129        handle,
130        transa,
131        transb,
132        m,
133        n,
134        k,
135        alpha,
136        a as *const _,
137        a_type,
138        lda,
139        stride_a,
140        b as *const _,
141        b_type,
142        ldb,
143        stride_b,
144        beta,
145        c as *mut _,
146        c_type,
147        ldc,
148        stride_c,
149        batch_count,
150        compute_type,
151        algo,
152    );
153    check(status, "cublasGemmStridedBatchedEx")
154}
155
156/// `cublasSgeam` / `cublasDgeam` — matrix add/scale: `C = α·op(A) + β·op(B)`.
157///
158/// # Safety
159/// Pointers must be valid for `(m,n)` matrices with leading dims
160/// `lda`/`ldb`/`ldc` in column-major layout.
161#[allow(clippy::too_many_arguments)]
162pub unsafe fn sgeam(
163    handle: cublasHandle_t,
164    transa: cublasOperation_t,
165    transb: cublasOperation_t,
166    m: c_int,
167    n: c_int,
168    alpha: *const f32,
169    a: CUdeviceptr,
170    lda: c_int,
171    beta: *const f32,
172    b: CUdeviceptr,
173    ldb: c_int,
174    c: CUdeviceptr,
175    ldc: c_int,
176) -> Result<(), GpuError> {
177    let status = sys::cublasSgeam(
178        handle,
179        transa,
180        transb,
181        m,
182        n,
183        alpha,
184        a as *const _,
185        lda,
186        beta,
187        b as *const _,
188        ldb,
189        c as *mut _,
190        ldc,
191    );
192    check(status, "cublasSgeam")
193}
194
195#[allow(clippy::too_many_arguments)]
196pub unsafe fn dgeam(
197    handle: cublasHandle_t,
198    transa: cublasOperation_t,
199    transb: cublasOperation_t,
200    m: c_int,
201    n: c_int,
202    alpha: *const f64,
203    a: CUdeviceptr,
204    lda: c_int,
205    beta: *const f64,
206    b: CUdeviceptr,
207    ldb: c_int,
208    c: CUdeviceptr,
209    ldc: c_int,
210) -> Result<(), GpuError> {
211    let status = sys::cublasDgeam(
212        handle,
213        transa,
214        transb,
215        m,
216        n,
217        alpha,
218        a as *const _,
219        lda,
220        beta,
221        b as *const _,
222        ldb,
223        c as *mut _,
224        ldc,
225    );
226    check(status, "cublasDgeam")
227}
228
229#[allow(clippy::too_many_arguments)]
230pub unsafe fn ssyrk(
231    handle: cublasHandle_t,
232    uplo: cublasFillMode_t,
233    trans: cublasOperation_t,
234    n: c_int,
235    k: c_int,
236    alpha: *const f32,
237    a: CUdeviceptr,
238    lda: c_int,
239    beta: *const f32,
240    c: CUdeviceptr,
241    ldc: c_int,
242) -> Result<(), GpuError> {
243    let status = sys::cublasSsyrk_v2(
244        handle,
245        uplo,
246        trans,
247        n,
248        k,
249        alpha,
250        a as *const _,
251        lda,
252        beta,
253        c as *mut _,
254        ldc,
255    );
256    check(status, "cublasSsyrk_v2")
257}
258
259#[allow(clippy::too_many_arguments)]
260pub unsafe fn dsyrk(
261    handle: cublasHandle_t,
262    uplo: cublasFillMode_t,
263    trans: cublasOperation_t,
264    n: c_int,
265    k: c_int,
266    alpha: *const f64,
267    a: CUdeviceptr,
268    lda: c_int,
269    beta: *const f64,
270    c: CUdeviceptr,
271    ldc: c_int,
272) -> Result<(), GpuError> {
273    let status = sys::cublasDsyrk_v2(
274        handle,
275        uplo,
276        trans,
277        n,
278        k,
279        alpha,
280        a as *const _,
281        lda,
282        beta,
283        c as *mut _,
284        ldc,
285    );
286    check(status, "cublasDsyrk_v2")
287}
288
289#[allow(clippy::too_many_arguments)]
290pub unsafe fn strsm(
291    handle: cublasHandle_t,
292    side: cublasSideMode_t,
293    uplo: cublasFillMode_t,
294    trans: cublasOperation_t,
295    diag: cublasDiagType_t,
296    m: c_int,
297    n: c_int,
298    alpha: *const f32,
299    a: CUdeviceptr,
300    lda: c_int,
301    b: CUdeviceptr,
302    ldb: c_int,
303) -> Result<(), GpuError> {
304    let status = sys::cublasStrsm_v2(
305        handle,
306        side,
307        uplo,
308        trans,
309        diag,
310        m,
311        n,
312        alpha,
313        a as *const _,
314        lda,
315        b as *mut _,
316        ldb,
317    );
318    check(status, "cublasStrsm_v2")
319}
320
321#[allow(clippy::too_many_arguments)]
322pub unsafe fn dtrsm(
323    handle: cublasHandle_t,
324    side: cublasSideMode_t,
325    uplo: cublasFillMode_t,
326    trans: cublasOperation_t,
327    diag: cublasDiagType_t,
328    m: c_int,
329    n: c_int,
330    alpha: *const f64,
331    a: CUdeviceptr,
332    lda: c_int,
333    b: CUdeviceptr,
334    ldb: c_int,
335) -> Result<(), GpuError> {
336    let status = sys::cublasDtrsm_v2(
337        handle,
338        side,
339        uplo,
340        trans,
341        diag,
342        m,
343        n,
344        alpha,
345        a as *const _,
346        lda,
347        b as *mut _,
348        ldb,
349    );
350    check(status, "cublasDtrsm_v2")
351}
352
353// ───────────────────────────── L2 ─────────────────────────────
354
355#[allow(clippy::too_many_arguments)]
356pub unsafe fn sgemv(
357    handle: cublasHandle_t,
358    trans: cublasOperation_t,
359    m: c_int,
360    n: c_int,
361    alpha: *const f32,
362    a: CUdeviceptr,
363    lda: c_int,
364    x: CUdeviceptr,
365    incx: c_int,
366    beta: *const f32,
367    y: CUdeviceptr,
368    incy: c_int,
369) -> Result<(), GpuError> {
370    let status = sys::cublasSgemv_v2(
371        handle,
372        trans,
373        m,
374        n,
375        alpha,
376        a as *const _,
377        lda,
378        x as *const _,
379        incx,
380        beta,
381        y as *mut _,
382        incy,
383    );
384    check(status, "cublasSgemv_v2")
385}
386
387#[allow(clippy::too_many_arguments)]
388pub unsafe fn dgemv(
389    handle: cublasHandle_t,
390    trans: cublasOperation_t,
391    m: c_int,
392    n: c_int,
393    alpha: *const f64,
394    a: CUdeviceptr,
395    lda: c_int,
396    x: CUdeviceptr,
397    incx: c_int,
398    beta: *const f64,
399    y: CUdeviceptr,
400    incy: c_int,
401) -> Result<(), GpuError> {
402    let status = sys::cublasDgemv_v2(
403        handle,
404        trans,
405        m,
406        n,
407        alpha,
408        a as *const _,
409        lda,
410        x as *const _,
411        incx,
412        beta,
413        y as *mut _,
414        incy,
415    );
416    check(status, "cublasDgemv_v2")
417}
418
419#[allow(clippy::too_many_arguments)]
420pub unsafe fn sger(
421    handle: cublasHandle_t,
422    m: c_int,
423    n: c_int,
424    alpha: *const f32,
425    x: CUdeviceptr,
426    incx: c_int,
427    y: CUdeviceptr,
428    incy: c_int,
429    a: CUdeviceptr,
430    lda: c_int,
431) -> Result<(), GpuError> {
432    let status = sys::cublasSger_v2(
433        handle,
434        m,
435        n,
436        alpha,
437        x as *const _,
438        incx,
439        y as *const _,
440        incy,
441        a as *mut _,
442        lda,
443    );
444    check(status, "cublasSger_v2")
445}
446
447#[allow(clippy::too_many_arguments)]
448pub unsafe fn dger(
449    handle: cublasHandle_t,
450    m: c_int,
451    n: c_int,
452    alpha: *const f64,
453    x: CUdeviceptr,
454    incx: c_int,
455    y: CUdeviceptr,
456    incy: c_int,
457    a: CUdeviceptr,
458    lda: c_int,
459) -> Result<(), GpuError> {
460    let status = sys::cublasDger_v2(
461        handle,
462        m,
463        n,
464        alpha,
465        x as *const _,
466        incx,
467        y as *const _,
468        incy,
469        a as *mut _,
470        lda,
471    );
472    check(status, "cublasDger_v2")
473}
474
475// ───────────────────────────── L1 ─────────────────────────────
476
477#[allow(clippy::too_many_arguments)]
478pub unsafe fn axpy_ex(
479    handle: cublasHandle_t,
480    n: c_int,
481    alpha: *const core::ffi::c_void,
482    alpha_type: cudaDataType,
483    x: CUdeviceptr,
484    x_type: cudaDataType,
485    incx: c_int,
486    y: CUdeviceptr,
487    y_type: cudaDataType,
488    incy: c_int,
489    execution_type: cudaDataType,
490) -> Result<(), GpuError> {
491    let status = sys::cublasAxpyEx(
492        handle,
493        n,
494        alpha,
495        alpha_type,
496        x as *const _,
497        x_type,
498        incx,
499        y as *mut _,
500        y_type,
501        incy,
502        execution_type,
503    );
504    check(status, "cublasAxpyEx")
505}
506
507#[allow(clippy::too_many_arguments)]
508pub unsafe fn scal_ex(
509    handle: cublasHandle_t,
510    n: c_int,
511    alpha: *const core::ffi::c_void,
512    alpha_type: cudaDataType,
513    x: CUdeviceptr,
514    x_type: cudaDataType,
515    incx: c_int,
516    execution_type: cudaDataType,
517) -> Result<(), GpuError> {
518    let status = sys::cublasScalEx(
519        handle,
520        n,
521        alpha,
522        alpha_type,
523        x as *mut _,
524        x_type,
525        incx,
526        execution_type,
527    );
528    check(status, "cublasScalEx")
529}
530
531#[allow(clippy::too_many_arguments)]
532pub unsafe fn nrm2_ex(
533    handle: cublasHandle_t,
534    n: c_int,
535    x: CUdeviceptr,
536    x_type: cudaDataType,
537    incx: c_int,
538    result: *mut core::ffi::c_void,
539    result_type: cudaDataType,
540    execution_type: cudaDataType,
541) -> Result<(), GpuError> {
542    let status = sys::cublasNrm2Ex(
543        handle,
544        n,
545        x as *const _,
546        x_type,
547        incx,
548        result,
549        result_type,
550        execution_type,
551    );
552    check(status, "cublasNrm2Ex")
553}
554
555#[allow(clippy::too_many_arguments)]
556pub unsafe fn dot_ex(
557    handle: cublasHandle_t,
558    n: c_int,
559    x: CUdeviceptr,
560    x_type: cudaDataType,
561    incx: c_int,
562    y: CUdeviceptr,
563    y_type: cudaDataType,
564    incy: c_int,
565    result: *mut core::ffi::c_void,
566    result_type: cudaDataType,
567    execution_type: cudaDataType,
568) -> Result<(), GpuError> {
569    let status = sys::cublasDotEx(
570        handle,
571        n,
572        x as *const _,
573        x_type,
574        incx,
575        y as *const _,
576        y_type,
577        incy,
578        result,
579        result_type,
580        execution_type,
581    );
582    check(status, "cublasDotEx")
583}
584
585#[allow(clippy::too_many_arguments)]
586pub unsafe fn iamax_ex(
587    handle: cublasHandle_t,
588    n: c_int,
589    x: CUdeviceptr,
590    x_type: cudaDataType,
591    incx: c_int,
592    result: *mut c_int,
593) -> Result<(), GpuError> {
594    let status = sys::cublasIamaxEx(handle, n, x as *const _, x_type, incx, result);
595    check(status, "cublasIamaxEx")
596}
597
598#[allow(clippy::too_many_arguments)]
599pub unsafe fn iamin_ex(
600    handle: cublasHandle_t,
601    n: c_int,
602    x: CUdeviceptr,
603    x_type: cudaDataType,
604    incx: c_int,
605    result: *mut c_int,
606) -> Result<(), GpuError> {
607    let status = sys::cublasIaminEx(handle, n, x as *const _, x_type, incx, result);
608    check(status, "cublasIaminEx")
609}
610
611#[allow(clippy::too_many_arguments)]
612pub unsafe fn asum_ex(
613    handle: cublasHandle_t,
614    n: c_int,
615    x: CUdeviceptr,
616    x_type: cudaDataType,
617    incx: c_int,
618    result: *mut core::ffi::c_void,
619    result_type: cudaDataType,
620    execution_type: cudaDataType,
621) -> Result<(), GpuError> {
622    let status = sys::cublasAsumEx(
623        handle,
624        n,
625        x as *const _,
626        x_type,
627        incx,
628        result,
629        result_type,
630        execution_type,
631    );
632    check(status, "cublasAsumEx")
633}
634
635#[allow(clippy::too_many_arguments)]
636pub unsafe fn copy_ex(
637    handle: cublasHandle_t,
638    n: c_int,
639    x: CUdeviceptr,
640    x_type: cudaDataType,
641    incx: c_int,
642    y: CUdeviceptr,
643    y_type: cudaDataType,
644    incy: c_int,
645) -> Result<(), GpuError> {
646    let status = sys::cublasCopyEx(
647        handle,
648        n,
649        x as *const _,
650        x_type,
651        incx,
652        y as *mut _,
653        y_type,
654        incy,
655    );
656    check(status, "cublasCopyEx")
657}
658
659#[allow(clippy::too_many_arguments)]
660pub unsafe fn swap_ex(
661    handle: cublasHandle_t,
662    n: c_int,
663    x: CUdeviceptr,
664    x_type: cudaDataType,
665    incx: c_int,
666    y: CUdeviceptr,
667    y_type: cudaDataType,
668    incy: c_int,
669) -> Result<(), GpuError> {
670    let status = sys::cublasSwapEx(
671        handle,
672        n,
673        x as *mut _,
674        x_type,
675        incx,
676        y as *mut _,
677        y_type,
678        incy,
679    );
680    check(status, "cublasSwapEx")
681}
682
683#[allow(clippy::too_many_arguments)]
684pub unsafe fn rot_ex(
685    handle: cublasHandle_t,
686    n: c_int,
687    x: CUdeviceptr,
688    x_type: cudaDataType,
689    incx: c_int,
690    y: CUdeviceptr,
691    y_type: cudaDataType,
692    incy: c_int,
693    cs: *const core::ffi::c_void,
694    s: *const core::ffi::c_void,
695    cs_type: cudaDataType,
696    execution_type: cudaDataType,
697) -> Result<(), GpuError> {
698    let status = sys::cublasRotEx(
699        handle,
700        n,
701        x as *mut _,
702        x_type,
703        incx,
704        y as *mut _,
705        y_type,
706        incy,
707        cs,
708        s,
709        cs_type,
710        execution_type,
711    );
712    check(status, "cublasRotEx")
713}