Skip to main content

atomr_accel_cuda/sys/
cusolver.rs

1//! Crate-private FFI thunks around `cudarc::cusolver::sys`.
2//!
3//! cudarc 0.19's safe layer covers handle construction (`DnHandle`,
4//! `SpHandle`) but leaves every dense / sparse / batched factorisation
5//! behind raw `extern "C"` declarations under `cusolver::sys::lib`. The
6//! [`SolverScalar`] trait lifts the per-prefix C entry points
7//! (`Sgeqrf`, `Dgeqrf`, …) onto a uniform Rust surface so the actor's
8//! per-op handlers can be written generically over `T: SolverSupported`.
9//!
10//! All unsafe FFI is contained inside this module — handlers in
11//! `kernel::solver::*` only see the typed wrappers and a single
12//! [`status_to_result`] adapter.
13
14use cudarc::cusolver::sys as cs;
15
16use crate::dtype::SolverSupported;
17use crate::error::GpuError;
18
19pub const LIB: &str = "cusolver";
20
21/// Translate a `cusolverStatus_t` into our typed error.
22pub fn status_to_result(status: cs::cusolverStatus_t, op: &'static str) -> Result<(), GpuError> {
23    if status == cs::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
24        Ok(())
25    } else {
26        Err(GpuError::LibraryError {
27            lib: LIB,
28            msg: format!("{op}: {status:?}"),
29        })
30    }
31}
32
33/// Per-dtype dispatcher to the cuSOLVER `S/D/C/Z` entry points.
34///
35/// Methods take raw pointers because the actor side has already
36/// extracted them via `device_ptr_mut`; threading lifetimes through
37/// `&mut CudaSlice<T>` would force every wrapper to be generic over
38/// keep-alive guards. The actor is responsible for keeping the slices
39/// alive for the duration of the call.
40///
41/// # Safety
42///
43/// Each `unsafe` method is safe to call when:
44/// - the `handle` is a live `DnHandle::cu()` value bound to the same
45///   stream the slices were allocated/written through,
46/// - all device pointers reference the buffer sizes implied by the
47///   `(m, n, lda)` triple per the cuSOLVER reference,
48/// - `lwork` matches what the corresponding `*_buffer_size` returned,
49/// - `info` points to at least one writable `i32`.
50pub trait SolverScalar: SolverSupported {
51    /// QR `geqrf`: workspace query.
52    unsafe fn geqrf_buffer_size(
53        handle: cs::cusolverDnHandle_t,
54        m: i32,
55        n: i32,
56        a: *mut Self,
57        lda: i32,
58        lwork: *mut i32,
59    ) -> cs::cusolverStatus_t;
60    unsafe fn geqrf(
61        handle: cs::cusolverDnHandle_t,
62        m: i32,
63        n: i32,
64        a: *mut Self,
65        lda: i32,
66        tau: *mut Self,
67        work: *mut Self,
68        lwork: i32,
69        info: *mut i32,
70    ) -> cs::cusolverStatus_t;
71
72    /// LU `getrf`: workspace query.
73    unsafe fn getrf_buffer_size(
74        handle: cs::cusolverDnHandle_t,
75        m: i32,
76        n: i32,
77        a: *mut Self,
78        lda: i32,
79        lwork: *mut i32,
80    ) -> cs::cusolverStatus_t;
81    unsafe fn getrf(
82        handle: cs::cusolverDnHandle_t,
83        m: i32,
84        n: i32,
85        a: *mut Self,
86        lda: i32,
87        work: *mut Self,
88        ipiv: *mut i32,
89        info: *mut i32,
90    ) -> cs::cusolverStatus_t;
91    unsafe fn getrs(
92        handle: cs::cusolverDnHandle_t,
93        trans: cs::cublasOperation_t,
94        n: i32,
95        nrhs: i32,
96        a: *const Self,
97        lda: i32,
98        ipiv: *const i32,
99        b: *mut Self,
100        ldb: i32,
101        info: *mut i32,
102    ) -> cs::cusolverStatus_t;
103
104    /// Cholesky `potrf`.
105    unsafe fn potrf_buffer_size(
106        handle: cs::cusolverDnHandle_t,
107        uplo: cs::cublasFillMode_t,
108        n: i32,
109        a: *mut Self,
110        lda: i32,
111        lwork: *mut i32,
112    ) -> cs::cusolverStatus_t;
113    unsafe fn potrf(
114        handle: cs::cusolverDnHandle_t,
115        uplo: cs::cublasFillMode_t,
116        n: i32,
117        a: *mut Self,
118        lda: i32,
119        work: *mut Self,
120        lwork: i32,
121        info: *mut i32,
122    ) -> cs::cusolverStatus_t;
123
124    /// SVD `gesvd`. (`gesvd` workspace query takes only m/n.)
125    unsafe fn gesvd_buffer_size(
126        handle: cs::cusolverDnHandle_t,
127        m: i32,
128        n: i32,
129        lwork: *mut i32,
130    ) -> cs::cusolverStatus_t;
131    unsafe fn gesvd(
132        handle: cs::cusolverDnHandle_t,
133        jobu: i8,
134        jobvt: i8,
135        m: i32,
136        n: i32,
137        a: *mut Self,
138        lda: i32,
139        s: *mut Self,
140        u: *mut Self,
141        ldu: i32,
142        vt: *mut Self,
143        ldvt: i32,
144        work: *mut Self,
145        lwork: i32,
146        rwork: *mut Self,
147        info: *mut i32,
148    ) -> cs::cusolverStatus_t;
149
150    /// Symmetric eigendecomposition `syevd`.
151    unsafe fn syevd_buffer_size(
152        handle: cs::cusolverDnHandle_t,
153        jobz: cs::cusolverEigMode_t,
154        uplo: cs::cublasFillMode_t,
155        n: i32,
156        a: *const Self,
157        lda: i32,
158        w: *const Self,
159        lwork: *mut i32,
160    ) -> cs::cusolverStatus_t;
161    unsafe fn syevd(
162        handle: cs::cusolverDnHandle_t,
163        jobz: cs::cusolverEigMode_t,
164        uplo: cs::cublasFillMode_t,
165        n: i32,
166        a: *mut Self,
167        lda: i32,
168        w: *mut Self,
169        work: *mut Self,
170        lwork: i32,
171        info: *mut i32,
172    ) -> cs::cusolverStatus_t;
173
174    /// Generalized symmetric eigendecomposition `sygvd` (real)
175    /// / `hegvd` (complex). Phase 1 routes both messages through
176    /// the same trait method since the f32/f64 surface is purely
177    /// real.
178    unsafe fn sygvd_buffer_size(
179        handle: cs::cusolverDnHandle_t,
180        itype: cs::cusolverEigType_t,
181        jobz: cs::cusolverEigMode_t,
182        uplo: cs::cublasFillMode_t,
183        n: i32,
184        a: *const Self,
185        lda: i32,
186        b: *const Self,
187        ldb: i32,
188        w: *const Self,
189        lwork: *mut i32,
190    ) -> cs::cusolverStatus_t;
191    unsafe fn sygvd(
192        handle: cs::cusolverDnHandle_t,
193        itype: cs::cusolverEigType_t,
194        jobz: cs::cusolverEigMode_t,
195        uplo: cs::cublasFillMode_t,
196        n: i32,
197        a: *mut Self,
198        lda: i32,
199        b: *mut Self,
200        ldb: i32,
201        w: *mut Self,
202        work: *mut Self,
203        lwork: i32,
204        info: *mut i32,
205    ) -> cs::cusolverStatus_t;
206
207    /// Batched Cholesky `potrfBatched`. cuSOLVER takes an
208    /// array-of-pointers to `n × n` matrices; the actor stages
209    /// them into a contiguous device buffer of pointers.
210    unsafe fn potrf_batched(
211        handle: cs::cusolverDnHandle_t,
212        uplo: cs::cublasFillMode_t,
213        n: i32,
214        a_array: *mut *mut Self,
215        lda: i32,
216        info_array: *mut i32,
217        batch_size: i32,
218    ) -> cs::cusolverStatus_t;
219
220    /// Batched Jacobi SVD `gesvdjBatched` workspace query + launch.
221    unsafe fn gesvdj_batched_buffer_size(
222        handle: cs::cusolverDnHandle_t,
223        jobz: cs::cusolverEigMode_t,
224        m: i32,
225        n: i32,
226        a: *const Self,
227        lda: i32,
228        s: *const Self,
229        u: *const Self,
230        ldu: i32,
231        v: *const Self,
232        ldv: i32,
233        lwork: *mut i32,
234        params: cs::gesvdjInfo_t,
235        batch_size: i32,
236    ) -> cs::cusolverStatus_t;
237    unsafe fn gesvdj_batched(
238        handle: cs::cusolverDnHandle_t,
239        jobz: cs::cusolverEigMode_t,
240        m: i32,
241        n: i32,
242        a: *mut Self,
243        lda: i32,
244        s: *mut Self,
245        u: *mut Self,
246        ldu: i32,
247        v: *mut Self,
248        ldv: i32,
249        work: *mut Self,
250        lwork: i32,
251        info: *mut i32,
252        params: cs::gesvdjInfo_t,
253        batch_size: i32,
254    ) -> cs::cusolverStatus_t;
255}
256
257macro_rules! impl_solver_scalar {
258    (
259        $T:ty,
260        geqrf:           $geqrf:ident,        $geqrf_bs:ident;
261        getrf:           $getrf:ident,        $getrf_bs:ident,        $getrs:ident;
262        potrf:           $potrf:ident,        $potrf_bs:ident;
263        gesvd:           $gesvd:ident,        $gesvd_bs:ident;
264        syevd:           $syevd:ident,        $syevd_bs:ident;
265        sygvd:           $sygvd:ident,        $sygvd_bs:ident;
266        potrf_batched:   $potrf_b:ident;
267        gesvdj_batched:  $gesvdj_b:ident,     $gesvdj_b_bs:ident;
268    ) => {
269        impl SolverScalar for $T {
270            unsafe fn geqrf_buffer_size(
271                handle: cs::cusolverDnHandle_t,
272                m: i32,
273                n: i32,
274                a: *mut Self,
275                lda: i32,
276                lwork: *mut i32,
277            ) -> cs::cusolverStatus_t {
278                cs::$geqrf_bs(handle, m, n, a, lda, lwork)
279            }
280            unsafe fn geqrf(
281                handle: cs::cusolverDnHandle_t,
282                m: i32,
283                n: i32,
284                a: *mut Self,
285                lda: i32,
286                tau: *mut Self,
287                work: *mut Self,
288                lwork: i32,
289                info: *mut i32,
290            ) -> cs::cusolverStatus_t {
291                cs::$geqrf(handle, m, n, a, lda, tau, work, lwork, info)
292            }
293
294            unsafe fn getrf_buffer_size(
295                handle: cs::cusolverDnHandle_t,
296                m: i32,
297                n: i32,
298                a: *mut Self,
299                lda: i32,
300                lwork: *mut i32,
301            ) -> cs::cusolverStatus_t {
302                cs::$getrf_bs(handle, m, n, a, lda, lwork)
303            }
304            unsafe fn getrf(
305                handle: cs::cusolverDnHandle_t,
306                m: i32,
307                n: i32,
308                a: *mut Self,
309                lda: i32,
310                work: *mut Self,
311                ipiv: *mut i32,
312                info: *mut i32,
313            ) -> cs::cusolverStatus_t {
314                cs::$getrf(handle, m, n, a, lda, work, ipiv, info)
315            }
316            unsafe fn getrs(
317                handle: cs::cusolverDnHandle_t,
318                trans: cs::cublasOperation_t,
319                n: i32,
320                nrhs: i32,
321                a: *const Self,
322                lda: i32,
323                ipiv: *const i32,
324                b: *mut Self,
325                ldb: i32,
326                info: *mut i32,
327            ) -> cs::cusolverStatus_t {
328                cs::$getrs(handle, trans, n, nrhs, a, lda, ipiv, b, ldb, info)
329            }
330
331            unsafe fn potrf_buffer_size(
332                handle: cs::cusolverDnHandle_t,
333                uplo: cs::cublasFillMode_t,
334                n: i32,
335                a: *mut Self,
336                lda: i32,
337                lwork: *mut i32,
338            ) -> cs::cusolverStatus_t {
339                cs::$potrf_bs(handle, uplo, n, a, lda, lwork)
340            }
341            unsafe fn potrf(
342                handle: cs::cusolverDnHandle_t,
343                uplo: cs::cublasFillMode_t,
344                n: i32,
345                a: *mut Self,
346                lda: i32,
347                work: *mut Self,
348                lwork: i32,
349                info: *mut i32,
350            ) -> cs::cusolverStatus_t {
351                cs::$potrf(handle, uplo, n, a, lda, work, lwork, info)
352            }
353
354            unsafe fn gesvd_buffer_size(
355                handle: cs::cusolverDnHandle_t,
356                m: i32,
357                n: i32,
358                lwork: *mut i32,
359            ) -> cs::cusolverStatus_t {
360                cs::$gesvd_bs(handle, m, n, lwork)
361            }
362            unsafe fn gesvd(
363                handle: cs::cusolverDnHandle_t,
364                jobu: i8,
365                jobvt: i8,
366                m: i32,
367                n: i32,
368                a: *mut Self,
369                lda: i32,
370                s: *mut Self,
371                u: *mut Self,
372                ldu: i32,
373                vt: *mut Self,
374                ldvt: i32,
375                work: *mut Self,
376                lwork: i32,
377                rwork: *mut Self,
378                info: *mut i32,
379            ) -> cs::cusolverStatus_t {
380                cs::$gesvd(
381                    handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork,
382                    info,
383                )
384            }
385
386            unsafe fn syevd_buffer_size(
387                handle: cs::cusolverDnHandle_t,
388                jobz: cs::cusolverEigMode_t,
389                uplo: cs::cublasFillMode_t,
390                n: i32,
391                a: *const Self,
392                lda: i32,
393                w: *const Self,
394                lwork: *mut i32,
395            ) -> cs::cusolverStatus_t {
396                cs::$syevd_bs(handle, jobz, uplo, n, a, lda, w, lwork)
397            }
398            unsafe fn syevd(
399                handle: cs::cusolverDnHandle_t,
400                jobz: cs::cusolverEigMode_t,
401                uplo: cs::cublasFillMode_t,
402                n: i32,
403                a: *mut Self,
404                lda: i32,
405                w: *mut Self,
406                work: *mut Self,
407                lwork: i32,
408                info: *mut i32,
409            ) -> cs::cusolverStatus_t {
410                cs::$syevd(handle, jobz, uplo, n, a, lda, w, work, lwork, info)
411            }
412
413            unsafe fn sygvd_buffer_size(
414                handle: cs::cusolverDnHandle_t,
415                itype: cs::cusolverEigType_t,
416                jobz: cs::cusolverEigMode_t,
417                uplo: cs::cublasFillMode_t,
418                n: i32,
419                a: *const Self,
420                lda: i32,
421                b: *const Self,
422                ldb: i32,
423                w: *const Self,
424                lwork: *mut i32,
425            ) -> cs::cusolverStatus_t {
426                cs::$sygvd_bs(handle, itype, jobz, uplo, n, a, lda, b, ldb, w, lwork)
427            }
428            unsafe fn sygvd(
429                handle: cs::cusolverDnHandle_t,
430                itype: cs::cusolverEigType_t,
431                jobz: cs::cusolverEigMode_t,
432                uplo: cs::cublasFillMode_t,
433                n: i32,
434                a: *mut Self,
435                lda: i32,
436                b: *mut Self,
437                ldb: i32,
438                w: *mut Self,
439                work: *mut Self,
440                lwork: i32,
441                info: *mut i32,
442            ) -> cs::cusolverStatus_t {
443                cs::$sygvd(
444                    handle, itype, jobz, uplo, n, a, lda, b, ldb, w, work, lwork, info,
445                )
446            }
447
448            unsafe fn potrf_batched(
449                handle: cs::cusolverDnHandle_t,
450                uplo: cs::cublasFillMode_t,
451                n: i32,
452                a_array: *mut *mut Self,
453                lda: i32,
454                info_array: *mut i32,
455                batch_size: i32,
456            ) -> cs::cusolverStatus_t {
457                cs::$potrf_b(handle, uplo, n, a_array, lda, info_array, batch_size)
458            }
459
460            unsafe fn gesvdj_batched_buffer_size(
461                handle: cs::cusolverDnHandle_t,
462                jobz: cs::cusolverEigMode_t,
463                m: i32,
464                n: i32,
465                a: *const Self,
466                lda: i32,
467                s: *const Self,
468                u: *const Self,
469                ldu: i32,
470                v: *const Self,
471                ldv: i32,
472                lwork: *mut i32,
473                params: cs::gesvdjInfo_t,
474                batch_size: i32,
475            ) -> cs::cusolverStatus_t {
476                cs::$gesvdj_b_bs(
477                    handle, jobz, m, n, a, lda, s, u, ldu, v, ldv, lwork, params, batch_size,
478                )
479            }
480            unsafe fn gesvdj_batched(
481                handle: cs::cusolverDnHandle_t,
482                jobz: cs::cusolverEigMode_t,
483                m: i32,
484                n: i32,
485                a: *mut Self,
486                lda: i32,
487                s: *mut Self,
488                u: *mut Self,
489                ldu: i32,
490                v: *mut Self,
491                ldv: i32,
492                work: *mut Self,
493                lwork: i32,
494                info: *mut i32,
495                params: cs::gesvdjInfo_t,
496                batch_size: i32,
497            ) -> cs::cusolverStatus_t {
498                cs::$gesvdj_b(
499                    handle, jobz, m, n, a, lda, s, u, ldu, v, ldv, work, lwork, info, params,
500                    batch_size,
501                )
502            }
503        }
504    };
505}
506
507impl_solver_scalar!(
508    f32,
509    geqrf:          cusolverDnSgeqrf,            cusolverDnSgeqrf_bufferSize;
510    getrf:          cusolverDnSgetrf,            cusolverDnSgetrf_bufferSize, cusolverDnSgetrs;
511    potrf:          cusolverDnSpotrf,            cusolverDnSpotrf_bufferSize;
512    gesvd:          cusolverDnSgesvd,            cusolverDnSgesvd_bufferSize;
513    syevd:          cusolverDnSsyevd,            cusolverDnSsyevd_bufferSize;
514    sygvd:          cusolverDnSsygvd,            cusolverDnSsygvd_bufferSize;
515    potrf_batched:  cusolverDnSpotrfBatched;
516    gesvdj_batched: cusolverDnSgesvdjBatched,    cusolverDnSgesvdjBatched_bufferSize;
517);
518
519impl_solver_scalar!(
520    f64,
521    geqrf:          cusolverDnDgeqrf,            cusolverDnDgeqrf_bufferSize;
522    getrf:          cusolverDnDgetrf,            cusolverDnDgetrf_bufferSize, cusolverDnDgetrs;
523    potrf:          cusolverDnDpotrf,            cusolverDnDpotrf_bufferSize;
524    gesvd:          cusolverDnDgesvd,            cusolverDnDgesvd_bufferSize;
525    syevd:          cusolverDnDsyevd,            cusolverDnDsyevd_bufferSize;
526    sygvd:          cusolverDnDsygvd,            cusolverDnDsygvd_bufferSize;
527    potrf_batched:  cusolverDnDpotrfBatched;
528    gesvdj_batched: cusolverDnDgesvdjBatched,    cusolverDnDgesvdjBatched_bufferSize;
529);
530
531/// Sparse cuSOLVER (`cusolverSp`) entry points.
532///
533/// Only the device-side `csrlsv*` triplet (Cholesky / QR) is exposed
534/// here — `csrlsvluHost` is host-only and out of scope. Each method
535/// solves `A x = b` in one shot for an `m × m` CSR matrix.
536#[cfg(feature = "cusolver-sp")]
537pub trait SparseSolverScalar: SolverSupported {
538    /// `csrlsvchol`: SPD CSR system via Cholesky.
539    unsafe fn csrlsvchol(
540        handle: cs::cusolverSpHandle_t,
541        m: i32,
542        nnz: i32,
543        descr_a: cs::cusparseMatDescr_t,
544        csr_val: *const Self,
545        csr_row_ptr: *const i32,
546        csr_col_ind: *const i32,
547        b: *const Self,
548        tol: f64,
549        reorder: i32,
550        x: *mut Self,
551        singularity: *mut i32,
552    ) -> cs::cusolverStatus_t;
553
554    /// `csrlsvqr`: general CSR system via QR.
555    unsafe fn csrlsvqr(
556        handle: cs::cusolverSpHandle_t,
557        m: i32,
558        nnz: i32,
559        descr_a: cs::cusparseMatDescr_t,
560        csr_val: *const Self,
561        csr_row_ptr: *const i32,
562        csr_col_ind: *const i32,
563        b: *const Self,
564        tol: f64,
565        reorder: i32,
566        x: *mut Self,
567        singularity: *mut i32,
568    ) -> cs::cusolverStatus_t;
569}
570
571#[cfg(feature = "cusolver-sp")]
572macro_rules! impl_sparse_solver_scalar {
573    ($T:ty, $tol:ty, chol: $chol:ident, qr: $qr:ident) => {
574        impl SparseSolverScalar for $T {
575            unsafe fn csrlsvchol(
576                handle: cs::cusolverSpHandle_t,
577                m: i32,
578                nnz: i32,
579                descr_a: cs::cusparseMatDescr_t,
580                csr_val: *const Self,
581                csr_row_ptr: *const i32,
582                csr_col_ind: *const i32,
583                b: *const Self,
584                tol: f64,
585                reorder: i32,
586                x: *mut Self,
587                singularity: *mut i32,
588            ) -> cs::cusolverStatus_t {
589                cs::$chol(
590                    handle,
591                    m,
592                    nnz,
593                    descr_a,
594                    csr_val,
595                    csr_row_ptr,
596                    csr_col_ind,
597                    b,
598                    tol as $tol,
599                    reorder,
600                    x,
601                    singularity,
602                )
603            }
604
605            unsafe fn csrlsvqr(
606                handle: cs::cusolverSpHandle_t,
607                m: i32,
608                nnz: i32,
609                descr_a: cs::cusparseMatDescr_t,
610                csr_val: *const Self,
611                csr_row_ptr: *const i32,
612                csr_col_ind: *const i32,
613                b: *const Self,
614                tol: f64,
615                reorder: i32,
616                x: *mut Self,
617                singularity: *mut i32,
618            ) -> cs::cusolverStatus_t {
619                cs::$qr(
620                    handle,
621                    m,
622                    nnz,
623                    descr_a,
624                    csr_val,
625                    csr_row_ptr,
626                    csr_col_ind,
627                    b,
628                    tol as $tol,
629                    reorder,
630                    x,
631                    singularity,
632                )
633            }
634        }
635    };
636}
637
638#[cfg(feature = "cusolver-sp")]
639impl_sparse_solver_scalar!(f32, f32, chol: cusolverSpScsrlsvchol, qr: cusolverSpScsrlsvqr);
640#[cfg(feature = "cusolver-sp")]
641impl_sparse_solver_scalar!(f64, f64, chol: cusolverSpDcsrlsvchol, qr: cusolverSpDcsrlsvqr);