1use cudarc::cusolver::sys as cs;
15
16use crate::dtype::SolverSupported;
17use crate::error::GpuError;
18
19pub const LIB: &str = "cusolver";
20
21pub fn status_to_result(status: cs::cusolverStatus_t, op: &'static str) -> Result<(), GpuError> {
23 if status == cs::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
24 Ok(())
25 } else {
26 Err(GpuError::LibraryError {
27 lib: LIB,
28 msg: format!("{op}: {status:?}"),
29 })
30 }
31}
32
33pub trait SolverScalar: SolverSupported {
51 unsafe fn geqrf_buffer_size(
53 handle: cs::cusolverDnHandle_t,
54 m: i32,
55 n: i32,
56 a: *mut Self,
57 lda: i32,
58 lwork: *mut i32,
59 ) -> cs::cusolverStatus_t;
60 unsafe fn geqrf(
61 handle: cs::cusolverDnHandle_t,
62 m: i32,
63 n: i32,
64 a: *mut Self,
65 lda: i32,
66 tau: *mut Self,
67 work: *mut Self,
68 lwork: i32,
69 info: *mut i32,
70 ) -> cs::cusolverStatus_t;
71
72 unsafe fn getrf_buffer_size(
74 handle: cs::cusolverDnHandle_t,
75 m: i32,
76 n: i32,
77 a: *mut Self,
78 lda: i32,
79 lwork: *mut i32,
80 ) -> cs::cusolverStatus_t;
81 unsafe fn getrf(
82 handle: cs::cusolverDnHandle_t,
83 m: i32,
84 n: i32,
85 a: *mut Self,
86 lda: i32,
87 work: *mut Self,
88 ipiv: *mut i32,
89 info: *mut i32,
90 ) -> cs::cusolverStatus_t;
91 unsafe fn getrs(
92 handle: cs::cusolverDnHandle_t,
93 trans: cs::cublasOperation_t,
94 n: i32,
95 nrhs: i32,
96 a: *const Self,
97 lda: i32,
98 ipiv: *const i32,
99 b: *mut Self,
100 ldb: i32,
101 info: *mut i32,
102 ) -> cs::cusolverStatus_t;
103
104 unsafe fn potrf_buffer_size(
106 handle: cs::cusolverDnHandle_t,
107 uplo: cs::cublasFillMode_t,
108 n: i32,
109 a: *mut Self,
110 lda: i32,
111 lwork: *mut i32,
112 ) -> cs::cusolverStatus_t;
113 unsafe fn potrf(
114 handle: cs::cusolverDnHandle_t,
115 uplo: cs::cublasFillMode_t,
116 n: i32,
117 a: *mut Self,
118 lda: i32,
119 work: *mut Self,
120 lwork: i32,
121 info: *mut i32,
122 ) -> cs::cusolverStatus_t;
123
124 unsafe fn gesvd_buffer_size(
126 handle: cs::cusolverDnHandle_t,
127 m: i32,
128 n: i32,
129 lwork: *mut i32,
130 ) -> cs::cusolverStatus_t;
131 unsafe fn gesvd(
132 handle: cs::cusolverDnHandle_t,
133 jobu: i8,
134 jobvt: i8,
135 m: i32,
136 n: i32,
137 a: *mut Self,
138 lda: i32,
139 s: *mut Self,
140 u: *mut Self,
141 ldu: i32,
142 vt: *mut Self,
143 ldvt: i32,
144 work: *mut Self,
145 lwork: i32,
146 rwork: *mut Self,
147 info: *mut i32,
148 ) -> cs::cusolverStatus_t;
149
150 unsafe fn syevd_buffer_size(
152 handle: cs::cusolverDnHandle_t,
153 jobz: cs::cusolverEigMode_t,
154 uplo: cs::cublasFillMode_t,
155 n: i32,
156 a: *const Self,
157 lda: i32,
158 w: *const Self,
159 lwork: *mut i32,
160 ) -> cs::cusolverStatus_t;
161 unsafe fn syevd(
162 handle: cs::cusolverDnHandle_t,
163 jobz: cs::cusolverEigMode_t,
164 uplo: cs::cublasFillMode_t,
165 n: i32,
166 a: *mut Self,
167 lda: i32,
168 w: *mut Self,
169 work: *mut Self,
170 lwork: i32,
171 info: *mut i32,
172 ) -> cs::cusolverStatus_t;
173
174 unsafe fn sygvd_buffer_size(
179 handle: cs::cusolverDnHandle_t,
180 itype: cs::cusolverEigType_t,
181 jobz: cs::cusolverEigMode_t,
182 uplo: cs::cublasFillMode_t,
183 n: i32,
184 a: *const Self,
185 lda: i32,
186 b: *const Self,
187 ldb: i32,
188 w: *const Self,
189 lwork: *mut i32,
190 ) -> cs::cusolverStatus_t;
191 unsafe fn sygvd(
192 handle: cs::cusolverDnHandle_t,
193 itype: cs::cusolverEigType_t,
194 jobz: cs::cusolverEigMode_t,
195 uplo: cs::cublasFillMode_t,
196 n: i32,
197 a: *mut Self,
198 lda: i32,
199 b: *mut Self,
200 ldb: i32,
201 w: *mut Self,
202 work: *mut Self,
203 lwork: i32,
204 info: *mut i32,
205 ) -> cs::cusolverStatus_t;
206
207 unsafe fn potrf_batched(
211 handle: cs::cusolverDnHandle_t,
212 uplo: cs::cublasFillMode_t,
213 n: i32,
214 a_array: *mut *mut Self,
215 lda: i32,
216 info_array: *mut i32,
217 batch_size: i32,
218 ) -> cs::cusolverStatus_t;
219
220 unsafe fn gesvdj_batched_buffer_size(
222 handle: cs::cusolverDnHandle_t,
223 jobz: cs::cusolverEigMode_t,
224 m: i32,
225 n: i32,
226 a: *const Self,
227 lda: i32,
228 s: *const Self,
229 u: *const Self,
230 ldu: i32,
231 v: *const Self,
232 ldv: i32,
233 lwork: *mut i32,
234 params: cs::gesvdjInfo_t,
235 batch_size: i32,
236 ) -> cs::cusolverStatus_t;
237 unsafe fn gesvdj_batched(
238 handle: cs::cusolverDnHandle_t,
239 jobz: cs::cusolverEigMode_t,
240 m: i32,
241 n: i32,
242 a: *mut Self,
243 lda: i32,
244 s: *mut Self,
245 u: *mut Self,
246 ldu: i32,
247 v: *mut Self,
248 ldv: i32,
249 work: *mut Self,
250 lwork: i32,
251 info: *mut i32,
252 params: cs::gesvdjInfo_t,
253 batch_size: i32,
254 ) -> cs::cusolverStatus_t;
255}
256
257macro_rules! impl_solver_scalar {
258 (
259 $T:ty,
260 geqrf: $geqrf:ident, $geqrf_bs:ident;
261 getrf: $getrf:ident, $getrf_bs:ident, $getrs:ident;
262 potrf: $potrf:ident, $potrf_bs:ident;
263 gesvd: $gesvd:ident, $gesvd_bs:ident;
264 syevd: $syevd:ident, $syevd_bs:ident;
265 sygvd: $sygvd:ident, $sygvd_bs:ident;
266 potrf_batched: $potrf_b:ident;
267 gesvdj_batched: $gesvdj_b:ident, $gesvdj_b_bs:ident;
268 ) => {
269 impl SolverScalar for $T {
270 unsafe fn geqrf_buffer_size(
271 handle: cs::cusolverDnHandle_t,
272 m: i32,
273 n: i32,
274 a: *mut Self,
275 lda: i32,
276 lwork: *mut i32,
277 ) -> cs::cusolverStatus_t {
278 cs::$geqrf_bs(handle, m, n, a, lda, lwork)
279 }
280 unsafe fn geqrf(
281 handle: cs::cusolverDnHandle_t,
282 m: i32,
283 n: i32,
284 a: *mut Self,
285 lda: i32,
286 tau: *mut Self,
287 work: *mut Self,
288 lwork: i32,
289 info: *mut i32,
290 ) -> cs::cusolverStatus_t {
291 cs::$geqrf(handle, m, n, a, lda, tau, work, lwork, info)
292 }
293
294 unsafe fn getrf_buffer_size(
295 handle: cs::cusolverDnHandle_t,
296 m: i32,
297 n: i32,
298 a: *mut Self,
299 lda: i32,
300 lwork: *mut i32,
301 ) -> cs::cusolverStatus_t {
302 cs::$getrf_bs(handle, m, n, a, lda, lwork)
303 }
304 unsafe fn getrf(
305 handle: cs::cusolverDnHandle_t,
306 m: i32,
307 n: i32,
308 a: *mut Self,
309 lda: i32,
310 work: *mut Self,
311 ipiv: *mut i32,
312 info: *mut i32,
313 ) -> cs::cusolverStatus_t {
314 cs::$getrf(handle, m, n, a, lda, work, ipiv, info)
315 }
316 unsafe fn getrs(
317 handle: cs::cusolverDnHandle_t,
318 trans: cs::cublasOperation_t,
319 n: i32,
320 nrhs: i32,
321 a: *const Self,
322 lda: i32,
323 ipiv: *const i32,
324 b: *mut Self,
325 ldb: i32,
326 info: *mut i32,
327 ) -> cs::cusolverStatus_t {
328 cs::$getrs(handle, trans, n, nrhs, a, lda, ipiv, b, ldb, info)
329 }
330
331 unsafe fn potrf_buffer_size(
332 handle: cs::cusolverDnHandle_t,
333 uplo: cs::cublasFillMode_t,
334 n: i32,
335 a: *mut Self,
336 lda: i32,
337 lwork: *mut i32,
338 ) -> cs::cusolverStatus_t {
339 cs::$potrf_bs(handle, uplo, n, a, lda, lwork)
340 }
341 unsafe fn potrf(
342 handle: cs::cusolverDnHandle_t,
343 uplo: cs::cublasFillMode_t,
344 n: i32,
345 a: *mut Self,
346 lda: i32,
347 work: *mut Self,
348 lwork: i32,
349 info: *mut i32,
350 ) -> cs::cusolverStatus_t {
351 cs::$potrf(handle, uplo, n, a, lda, work, lwork, info)
352 }
353
354 unsafe fn gesvd_buffer_size(
355 handle: cs::cusolverDnHandle_t,
356 m: i32,
357 n: i32,
358 lwork: *mut i32,
359 ) -> cs::cusolverStatus_t {
360 cs::$gesvd_bs(handle, m, n, lwork)
361 }
362 unsafe fn gesvd(
363 handle: cs::cusolverDnHandle_t,
364 jobu: i8,
365 jobvt: i8,
366 m: i32,
367 n: i32,
368 a: *mut Self,
369 lda: i32,
370 s: *mut Self,
371 u: *mut Self,
372 ldu: i32,
373 vt: *mut Self,
374 ldvt: i32,
375 work: *mut Self,
376 lwork: i32,
377 rwork: *mut Self,
378 info: *mut i32,
379 ) -> cs::cusolverStatus_t {
380 cs::$gesvd(
381 handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork,
382 info,
383 )
384 }
385
386 unsafe fn syevd_buffer_size(
387 handle: cs::cusolverDnHandle_t,
388 jobz: cs::cusolverEigMode_t,
389 uplo: cs::cublasFillMode_t,
390 n: i32,
391 a: *const Self,
392 lda: i32,
393 w: *const Self,
394 lwork: *mut i32,
395 ) -> cs::cusolverStatus_t {
396 cs::$syevd_bs(handle, jobz, uplo, n, a, lda, w, lwork)
397 }
398 unsafe fn syevd(
399 handle: cs::cusolverDnHandle_t,
400 jobz: cs::cusolverEigMode_t,
401 uplo: cs::cublasFillMode_t,
402 n: i32,
403 a: *mut Self,
404 lda: i32,
405 w: *mut Self,
406 work: *mut Self,
407 lwork: i32,
408 info: *mut i32,
409 ) -> cs::cusolverStatus_t {
410 cs::$syevd(handle, jobz, uplo, n, a, lda, w, work, lwork, info)
411 }
412
413 unsafe fn sygvd_buffer_size(
414 handle: cs::cusolverDnHandle_t,
415 itype: cs::cusolverEigType_t,
416 jobz: cs::cusolverEigMode_t,
417 uplo: cs::cublasFillMode_t,
418 n: i32,
419 a: *const Self,
420 lda: i32,
421 b: *const Self,
422 ldb: i32,
423 w: *const Self,
424 lwork: *mut i32,
425 ) -> cs::cusolverStatus_t {
426 cs::$sygvd_bs(handle, itype, jobz, uplo, n, a, lda, b, ldb, w, lwork)
427 }
428 unsafe fn sygvd(
429 handle: cs::cusolverDnHandle_t,
430 itype: cs::cusolverEigType_t,
431 jobz: cs::cusolverEigMode_t,
432 uplo: cs::cublasFillMode_t,
433 n: i32,
434 a: *mut Self,
435 lda: i32,
436 b: *mut Self,
437 ldb: i32,
438 w: *mut Self,
439 work: *mut Self,
440 lwork: i32,
441 info: *mut i32,
442 ) -> cs::cusolverStatus_t {
443 cs::$sygvd(
444 handle, itype, jobz, uplo, n, a, lda, b, ldb, w, work, lwork, info,
445 )
446 }
447
448 unsafe fn potrf_batched(
449 handle: cs::cusolverDnHandle_t,
450 uplo: cs::cublasFillMode_t,
451 n: i32,
452 a_array: *mut *mut Self,
453 lda: i32,
454 info_array: *mut i32,
455 batch_size: i32,
456 ) -> cs::cusolverStatus_t {
457 cs::$potrf_b(handle, uplo, n, a_array, lda, info_array, batch_size)
458 }
459
460 unsafe fn gesvdj_batched_buffer_size(
461 handle: cs::cusolverDnHandle_t,
462 jobz: cs::cusolverEigMode_t,
463 m: i32,
464 n: i32,
465 a: *const Self,
466 lda: i32,
467 s: *const Self,
468 u: *const Self,
469 ldu: i32,
470 v: *const Self,
471 ldv: i32,
472 lwork: *mut i32,
473 params: cs::gesvdjInfo_t,
474 batch_size: i32,
475 ) -> cs::cusolverStatus_t {
476 cs::$gesvdj_b_bs(
477 handle, jobz, m, n, a, lda, s, u, ldu, v, ldv, lwork, params, batch_size,
478 )
479 }
480 unsafe fn gesvdj_batched(
481 handle: cs::cusolverDnHandle_t,
482 jobz: cs::cusolverEigMode_t,
483 m: i32,
484 n: i32,
485 a: *mut Self,
486 lda: i32,
487 s: *mut Self,
488 u: *mut Self,
489 ldu: i32,
490 v: *mut Self,
491 ldv: i32,
492 work: *mut Self,
493 lwork: i32,
494 info: *mut i32,
495 params: cs::gesvdjInfo_t,
496 batch_size: i32,
497 ) -> cs::cusolverStatus_t {
498 cs::$gesvdj_b(
499 handle, jobz, m, n, a, lda, s, u, ldu, v, ldv, work, lwork, info, params,
500 batch_size,
501 )
502 }
503 }
504 };
505}
506
507impl_solver_scalar!(
508 f32,
509 geqrf: cusolverDnSgeqrf, cusolverDnSgeqrf_bufferSize;
510 getrf: cusolverDnSgetrf, cusolverDnSgetrf_bufferSize, cusolverDnSgetrs;
511 potrf: cusolverDnSpotrf, cusolverDnSpotrf_bufferSize;
512 gesvd: cusolverDnSgesvd, cusolverDnSgesvd_bufferSize;
513 syevd: cusolverDnSsyevd, cusolverDnSsyevd_bufferSize;
514 sygvd: cusolverDnSsygvd, cusolverDnSsygvd_bufferSize;
515 potrf_batched: cusolverDnSpotrfBatched;
516 gesvdj_batched: cusolverDnSgesvdjBatched, cusolverDnSgesvdjBatched_bufferSize;
517);
518
519impl_solver_scalar!(
520 f64,
521 geqrf: cusolverDnDgeqrf, cusolverDnDgeqrf_bufferSize;
522 getrf: cusolverDnDgetrf, cusolverDnDgetrf_bufferSize, cusolverDnDgetrs;
523 potrf: cusolverDnDpotrf, cusolverDnDpotrf_bufferSize;
524 gesvd: cusolverDnDgesvd, cusolverDnDgesvd_bufferSize;
525 syevd: cusolverDnDsyevd, cusolverDnDsyevd_bufferSize;
526 sygvd: cusolverDnDsygvd, cusolverDnDsygvd_bufferSize;
527 potrf_batched: cusolverDnDpotrfBatched;
528 gesvdj_batched: cusolverDnDgesvdjBatched, cusolverDnDgesvdjBatched_bufferSize;
529);
530
531#[cfg(feature = "cusolver-sp")]
537pub trait SparseSolverScalar: SolverSupported {
538 unsafe fn csrlsvchol(
540 handle: cs::cusolverSpHandle_t,
541 m: i32,
542 nnz: i32,
543 descr_a: cs::cusparseMatDescr_t,
544 csr_val: *const Self,
545 csr_row_ptr: *const i32,
546 csr_col_ind: *const i32,
547 b: *const Self,
548 tol: f64,
549 reorder: i32,
550 x: *mut Self,
551 singularity: *mut i32,
552 ) -> cs::cusolverStatus_t;
553
554 unsafe fn csrlsvqr(
556 handle: cs::cusolverSpHandle_t,
557 m: i32,
558 nnz: i32,
559 descr_a: cs::cusparseMatDescr_t,
560 csr_val: *const Self,
561 csr_row_ptr: *const i32,
562 csr_col_ind: *const i32,
563 b: *const Self,
564 tol: f64,
565 reorder: i32,
566 x: *mut Self,
567 singularity: *mut i32,
568 ) -> cs::cusolverStatus_t;
569}
570
571#[cfg(feature = "cusolver-sp")]
572macro_rules! impl_sparse_solver_scalar {
573 ($T:ty, $tol:ty, chol: $chol:ident, qr: $qr:ident) => {
574 impl SparseSolverScalar for $T {
575 unsafe fn csrlsvchol(
576 handle: cs::cusolverSpHandle_t,
577 m: i32,
578 nnz: i32,
579 descr_a: cs::cusparseMatDescr_t,
580 csr_val: *const Self,
581 csr_row_ptr: *const i32,
582 csr_col_ind: *const i32,
583 b: *const Self,
584 tol: f64,
585 reorder: i32,
586 x: *mut Self,
587 singularity: *mut i32,
588 ) -> cs::cusolverStatus_t {
589 cs::$chol(
590 handle,
591 m,
592 nnz,
593 descr_a,
594 csr_val,
595 csr_row_ptr,
596 csr_col_ind,
597 b,
598 tol as $tol,
599 reorder,
600 x,
601 singularity,
602 )
603 }
604
605 unsafe fn csrlsvqr(
606 handle: cs::cusolverSpHandle_t,
607 m: i32,
608 nnz: i32,
609 descr_a: cs::cusparseMatDescr_t,
610 csr_val: *const Self,
611 csr_row_ptr: *const i32,
612 csr_col_ind: *const i32,
613 b: *const Self,
614 tol: f64,
615 reorder: i32,
616 x: *mut Self,
617 singularity: *mut i32,
618 ) -> cs::cusolverStatus_t {
619 cs::$qr(
620 handle,
621 m,
622 nnz,
623 descr_a,
624 csr_val,
625 csr_row_ptr,
626 csr_col_ind,
627 b,
628 tol as $tol,
629 reorder,
630 x,
631 singularity,
632 )
633 }
634 }
635 };
636}
637
638#[cfg(feature = "cusolver-sp")]
639impl_sparse_solver_scalar!(f32, f32, chol: cusolverSpScsrlsvchol, qr: cusolverSpScsrlsvqr);
640#[cfg(feature = "cusolver-sp")]
641impl_sparse_solver_scalar!(f64, f64, chol: cusolverSpDcsrlsvchol, qr: cusolverSpDcsrlsvqr);