1#![allow(non_snake_case)]
19
20use core::ffi::{c_int, c_longlong};
21
22use cudarc::cublas::sys::{
23 self, cublasComputeType_t, cublasDiagType_t, cublasFillMode_t, cublasGemmAlgo_t,
24 cublasHandle_t, cublasOperation_t, cublasSideMode_t, cudaDataType,
25};
26use cudarc::driver::sys::CUdeviceptr;
27
28use crate::error::GpuError;
29
30const LIB: &str = "cublas";
31
32#[inline]
33fn check(status: sys::cublasStatus_t, op: &'static str) -> Result<(), GpuError> {
34 match status {
35 sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
36 e => Err(GpuError::LibraryError {
37 lib: LIB,
38 msg: format!("{op}: {e:?}"),
39 }),
40 }
41}
42
43#[allow(clippy::too_many_arguments)]
52pub unsafe fn gemm_ex(
53 handle: cublasHandle_t,
54 transa: cublasOperation_t,
55 transb: cublasOperation_t,
56 m: c_int,
57 n: c_int,
58 k: c_int,
59 alpha: *const core::ffi::c_void,
60 a: CUdeviceptr,
61 a_type: cudaDataType,
62 lda: c_int,
63 b: CUdeviceptr,
64 b_type: cudaDataType,
65 ldb: c_int,
66 beta: *const core::ffi::c_void,
67 c: CUdeviceptr,
68 c_type: cudaDataType,
69 ldc: c_int,
70 compute_type: cublasComputeType_t,
71 algo: cublasGemmAlgo_t,
72) -> Result<(), GpuError> {
73 let status = sys::cublasGemmEx(
74 handle,
75 transa,
76 transb,
77 m,
78 n,
79 k,
80 alpha,
81 a as *const _,
82 a_type,
83 lda,
84 b as *const _,
85 b_type,
86 ldb,
87 beta,
88 c as *mut _,
89 c_type,
90 ldc,
91 compute_type,
92 algo,
93 );
94 check(status, "cublasGemmEx")
95}
96
97#[allow(clippy::too_many_arguments)]
103pub unsafe fn gemm_strided_batched_ex(
104 handle: cublasHandle_t,
105 transa: cublasOperation_t,
106 transb: cublasOperation_t,
107 m: c_int,
108 n: c_int,
109 k: c_int,
110 alpha: *const core::ffi::c_void,
111 a: CUdeviceptr,
112 a_type: cudaDataType,
113 lda: c_int,
114 stride_a: c_longlong,
115 b: CUdeviceptr,
116 b_type: cudaDataType,
117 ldb: c_int,
118 stride_b: c_longlong,
119 beta: *const core::ffi::c_void,
120 c: CUdeviceptr,
121 c_type: cudaDataType,
122 ldc: c_int,
123 stride_c: c_longlong,
124 batch_count: c_int,
125 compute_type: cublasComputeType_t,
126 algo: cublasGemmAlgo_t,
127) -> Result<(), GpuError> {
128 let status = sys::cublasGemmStridedBatchedEx(
129 handle,
130 transa,
131 transb,
132 m,
133 n,
134 k,
135 alpha,
136 a as *const _,
137 a_type,
138 lda,
139 stride_a,
140 b as *const _,
141 b_type,
142 ldb,
143 stride_b,
144 beta,
145 c as *mut _,
146 c_type,
147 ldc,
148 stride_c,
149 batch_count,
150 compute_type,
151 algo,
152 );
153 check(status, "cublasGemmStridedBatchedEx")
154}
155
156#[allow(clippy::too_many_arguments)]
162pub unsafe fn sgeam(
163 handle: cublasHandle_t,
164 transa: cublasOperation_t,
165 transb: cublasOperation_t,
166 m: c_int,
167 n: c_int,
168 alpha: *const f32,
169 a: CUdeviceptr,
170 lda: c_int,
171 beta: *const f32,
172 b: CUdeviceptr,
173 ldb: c_int,
174 c: CUdeviceptr,
175 ldc: c_int,
176) -> Result<(), GpuError> {
177 let status = sys::cublasSgeam(
178 handle,
179 transa,
180 transb,
181 m,
182 n,
183 alpha,
184 a as *const _,
185 lda,
186 beta,
187 b as *const _,
188 ldb,
189 c as *mut _,
190 ldc,
191 );
192 check(status, "cublasSgeam")
193}
194
195#[allow(clippy::too_many_arguments)]
196pub unsafe fn dgeam(
197 handle: cublasHandle_t,
198 transa: cublasOperation_t,
199 transb: cublasOperation_t,
200 m: c_int,
201 n: c_int,
202 alpha: *const f64,
203 a: CUdeviceptr,
204 lda: c_int,
205 beta: *const f64,
206 b: CUdeviceptr,
207 ldb: c_int,
208 c: CUdeviceptr,
209 ldc: c_int,
210) -> Result<(), GpuError> {
211 let status = sys::cublasDgeam(
212 handle,
213 transa,
214 transb,
215 m,
216 n,
217 alpha,
218 a as *const _,
219 lda,
220 beta,
221 b as *const _,
222 ldb,
223 c as *mut _,
224 ldc,
225 );
226 check(status, "cublasDgeam")
227}
228
229#[allow(clippy::too_many_arguments)]
230pub unsafe fn ssyrk(
231 handle: cublasHandle_t,
232 uplo: cublasFillMode_t,
233 trans: cublasOperation_t,
234 n: c_int,
235 k: c_int,
236 alpha: *const f32,
237 a: CUdeviceptr,
238 lda: c_int,
239 beta: *const f32,
240 c: CUdeviceptr,
241 ldc: c_int,
242) -> Result<(), GpuError> {
243 let status = sys::cublasSsyrk_v2(
244 handle,
245 uplo,
246 trans,
247 n,
248 k,
249 alpha,
250 a as *const _,
251 lda,
252 beta,
253 c as *mut _,
254 ldc,
255 );
256 check(status, "cublasSsyrk_v2")
257}
258
259#[allow(clippy::too_many_arguments)]
260pub unsafe fn dsyrk(
261 handle: cublasHandle_t,
262 uplo: cublasFillMode_t,
263 trans: cublasOperation_t,
264 n: c_int,
265 k: c_int,
266 alpha: *const f64,
267 a: CUdeviceptr,
268 lda: c_int,
269 beta: *const f64,
270 c: CUdeviceptr,
271 ldc: c_int,
272) -> Result<(), GpuError> {
273 let status = sys::cublasDsyrk_v2(
274 handle,
275 uplo,
276 trans,
277 n,
278 k,
279 alpha,
280 a as *const _,
281 lda,
282 beta,
283 c as *mut _,
284 ldc,
285 );
286 check(status, "cublasDsyrk_v2")
287}
288
289#[allow(clippy::too_many_arguments)]
290pub unsafe fn strsm(
291 handle: cublasHandle_t,
292 side: cublasSideMode_t,
293 uplo: cublasFillMode_t,
294 trans: cublasOperation_t,
295 diag: cublasDiagType_t,
296 m: c_int,
297 n: c_int,
298 alpha: *const f32,
299 a: CUdeviceptr,
300 lda: c_int,
301 b: CUdeviceptr,
302 ldb: c_int,
303) -> Result<(), GpuError> {
304 let status = sys::cublasStrsm_v2(
305 handle,
306 side,
307 uplo,
308 trans,
309 diag,
310 m,
311 n,
312 alpha,
313 a as *const _,
314 lda,
315 b as *mut _,
316 ldb,
317 );
318 check(status, "cublasStrsm_v2")
319}
320
321#[allow(clippy::too_many_arguments)]
322pub unsafe fn dtrsm(
323 handle: cublasHandle_t,
324 side: cublasSideMode_t,
325 uplo: cublasFillMode_t,
326 trans: cublasOperation_t,
327 diag: cublasDiagType_t,
328 m: c_int,
329 n: c_int,
330 alpha: *const f64,
331 a: CUdeviceptr,
332 lda: c_int,
333 b: CUdeviceptr,
334 ldb: c_int,
335) -> Result<(), GpuError> {
336 let status = sys::cublasDtrsm_v2(
337 handle,
338 side,
339 uplo,
340 trans,
341 diag,
342 m,
343 n,
344 alpha,
345 a as *const _,
346 lda,
347 b as *mut _,
348 ldb,
349 );
350 check(status, "cublasDtrsm_v2")
351}
352
353#[allow(clippy::too_many_arguments)]
356pub unsafe fn sgemv(
357 handle: cublasHandle_t,
358 trans: cublasOperation_t,
359 m: c_int,
360 n: c_int,
361 alpha: *const f32,
362 a: CUdeviceptr,
363 lda: c_int,
364 x: CUdeviceptr,
365 incx: c_int,
366 beta: *const f32,
367 y: CUdeviceptr,
368 incy: c_int,
369) -> Result<(), GpuError> {
370 let status = sys::cublasSgemv_v2(
371 handle,
372 trans,
373 m,
374 n,
375 alpha,
376 a as *const _,
377 lda,
378 x as *const _,
379 incx,
380 beta,
381 y as *mut _,
382 incy,
383 );
384 check(status, "cublasSgemv_v2")
385}
386
387#[allow(clippy::too_many_arguments)]
388pub unsafe fn dgemv(
389 handle: cublasHandle_t,
390 trans: cublasOperation_t,
391 m: c_int,
392 n: c_int,
393 alpha: *const f64,
394 a: CUdeviceptr,
395 lda: c_int,
396 x: CUdeviceptr,
397 incx: c_int,
398 beta: *const f64,
399 y: CUdeviceptr,
400 incy: c_int,
401) -> Result<(), GpuError> {
402 let status = sys::cublasDgemv_v2(
403 handle,
404 trans,
405 m,
406 n,
407 alpha,
408 a as *const _,
409 lda,
410 x as *const _,
411 incx,
412 beta,
413 y as *mut _,
414 incy,
415 );
416 check(status, "cublasDgemv_v2")
417}
418
419#[allow(clippy::too_many_arguments)]
420pub unsafe fn sger(
421 handle: cublasHandle_t,
422 m: c_int,
423 n: c_int,
424 alpha: *const f32,
425 x: CUdeviceptr,
426 incx: c_int,
427 y: CUdeviceptr,
428 incy: c_int,
429 a: CUdeviceptr,
430 lda: c_int,
431) -> Result<(), GpuError> {
432 let status = sys::cublasSger_v2(
433 handle,
434 m,
435 n,
436 alpha,
437 x as *const _,
438 incx,
439 y as *const _,
440 incy,
441 a as *mut _,
442 lda,
443 );
444 check(status, "cublasSger_v2")
445}
446
447#[allow(clippy::too_many_arguments)]
448pub unsafe fn dger(
449 handle: cublasHandle_t,
450 m: c_int,
451 n: c_int,
452 alpha: *const f64,
453 x: CUdeviceptr,
454 incx: c_int,
455 y: CUdeviceptr,
456 incy: c_int,
457 a: CUdeviceptr,
458 lda: c_int,
459) -> Result<(), GpuError> {
460 let status = sys::cublasDger_v2(
461 handle,
462 m,
463 n,
464 alpha,
465 x as *const _,
466 incx,
467 y as *const _,
468 incy,
469 a as *mut _,
470 lda,
471 );
472 check(status, "cublasDger_v2")
473}
474
475#[allow(clippy::too_many_arguments)]
478pub unsafe fn axpy_ex(
479 handle: cublasHandle_t,
480 n: c_int,
481 alpha: *const core::ffi::c_void,
482 alpha_type: cudaDataType,
483 x: CUdeviceptr,
484 x_type: cudaDataType,
485 incx: c_int,
486 y: CUdeviceptr,
487 y_type: cudaDataType,
488 incy: c_int,
489 execution_type: cudaDataType,
490) -> Result<(), GpuError> {
491 let status = sys::cublasAxpyEx(
492 handle,
493 n,
494 alpha,
495 alpha_type,
496 x as *const _,
497 x_type,
498 incx,
499 y as *mut _,
500 y_type,
501 incy,
502 execution_type,
503 );
504 check(status, "cublasAxpyEx")
505}
506
507#[allow(clippy::too_many_arguments)]
508pub unsafe fn scal_ex(
509 handle: cublasHandle_t,
510 n: c_int,
511 alpha: *const core::ffi::c_void,
512 alpha_type: cudaDataType,
513 x: CUdeviceptr,
514 x_type: cudaDataType,
515 incx: c_int,
516 execution_type: cudaDataType,
517) -> Result<(), GpuError> {
518 let status = sys::cublasScalEx(
519 handle,
520 n,
521 alpha,
522 alpha_type,
523 x as *mut _,
524 x_type,
525 incx,
526 execution_type,
527 );
528 check(status, "cublasScalEx")
529}
530
531#[allow(clippy::too_many_arguments)]
532pub unsafe fn nrm2_ex(
533 handle: cublasHandle_t,
534 n: c_int,
535 x: CUdeviceptr,
536 x_type: cudaDataType,
537 incx: c_int,
538 result: *mut core::ffi::c_void,
539 result_type: cudaDataType,
540 execution_type: cudaDataType,
541) -> Result<(), GpuError> {
542 let status = sys::cublasNrm2Ex(
543 handle,
544 n,
545 x as *const _,
546 x_type,
547 incx,
548 result,
549 result_type,
550 execution_type,
551 );
552 check(status, "cublasNrm2Ex")
553}
554
555#[allow(clippy::too_many_arguments)]
556pub unsafe fn dot_ex(
557 handle: cublasHandle_t,
558 n: c_int,
559 x: CUdeviceptr,
560 x_type: cudaDataType,
561 incx: c_int,
562 y: CUdeviceptr,
563 y_type: cudaDataType,
564 incy: c_int,
565 result: *mut core::ffi::c_void,
566 result_type: cudaDataType,
567 execution_type: cudaDataType,
568) -> Result<(), GpuError> {
569 let status = sys::cublasDotEx(
570 handle,
571 n,
572 x as *const _,
573 x_type,
574 incx,
575 y as *const _,
576 y_type,
577 incy,
578 result,
579 result_type,
580 execution_type,
581 );
582 check(status, "cublasDotEx")
583}
584
585#[allow(clippy::too_many_arguments)]
586pub unsafe fn iamax_ex(
587 handle: cublasHandle_t,
588 n: c_int,
589 x: CUdeviceptr,
590 x_type: cudaDataType,
591 incx: c_int,
592 result: *mut c_int,
593) -> Result<(), GpuError> {
594 let status = sys::cublasIamaxEx(handle, n, x as *const _, x_type, incx, result);
595 check(status, "cublasIamaxEx")
596}
597
598#[allow(clippy::too_many_arguments)]
599pub unsafe fn iamin_ex(
600 handle: cublasHandle_t,
601 n: c_int,
602 x: CUdeviceptr,
603 x_type: cudaDataType,
604 incx: c_int,
605 result: *mut c_int,
606) -> Result<(), GpuError> {
607 let status = sys::cublasIaminEx(handle, n, x as *const _, x_type, incx, result);
608 check(status, "cublasIaminEx")
609}
610
611#[allow(clippy::too_many_arguments)]
612pub unsafe fn asum_ex(
613 handle: cublasHandle_t,
614 n: c_int,
615 x: CUdeviceptr,
616 x_type: cudaDataType,
617 incx: c_int,
618 result: *mut core::ffi::c_void,
619 result_type: cudaDataType,
620 execution_type: cudaDataType,
621) -> Result<(), GpuError> {
622 let status = sys::cublasAsumEx(
623 handle,
624 n,
625 x as *const _,
626 x_type,
627 incx,
628 result,
629 result_type,
630 execution_type,
631 );
632 check(status, "cublasAsumEx")
633}
634
635#[allow(clippy::too_many_arguments)]
636pub unsafe fn copy_ex(
637 handle: cublasHandle_t,
638 n: c_int,
639 x: CUdeviceptr,
640 x_type: cudaDataType,
641 incx: c_int,
642 y: CUdeviceptr,
643 y_type: cudaDataType,
644 incy: c_int,
645) -> Result<(), GpuError> {
646 let status = sys::cublasCopyEx(
647 handle,
648 n,
649 x as *const _,
650 x_type,
651 incx,
652 y as *mut _,
653 y_type,
654 incy,
655 );
656 check(status, "cublasCopyEx")
657}
658
659#[allow(clippy::too_many_arguments)]
660pub unsafe fn swap_ex(
661 handle: cublasHandle_t,
662 n: c_int,
663 x: CUdeviceptr,
664 x_type: cudaDataType,
665 incx: c_int,
666 y: CUdeviceptr,
667 y_type: cudaDataType,
668 incy: c_int,
669) -> Result<(), GpuError> {
670 let status = sys::cublasSwapEx(
671 handle,
672 n,
673 x as *mut _,
674 x_type,
675 incx,
676 y as *mut _,
677 y_type,
678 incy,
679 );
680 check(status, "cublasSwapEx")
681}
682
683#[allow(clippy::too_many_arguments)]
684pub unsafe fn rot_ex(
685 handle: cublasHandle_t,
686 n: c_int,
687 x: CUdeviceptr,
688 x_type: cudaDataType,
689 incx: c_int,
690 y: CUdeviceptr,
691 y_type: cudaDataType,
692 incy: c_int,
693 cs: *const core::ffi::c_void,
694 s: *const core::ffi::c_void,
695 cs_type: cudaDataType,
696 execution_type: cudaDataType,
697) -> Result<(), GpuError> {
698 let status = sys::cublasRotEx(
699 handle,
700 n,
701 x as *mut _,
702 x_type,
703 incx,
704 y as *mut _,
705 y_type,
706 incy,
707 cs,
708 s,
709 cs_type,
710 execution_type,
711 );
712 check(status, "cublasRotEx")
713}