1use std::marker::PhantomData;
7use std::sync::Arc;
8
9use cudarc::cusolver::sys as cs;
10use cudarc::driver::{DevicePtr, DevicePtrMut};
11use tokio::sync::oneshot;
12
13use crate::dtype::SolverSupported;
14use crate::error::GpuError;
15use crate::gpu_ref::GpuRef;
16use crate::kernel::envelope;
17use crate::sys::cusolver::{status_to_result, SolverScalar, LIB};
18
19use super::workspace::{check_info, ensure_workspace_bytes, lwork_bytes};
20use super::{SolverCells, SolverDispatch, Uplo};
21
22pub struct QrRequest<T: SolverSupported> {
27 pub a: GpuRef<T>,
28 pub m: i32,
29 pub n: i32,
30 pub tau: GpuRef<T>,
31 pub reply: oneshot::Sender<Result<(), GpuError>>,
32}
33
34impl<T> SolverDispatch for QrRequest<T>
35where
36 T: SolverSupported + SolverScalar,
37{
38 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
39 let QrRequest {
40 a,
41 m,
42 n,
43 tau,
44 reply,
45 } = *self;
46 run_qr::<T>(cells, a, m, n, tau, reply);
47 }
48
49 fn dispatch_mock(self: Box<Self>) {
50 let _ = self.reply.send(Err(GpuError::Unrecoverable(
51 "SolverActor in mock mode".into(),
52 )));
53 }
54}
55
56fn run_qr<T: SolverScalar>(
57 cells: SolverCells<'_>,
58 a: GpuRef<T>,
59 m: i32,
60 n: i32,
61 tau: GpuRef<T>,
62 reply: oneshot::Sender<Result<(), GpuError>>,
63) {
64 let SolverCells {
65 handle,
66 stream,
67 completion,
68 workspace,
69 info,
70 ..
71 } = cells;
72
73 let (a_slice, tau_slice) = match envelope::access_all_2(&a, &tau) {
74 Ok(t) => t,
75 Err(e) => {
76 let _ = reply.send(Err(e));
77 return;
78 }
79 };
80 let mut a_owned = match Arc::try_unwrap(a_slice) {
81 Ok(s) => s,
82 Err(_) => {
83 let _ = reply.send(Err(GpuError::Unrecoverable(
84 "QR a has multiple live references".into(),
85 )));
86 return;
87 }
88 };
89 let mut tau_owned = match Arc::try_unwrap(tau_slice) {
90 Ok(s) => s,
91 Err(_) => {
92 let _ = reply.send(Err(GpuError::Unrecoverable(
93 "QR tau has multiple live references".into(),
94 )));
95 return;
96 }
97 };
98
99 let mut lwork = 0i32;
100 {
101 let h = handle.lock();
102 let (a_ptr, _g) = a_owned.device_ptr_mut(stream);
103 let status = unsafe {
104 T::geqrf_buffer_size(h.0.cu(), m, n, a_ptr as *mut T, m, &mut lwork as *mut _)
105 };
106 drop(_g);
107 if let Err(e) = status_to_result(status, "geqrf_bufferSize") {
108 let _ = reply.send(Err(e));
109 return;
110 }
111 }
112
113 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
114 let _ = reply.send(Err(e));
115 return;
116 }
117
118 a.record_write(stream);
119 tau.record_write(stream);
120
121 let stream_for_check = stream.clone();
122 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
123 let h = handle.lock();
124 let mut ws = workspace.lock();
125 let mut info_lock = info.lock();
126 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
127 let (tau_ptr, _g2) = tau_owned.device_ptr_mut(&stream_for_check);
128 let ws_slice = ws.as_mut().expect("workspace ensured");
129 let (ws_ptr, _g3) = ws_slice.device_ptr_mut(&stream_for_check);
130 let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
131 let status = unsafe {
132 T::geqrf(
133 h.0.cu(),
134 m,
135 n,
136 a_ptr as *mut T,
137 m,
138 tau_ptr as *mut T,
139 ws_ptr as *mut T,
140 lwork,
141 info_ptr as *mut i32,
142 )
143 };
144 drop((_g1, _g2, _g3, _g4));
145 status_to_result(status, "geqrf")?;
146 check_info(info, &stream_for_check, "geqrf")?;
147 Ok((a_owned, tau_owned))
148 });
149}
150
151pub struct LuRequest<T: SolverSupported> {
156 pub a: GpuRef<T>,
157 pub m: i32,
158 pub n: i32,
159 pub ipiv: GpuRef<i32>,
160 pub reply: oneshot::Sender<Result<(), GpuError>>,
161}
162
163pub struct LuSolveRequest<T: SolverSupported> {
164 pub lu: GpuRef<T>,
165 pub ipiv: GpuRef<i32>,
166 pub b: GpuRef<T>,
167 pub n: i32,
168 pub nrhs: i32,
169 pub trans: bool,
170 pub reply: oneshot::Sender<Result<(), GpuError>>,
171}
172
173impl<T> SolverDispatch for LuRequest<T>
174where
175 T: SolverSupported + SolverScalar,
176{
177 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
178 let LuRequest {
179 a,
180 m,
181 n,
182 ipiv,
183 reply,
184 } = *self;
185 run_lu::<T>(cells, a, m, n, ipiv, reply);
186 }
187
188 fn dispatch_mock(self: Box<Self>) {
189 let _ = self.reply.send(Err(GpuError::Unrecoverable(
190 "SolverActor in mock mode".into(),
191 )));
192 }
193}
194
195impl<T> SolverDispatch for LuSolveRequest<T>
196where
197 T: SolverSupported + SolverScalar,
198{
199 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
200 let LuSolveRequest {
201 lu,
202 ipiv,
203 b,
204 n,
205 nrhs,
206 trans,
207 reply,
208 } = *self;
209 run_lu_solve::<T>(cells, lu, ipiv, b, n, nrhs, trans, reply);
210 }
211
212 fn dispatch_mock(self: Box<Self>) {
213 let _ = self.reply.send(Err(GpuError::Unrecoverable(
214 "SolverActor in mock mode".into(),
215 )));
216 }
217}
218
219fn run_lu<T: SolverScalar>(
220 cells: SolverCells<'_>,
221 a: GpuRef<T>,
222 m: i32,
223 n: i32,
224 ipiv: GpuRef<i32>,
225 reply: oneshot::Sender<Result<(), GpuError>>,
226) {
227 let SolverCells {
228 handle,
229 stream,
230 completion,
231 workspace,
232 info,
233 ..
234 } = cells;
235
236 let a_slice = match a.access() {
237 Ok(s) => s.clone(),
238 Err(e) => {
239 let _ = reply.send(Err(e));
240 return;
241 }
242 };
243 let ipiv_slice = match ipiv.access() {
244 Ok(s) => s.clone(),
245 Err(e) => {
246 let _ = reply.send(Err(e));
247 return;
248 }
249 };
250 let mut a_owned = match Arc::try_unwrap(a_slice) {
251 Ok(s) => s,
252 Err(_) => {
253 let _ = reply.send(Err(GpuError::Unrecoverable(
254 "LU a has multiple live references".into(),
255 )));
256 return;
257 }
258 };
259 let mut ipiv_owned = match Arc::try_unwrap(ipiv_slice) {
260 Ok(s) => s,
261 Err(_) => {
262 let _ = reply.send(Err(GpuError::Unrecoverable(
263 "LU ipiv has multiple live references".into(),
264 )));
265 return;
266 }
267 };
268
269 let mut lwork = 0i32;
270 {
271 let h = handle.lock();
272 let (a_ptr, _g) = a_owned.device_ptr_mut(stream);
273 let status = unsafe {
274 T::getrf_buffer_size(h.0.cu(), m, n, a_ptr as *mut T, m, &mut lwork as *mut _)
275 };
276 drop(_g);
277 if let Err(e) = status_to_result(status, "getrf_bufferSize") {
278 let _ = reply.send(Err(e));
279 return;
280 }
281 }
282 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
283 let _ = reply.send(Err(e));
284 return;
285 }
286
287 a.record_write(stream);
288 ipiv.record_write(stream);
289
290 let stream_for_check = stream.clone();
291 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
292 let h = handle.lock();
293 let mut ws = workspace.lock();
294 let mut info_lock = info.lock();
295 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
296 let (ipiv_ptr, _g2) = ipiv_owned.device_ptr_mut(&stream_for_check);
297 let ws_slice = ws.as_mut().expect("workspace ensured");
298 let (ws_ptr, _g3) = ws_slice.device_ptr_mut(&stream_for_check);
299 let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
300 let status = unsafe {
301 T::getrf(
302 h.0.cu(),
303 m,
304 n,
305 a_ptr as *mut T,
306 m,
307 ws_ptr as *mut T,
308 ipiv_ptr as *mut i32,
309 info_ptr as *mut i32,
310 )
311 };
312 drop((_g1, _g2, _g3, _g4));
313 status_to_result(status, "getrf")?;
314 check_info(info, &stream_for_check, "getrf")?;
315 Ok((a_owned, ipiv_owned))
316 });
317}
318
319fn run_lu_solve<T: SolverScalar>(
320 cells: SolverCells<'_>,
321 lu: GpuRef<T>,
322 ipiv: GpuRef<i32>,
323 b: GpuRef<T>,
324 n: i32,
325 nrhs: i32,
326 trans: bool,
327 reply: oneshot::Sender<Result<(), GpuError>>,
328) {
329 let SolverCells {
330 handle,
331 stream,
332 completion,
333 info,
334 ..
335 } = cells;
336
337 let lu_slice = match lu.access() {
338 Ok(s) => s.clone(),
339 Err(e) => {
340 let _ = reply.send(Err(e));
341 return;
342 }
343 };
344 let ipiv_slice = match ipiv.access() {
345 Ok(s) => s.clone(),
346 Err(e) => {
347 let _ = reply.send(Err(e));
348 return;
349 }
350 };
351 let b_slice = match b.access() {
352 Ok(s) => s.clone(),
353 Err(e) => {
354 let _ = reply.send(Err(e));
355 return;
356 }
357 };
358 let mut b_owned = match Arc::try_unwrap(b_slice) {
359 Ok(s) => s,
360 Err(_) => {
361 let _ = reply.send(Err(GpuError::Unrecoverable(
362 "LU b has multiple live references".into(),
363 )));
364 return;
365 }
366 };
367 let trans_op = if trans {
368 cs::cublasOperation_t::CUBLAS_OP_T
369 } else {
370 cs::cublasOperation_t::CUBLAS_OP_N
371 };
372 b.record_write(stream);
373
374 let stream_for_check = stream.clone();
375 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
376 let h = handle.lock();
377 let mut info_lock = info.lock();
378 let (lu_ptr, _g1) = lu_slice.device_ptr(&stream_for_check);
379 let (ipiv_ptr, _g2) = ipiv_slice.device_ptr(&stream_for_check);
380 let (b_ptr, _g3) = b_owned.device_ptr_mut(&stream_for_check);
381 let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
382 let status = unsafe {
383 T::getrs(
384 h.0.cu(),
385 trans_op,
386 n,
387 nrhs,
388 lu_ptr as *const T,
389 n,
390 ipiv_ptr as *const i32,
391 b_ptr as *mut T,
392 n,
393 info_ptr as *mut i32,
394 )
395 };
396 drop((_g1, _g2, _g3, _g4));
397 status_to_result(status, "getrs")?;
398 check_info(info, &stream_for_check, "getrs")?;
399 Ok((lu_slice, ipiv_slice, b_owned))
400 });
401}
402
403pub struct CholeskyRequest<T: SolverSupported> {
408 pub a: GpuRef<T>,
409 pub n: i32,
410 pub uplo: Uplo,
411 pub reply: oneshot::Sender<Result<(), GpuError>>,
412}
413
414impl<T> SolverDispatch for CholeskyRequest<T>
415where
416 T: SolverSupported + SolverScalar,
417{
418 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
419 let CholeskyRequest { a, n, uplo, reply } = *self;
420 run_cholesky::<T>(cells, a, n, uplo, reply);
421 }
422
423 fn dispatch_mock(self: Box<Self>) {
424 let _ = self.reply.send(Err(GpuError::Unrecoverable(
425 "SolverActor in mock mode".into(),
426 )));
427 }
428}
429
430fn run_cholesky<T: SolverScalar>(
431 cells: SolverCells<'_>,
432 a: GpuRef<T>,
433 n: i32,
434 uplo: Uplo,
435 reply: oneshot::Sender<Result<(), GpuError>>,
436) {
437 let SolverCells {
438 handle,
439 stream,
440 completion,
441 workspace,
442 info,
443 ..
444 } = cells;
445
446 let a_slice = match a.access() {
447 Ok(s) => s.clone(),
448 Err(e) => {
449 let _ = reply.send(Err(e));
450 return;
451 }
452 };
453 let mut a_owned = match Arc::try_unwrap(a_slice) {
454 Ok(s) => s,
455 Err(_) => {
456 let _ = reply.send(Err(GpuError::Unrecoverable(
457 "Cholesky a has multiple live references".into(),
458 )));
459 return;
460 }
461 };
462 let fill = uplo.as_cusolver_fill();
463
464 let mut lwork = 0i32;
465 {
466 let h = handle.lock();
467 let (a_ptr, _g) = a_owned.device_ptr_mut(stream);
468 let status = unsafe {
469 T::potrf_buffer_size(h.0.cu(), fill, n, a_ptr as *mut T, n, &mut lwork as *mut _)
470 };
471 drop(_g);
472 if let Err(e) = status_to_result(status, "potrf_bufferSize") {
473 let _ = reply.send(Err(e));
474 return;
475 }
476 }
477 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
478 let _ = reply.send(Err(e));
479 return;
480 }
481 a.record_write(stream);
482
483 let stream_for_check = stream.clone();
484 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
485 let h = handle.lock();
486 let mut ws = workspace.lock();
487 let mut info_lock = info.lock();
488 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
489 let ws_slice = ws.as_mut().expect("workspace ensured");
490 let (ws_ptr, _g2) = ws_slice.device_ptr_mut(&stream_for_check);
491 let (info_ptr, _g3) = info_lock.device_ptr_mut(&stream_for_check);
492 let status = unsafe {
493 T::potrf(
494 h.0.cu(),
495 fill,
496 n,
497 a_ptr as *mut T,
498 n,
499 ws_ptr as *mut T,
500 lwork,
501 info_ptr as *mut i32,
502 )
503 };
504 drop((_g1, _g2, _g3));
505 status_to_result(status, "potrf")?;
506 check_info(info, &stream_for_check, "potrf")?;
507 Ok((a_owned,))
508 });
509}
510
511pub struct SvdRequest<T: SolverSupported> {
516 pub a: GpuRef<T>,
517 pub m: i32,
518 pub n: i32,
519 pub s: GpuRef<T>,
520 pub u: Option<GpuRef<T>>,
521 pub vt: Option<GpuRef<T>>,
522 pub reply: oneshot::Sender<Result<(), GpuError>>,
523}
524
525impl<T> SolverDispatch for SvdRequest<T>
526where
527 T: SolverSupported + SolverScalar,
528{
529 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
530 let SvdRequest {
531 a,
532 m,
533 n,
534 s,
535 u,
536 vt,
537 reply,
538 } = *self;
539 run_svd::<T>(cells, a, m, n, s, u, vt, reply);
540 }
541
542 fn dispatch_mock(self: Box<Self>) {
543 let _ = self.reply.send(Err(GpuError::Unrecoverable(
544 "SolverActor in mock mode".into(),
545 )));
546 }
547}
548
549fn run_svd<T: SolverScalar>(
550 cells: SolverCells<'_>,
551 a: GpuRef<T>,
552 m: i32,
553 n: i32,
554 s: GpuRef<T>,
555 u: Option<GpuRef<T>>,
556 vt: Option<GpuRef<T>>,
557 reply: oneshot::Sender<Result<(), GpuError>>,
558) {
559 let SolverCells {
560 handle,
561 stream,
562 completion,
563 workspace,
564 info,
565 ..
566 } = cells;
567
568 let a_slice = match a.access() {
569 Ok(sl) => sl.clone(),
570 Err(e) => {
571 let _ = reply.send(Err(e));
572 return;
573 }
574 };
575 let s_slice = match s.access() {
576 Ok(sl) => sl.clone(),
577 Err(e) => {
578 let _ = reply.send(Err(e));
579 return;
580 }
581 };
582 let mut a_owned = match Arc::try_unwrap(a_slice) {
583 Ok(sl) => sl,
584 Err(_) => {
585 let _ = reply.send(Err(GpuError::Unrecoverable(
586 "SVD a has multiple live references".into(),
587 )));
588 return;
589 }
590 };
591 let mut s_owned = match Arc::try_unwrap(s_slice) {
592 Ok(sl) => sl,
593 Err(_) => {
594 let _ = reply.send(Err(GpuError::Unrecoverable(
595 "SVD s has multiple live references".into(),
596 )));
597 return;
598 }
599 };
600 let u_slice = match u.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
601 Some(Ok(sl)) => Some(sl),
602 Some(Err(e)) => {
603 let _ = reply.send(Err(e));
604 return;
605 }
606 None => None,
607 };
608 let vt_slice = match vt.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
609 Some(Ok(sl)) => Some(sl),
610 Some(Err(e)) => {
611 let _ = reply.send(Err(e));
612 return;
613 }
614 None => None,
615 };
616 let mut u_owned = match u_slice {
617 Some(sl) => match Arc::try_unwrap(sl) {
618 Ok(o) => Some(o),
619 Err(_) => {
620 let _ = reply.send(Err(GpuError::Unrecoverable(
621 "SVD u has multiple live references".into(),
622 )));
623 return;
624 }
625 },
626 None => None,
627 };
628 let mut vt_owned = match vt_slice {
629 Some(sl) => match Arc::try_unwrap(sl) {
630 Ok(o) => Some(o),
631 Err(_) => {
632 let _ = reply.send(Err(GpuError::Unrecoverable(
633 "SVD vt has multiple live references".into(),
634 )));
635 return;
636 }
637 },
638 None => None,
639 };
640
641 let mut lwork = 0i32;
642 {
643 let h = handle.lock();
644 let status = unsafe { T::gesvd_buffer_size(h.0.cu(), m, n, &mut lwork as *mut _) };
645 if let Err(e) = status_to_result(status, "gesvd_bufferSize") {
646 let _ = reply.send(Err(e));
647 return;
648 }
649 }
650 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
651 let _ = reply.send(Err(e));
652 return;
653 }
654
655 a.record_write(stream);
656 s.record_write(stream);
657 if let Some(g) = &u {
658 g.record_write(stream);
659 }
660 if let Some(g) = &vt {
661 g.record_write(stream);
662 }
663
664 let jobu = if u_owned.is_some() {
665 b'A' as i8
666 } else {
667 b'N' as i8
668 };
669 let jobvt = if vt_owned.is_some() {
670 b'A' as i8
671 } else {
672 b'N' as i8
673 };
674 let stream_for_check = stream.clone();
675
676 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
677 let h = handle.lock();
678 let mut ws = workspace.lock();
679 let mut info_lock = info.lock();
680 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
681 let (s_ptr, _g2) = s_owned.device_ptr_mut(&stream_for_check);
682 let (u_ptr, _gu_opt): (*mut T, _) = match u_owned.as_mut() {
683 Some(o) => {
684 let (p, g) = o.device_ptr_mut(&stream_for_check);
685 (p as *mut T, Some(g))
686 }
687 None => (std::ptr::null_mut(), None),
688 };
689 let (vt_ptr, _gvt_opt): (*mut T, _) = match vt_owned.as_mut() {
690 Some(o) => {
691 let (p, g) = o.device_ptr_mut(&stream_for_check);
692 (p as *mut T, Some(g))
693 }
694 None => (std::ptr::null_mut(), None),
695 };
696 let ws_slice = ws.as_mut().expect("workspace ensured");
697 let (ws_ptr, _g5) = ws_slice.device_ptr_mut(&stream_for_check);
698 let (info_ptr, _g6) = info_lock.device_ptr_mut(&stream_for_check);
699 let ldu = m;
700 let ldvt = n;
701 let status = unsafe {
702 T::gesvd(
703 h.0.cu(),
704 jobu,
705 jobvt,
706 m,
707 n,
708 a_ptr as *mut T,
709 m,
710 s_ptr as *mut T,
711 u_ptr,
712 ldu,
713 vt_ptr,
714 ldvt,
715 ws_ptr as *mut T,
716 lwork,
717 std::ptr::null_mut(),
718 info_ptr as *mut i32,
719 )
720 };
721 drop((_g1, _g2, _g5, _g6, _gu_opt, _gvt_opt));
722 status_to_result(status, "gesvd")?;
723 check_info(info, &stream_for_check, "gesvd")?;
724 Ok((a_owned, s_owned, u_owned, vt_owned))
725 });
726}
727
728pub struct SyevdRequest<T: SolverSupported> {
733 pub a: GpuRef<T>,
734 pub n: i32,
735 pub uplo: Uplo,
736 pub w: GpuRef<T>,
737 pub compute_vectors: bool,
738 pub reply: oneshot::Sender<Result<(), GpuError>>,
739}
740
741impl<T> SolverDispatch for SyevdRequest<T>
742where
743 T: SolverSupported + SolverScalar,
744{
745 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
746 let SyevdRequest {
747 a,
748 n,
749 uplo,
750 w,
751 compute_vectors,
752 reply,
753 } = *self;
754 run_syevd::<T>(cells, a, n, uplo, w, compute_vectors, reply);
755 }
756
757 fn dispatch_mock(self: Box<Self>) {
758 let _ = self.reply.send(Err(GpuError::Unrecoverable(
759 "SolverActor in mock mode".into(),
760 )));
761 }
762}
763
764fn run_syevd<T: SolverScalar>(
765 cells: SolverCells<'_>,
766 a: GpuRef<T>,
767 n: i32,
768 uplo: Uplo,
769 w: GpuRef<T>,
770 compute_vectors: bool,
771 reply: oneshot::Sender<Result<(), GpuError>>,
772) {
773 let SolverCells {
774 handle,
775 stream,
776 completion,
777 workspace,
778 info,
779 ..
780 } = cells;
781
782 let a_slice = match a.access() {
783 Ok(sl) => sl.clone(),
784 Err(e) => {
785 let _ = reply.send(Err(e));
786 return;
787 }
788 };
789 let w_slice = match w.access() {
790 Ok(sl) => sl.clone(),
791 Err(e) => {
792 let _ = reply.send(Err(e));
793 return;
794 }
795 };
796 let mut a_owned = match Arc::try_unwrap(a_slice) {
797 Ok(sl) => sl,
798 Err(_) => {
799 let _ = reply.send(Err(GpuError::Unrecoverable(
800 "Syevd a has multiple live references".into(),
801 )));
802 return;
803 }
804 };
805 let mut w_owned = match Arc::try_unwrap(w_slice) {
806 Ok(sl) => sl,
807 Err(_) => {
808 let _ = reply.send(Err(GpuError::Unrecoverable(
809 "Syevd w has multiple live references".into(),
810 )));
811 return;
812 }
813 };
814 let fill = uplo.as_cusolver_fill();
815 let jobz = if compute_vectors {
816 cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_VECTOR
817 } else {
818 cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_NOVECTOR
819 };
820
821 let mut lwork = 0i32;
822 {
823 let h = handle.lock();
824 let (a_ptr, _ga) = a_owned.device_ptr_mut(stream);
825 let (w_ptr, _gw) = w_owned.device_ptr_mut(stream);
826 let status = unsafe {
827 T::syevd_buffer_size(
828 h.0.cu(),
829 jobz,
830 fill,
831 n,
832 a_ptr as *const T,
833 n,
834 w_ptr as *const T,
835 &mut lwork as *mut _,
836 )
837 };
838 drop((_ga, _gw));
839 if let Err(e) = status_to_result(status, "syevd_bufferSize") {
840 let _ = reply.send(Err(e));
841 return;
842 }
843 }
844 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
845 let _ = reply.send(Err(e));
846 return;
847 }
848
849 a.record_write(stream);
850 w.record_write(stream);
851
852 let stream_for_check = stream.clone();
853 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
854 let h = handle.lock();
855 let mut ws = workspace.lock();
856 let mut info_lock = info.lock();
857 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
858 let (w_ptr, _g2) = w_owned.device_ptr_mut(&stream_for_check);
859 let ws_slice = ws.as_mut().expect("workspace ensured");
860 let (ws_ptr, _g3) = ws_slice.device_ptr_mut(&stream_for_check);
861 let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
862 let status = unsafe {
863 T::syevd(
864 h.0.cu(),
865 jobz,
866 fill,
867 n,
868 a_ptr as *mut T,
869 n,
870 w_ptr as *mut T,
871 ws_ptr as *mut T,
872 lwork,
873 info_ptr as *mut i32,
874 )
875 };
876 drop((_g1, _g2, _g3, _g4));
877 status_to_result(status, "syevd")?;
878 check_info(info, &stream_for_check, "syevd")?;
879 Ok((a_owned, w_owned))
880 });
881}
882
883#[allow(dead_code)]
885fn _phantom<T: SolverSupported>() -> PhantomData<T> {
886 PhantomData
887}
888
889#[cfg(test)]
894mod tests {
895 use super::*;
896
897 #[test]
903 fn qr_lu_cholesky_svd_syevd_round_trip_f32_f64() {
904 fn assert_dispatch<R: SolverDispatch>() {}
911 assert_dispatch::<QrRequest<f32>>();
912 assert_dispatch::<QrRequest<f64>>();
913 assert_dispatch::<LuRequest<f32>>();
914 assert_dispatch::<LuRequest<f64>>();
915 assert_dispatch::<LuSolveRequest<f32>>();
916 assert_dispatch::<LuSolveRequest<f64>>();
917 assert_dispatch::<CholeskyRequest<f32>>();
918 assert_dispatch::<CholeskyRequest<f64>>();
919 assert_dispatch::<SvdRequest<f32>>();
920 assert_dispatch::<SvdRequest<f64>>();
921 assert_dispatch::<SyevdRequest<f32>>();
922 assert_dispatch::<SyevdRequest<f64>>();
923 }
924}