Skip to main content

atomr_accel_cuda/kernel/solver/
dense.rs

1//! Dense cuSOLVER ops: QR, LU (factorize / solve), Cholesky, SVD,
2//! Syevd. Each request struct is generic over `T: SolverSupported`
3//! (f32, f64) and dispatches through
4//! [`crate::sys::cusolver::SolverScalar`].
5
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use cudarc::cusolver::sys as cs;
10use cudarc::driver::{DevicePtr, DevicePtrMut};
11use tokio::sync::oneshot;
12
13use crate::dtype::SolverSupported;
14use crate::error::GpuError;
15use crate::gpu_ref::GpuRef;
16use crate::kernel::envelope;
17use crate::sys::cusolver::{status_to_result, SolverScalar, LIB};
18
19use super::workspace::{check_info, ensure_workspace_bytes, lwork_bytes};
20use super::{SolverCells, SolverDispatch, Uplo};
21
22// =====================================================================
23// QR factorisation
24// =====================================================================
25
26pub struct QrRequest<T: SolverSupported> {
27    pub a: GpuRef<T>,
28    pub m: i32,
29    pub n: i32,
30    pub tau: GpuRef<T>,
31    pub reply: oneshot::Sender<Result<(), GpuError>>,
32}
33
34impl<T> SolverDispatch for QrRequest<T>
35where
36    T: SolverSupported + SolverScalar,
37{
38    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
39        let QrRequest {
40            a,
41            m,
42            n,
43            tau,
44            reply,
45        } = *self;
46        run_qr::<T>(cells, a, m, n, tau, reply);
47    }
48
49    fn dispatch_mock(self: Box<Self>) {
50        let _ = self.reply.send(Err(GpuError::Unrecoverable(
51            "SolverActor in mock mode".into(),
52        )));
53    }
54}
55
56fn run_qr<T: SolverScalar>(
57    cells: SolverCells<'_>,
58    a: GpuRef<T>,
59    m: i32,
60    n: i32,
61    tau: GpuRef<T>,
62    reply: oneshot::Sender<Result<(), GpuError>>,
63) {
64    let SolverCells {
65        handle,
66        stream,
67        completion,
68        workspace,
69        info,
70        ..
71    } = cells;
72
73    let (a_slice, tau_slice) = match envelope::access_all_2(&a, &tau) {
74        Ok(t) => t,
75        Err(e) => {
76            let _ = reply.send(Err(e));
77            return;
78        }
79    };
80    let mut a_owned = match Arc::try_unwrap(a_slice) {
81        Ok(s) => s,
82        Err(_) => {
83            let _ = reply.send(Err(GpuError::Unrecoverable(
84                "QR a has multiple live references".into(),
85            )));
86            return;
87        }
88    };
89    let mut tau_owned = match Arc::try_unwrap(tau_slice) {
90        Ok(s) => s,
91        Err(_) => {
92            let _ = reply.send(Err(GpuError::Unrecoverable(
93                "QR tau has multiple live references".into(),
94            )));
95            return;
96        }
97    };
98
99    let mut lwork = 0i32;
100    {
101        let h = handle.lock();
102        let (a_ptr, _g) = a_owned.device_ptr_mut(stream);
103        let status = unsafe {
104            T::geqrf_buffer_size(h.0.cu(), m, n, a_ptr as *mut T, m, &mut lwork as *mut _)
105        };
106        drop(_g);
107        if let Err(e) = status_to_result(status, "geqrf_bufferSize") {
108            let _ = reply.send(Err(e));
109            return;
110        }
111    }
112
113    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
114        let _ = reply.send(Err(e));
115        return;
116    }
117
118    a.record_write(stream);
119    tau.record_write(stream);
120
121    let stream_for_check = stream.clone();
122    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
123        let h = handle.lock();
124        let mut ws = workspace.lock();
125        let mut info_lock = info.lock();
126        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
127        let (tau_ptr, _g2) = tau_owned.device_ptr_mut(&stream_for_check);
128        let ws_slice = ws.as_mut().expect("workspace ensured");
129        let (ws_ptr, _g3) = ws_slice.device_ptr_mut(&stream_for_check);
130        let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
131        let status = unsafe {
132            T::geqrf(
133                h.0.cu(),
134                m,
135                n,
136                a_ptr as *mut T,
137                m,
138                tau_ptr as *mut T,
139                ws_ptr as *mut T,
140                lwork,
141                info_ptr as *mut i32,
142            )
143        };
144        drop((_g1, _g2, _g3, _g4));
145        status_to_result(status, "geqrf")?;
146        check_info(info, &stream_for_check, "geqrf")?;
147        Ok((a_owned, tau_owned))
148    });
149}
150
151// =====================================================================
152// LU factorisation (`getrf`) and solve (`getrs`)
153// =====================================================================
154
155pub struct LuRequest<T: SolverSupported> {
156    pub a: GpuRef<T>,
157    pub m: i32,
158    pub n: i32,
159    pub ipiv: GpuRef<i32>,
160    pub reply: oneshot::Sender<Result<(), GpuError>>,
161}
162
163pub struct LuSolveRequest<T: SolverSupported> {
164    pub lu: GpuRef<T>,
165    pub ipiv: GpuRef<i32>,
166    pub b: GpuRef<T>,
167    pub n: i32,
168    pub nrhs: i32,
169    pub trans: bool,
170    pub reply: oneshot::Sender<Result<(), GpuError>>,
171}
172
173impl<T> SolverDispatch for LuRequest<T>
174where
175    T: SolverSupported + SolverScalar,
176{
177    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
178        let LuRequest {
179            a,
180            m,
181            n,
182            ipiv,
183            reply,
184        } = *self;
185        run_lu::<T>(cells, a, m, n, ipiv, reply);
186    }
187
188    fn dispatch_mock(self: Box<Self>) {
189        let _ = self.reply.send(Err(GpuError::Unrecoverable(
190            "SolverActor in mock mode".into(),
191        )));
192    }
193}
194
195impl<T> SolverDispatch for LuSolveRequest<T>
196where
197    T: SolverSupported + SolverScalar,
198{
199    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
200        let LuSolveRequest {
201            lu,
202            ipiv,
203            b,
204            n,
205            nrhs,
206            trans,
207            reply,
208        } = *self;
209        run_lu_solve::<T>(cells, lu, ipiv, b, n, nrhs, trans, reply);
210    }
211
212    fn dispatch_mock(self: Box<Self>) {
213        let _ = self.reply.send(Err(GpuError::Unrecoverable(
214            "SolverActor in mock mode".into(),
215        )));
216    }
217}
218
219fn run_lu<T: SolverScalar>(
220    cells: SolverCells<'_>,
221    a: GpuRef<T>,
222    m: i32,
223    n: i32,
224    ipiv: GpuRef<i32>,
225    reply: oneshot::Sender<Result<(), GpuError>>,
226) {
227    let SolverCells {
228        handle,
229        stream,
230        completion,
231        workspace,
232        info,
233        ..
234    } = cells;
235
236    let a_slice = match a.access() {
237        Ok(s) => s.clone(),
238        Err(e) => {
239            let _ = reply.send(Err(e));
240            return;
241        }
242    };
243    let ipiv_slice = match ipiv.access() {
244        Ok(s) => s.clone(),
245        Err(e) => {
246            let _ = reply.send(Err(e));
247            return;
248        }
249    };
250    let mut a_owned = match Arc::try_unwrap(a_slice) {
251        Ok(s) => s,
252        Err(_) => {
253            let _ = reply.send(Err(GpuError::Unrecoverable(
254                "LU a has multiple live references".into(),
255            )));
256            return;
257        }
258    };
259    let mut ipiv_owned = match Arc::try_unwrap(ipiv_slice) {
260        Ok(s) => s,
261        Err(_) => {
262            let _ = reply.send(Err(GpuError::Unrecoverable(
263                "LU ipiv has multiple live references".into(),
264            )));
265            return;
266        }
267    };
268
269    let mut lwork = 0i32;
270    {
271        let h = handle.lock();
272        let (a_ptr, _g) = a_owned.device_ptr_mut(stream);
273        let status = unsafe {
274            T::getrf_buffer_size(h.0.cu(), m, n, a_ptr as *mut T, m, &mut lwork as *mut _)
275        };
276        drop(_g);
277        if let Err(e) = status_to_result(status, "getrf_bufferSize") {
278            let _ = reply.send(Err(e));
279            return;
280        }
281    }
282    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
283        let _ = reply.send(Err(e));
284        return;
285    }
286
287    a.record_write(stream);
288    ipiv.record_write(stream);
289
290    let stream_for_check = stream.clone();
291    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
292        let h = handle.lock();
293        let mut ws = workspace.lock();
294        let mut info_lock = info.lock();
295        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
296        let (ipiv_ptr, _g2) = ipiv_owned.device_ptr_mut(&stream_for_check);
297        let ws_slice = ws.as_mut().expect("workspace ensured");
298        let (ws_ptr, _g3) = ws_slice.device_ptr_mut(&stream_for_check);
299        let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
300        let status = unsafe {
301            T::getrf(
302                h.0.cu(),
303                m,
304                n,
305                a_ptr as *mut T,
306                m,
307                ws_ptr as *mut T,
308                ipiv_ptr as *mut i32,
309                info_ptr as *mut i32,
310            )
311        };
312        drop((_g1, _g2, _g3, _g4));
313        status_to_result(status, "getrf")?;
314        check_info(info, &stream_for_check, "getrf")?;
315        Ok((a_owned, ipiv_owned))
316    });
317}
318
319fn run_lu_solve<T: SolverScalar>(
320    cells: SolverCells<'_>,
321    lu: GpuRef<T>,
322    ipiv: GpuRef<i32>,
323    b: GpuRef<T>,
324    n: i32,
325    nrhs: i32,
326    trans: bool,
327    reply: oneshot::Sender<Result<(), GpuError>>,
328) {
329    let SolverCells {
330        handle,
331        stream,
332        completion,
333        info,
334        ..
335    } = cells;
336
337    let lu_slice = match lu.access() {
338        Ok(s) => s.clone(),
339        Err(e) => {
340            let _ = reply.send(Err(e));
341            return;
342        }
343    };
344    let ipiv_slice = match ipiv.access() {
345        Ok(s) => s.clone(),
346        Err(e) => {
347            let _ = reply.send(Err(e));
348            return;
349        }
350    };
351    let b_slice = match b.access() {
352        Ok(s) => s.clone(),
353        Err(e) => {
354            let _ = reply.send(Err(e));
355            return;
356        }
357    };
358    let mut b_owned = match Arc::try_unwrap(b_slice) {
359        Ok(s) => s,
360        Err(_) => {
361            let _ = reply.send(Err(GpuError::Unrecoverable(
362                "LU b has multiple live references".into(),
363            )));
364            return;
365        }
366    };
367    let trans_op = if trans {
368        cs::cublasOperation_t::CUBLAS_OP_T
369    } else {
370        cs::cublasOperation_t::CUBLAS_OP_N
371    };
372    b.record_write(stream);
373
374    let stream_for_check = stream.clone();
375    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
376        let h = handle.lock();
377        let mut info_lock = info.lock();
378        let (lu_ptr, _g1) = lu_slice.device_ptr(&stream_for_check);
379        let (ipiv_ptr, _g2) = ipiv_slice.device_ptr(&stream_for_check);
380        let (b_ptr, _g3) = b_owned.device_ptr_mut(&stream_for_check);
381        let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
382        let status = unsafe {
383            T::getrs(
384                h.0.cu(),
385                trans_op,
386                n,
387                nrhs,
388                lu_ptr as *const T,
389                n,
390                ipiv_ptr as *const i32,
391                b_ptr as *mut T,
392                n,
393                info_ptr as *mut i32,
394            )
395        };
396        drop((_g1, _g2, _g3, _g4));
397        status_to_result(status, "getrs")?;
398        check_info(info, &stream_for_check, "getrs")?;
399        Ok((lu_slice, ipiv_slice, b_owned))
400    });
401}
402
403// =====================================================================
404// Cholesky
405// =====================================================================
406
407pub struct CholeskyRequest<T: SolverSupported> {
408    pub a: GpuRef<T>,
409    pub n: i32,
410    pub uplo: Uplo,
411    pub reply: oneshot::Sender<Result<(), GpuError>>,
412}
413
414impl<T> SolverDispatch for CholeskyRequest<T>
415where
416    T: SolverSupported + SolverScalar,
417{
418    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
419        let CholeskyRequest { a, n, uplo, reply } = *self;
420        run_cholesky::<T>(cells, a, n, uplo, reply);
421    }
422
423    fn dispatch_mock(self: Box<Self>) {
424        let _ = self.reply.send(Err(GpuError::Unrecoverable(
425            "SolverActor in mock mode".into(),
426        )));
427    }
428}
429
430fn run_cholesky<T: SolverScalar>(
431    cells: SolverCells<'_>,
432    a: GpuRef<T>,
433    n: i32,
434    uplo: Uplo,
435    reply: oneshot::Sender<Result<(), GpuError>>,
436) {
437    let SolverCells {
438        handle,
439        stream,
440        completion,
441        workspace,
442        info,
443        ..
444    } = cells;
445
446    let a_slice = match a.access() {
447        Ok(s) => s.clone(),
448        Err(e) => {
449            let _ = reply.send(Err(e));
450            return;
451        }
452    };
453    let mut a_owned = match Arc::try_unwrap(a_slice) {
454        Ok(s) => s,
455        Err(_) => {
456            let _ = reply.send(Err(GpuError::Unrecoverable(
457                "Cholesky a has multiple live references".into(),
458            )));
459            return;
460        }
461    };
462    let fill = uplo.as_cusolver_fill();
463
464    let mut lwork = 0i32;
465    {
466        let h = handle.lock();
467        let (a_ptr, _g) = a_owned.device_ptr_mut(stream);
468        let status = unsafe {
469            T::potrf_buffer_size(h.0.cu(), fill, n, a_ptr as *mut T, n, &mut lwork as *mut _)
470        };
471        drop(_g);
472        if let Err(e) = status_to_result(status, "potrf_bufferSize") {
473            let _ = reply.send(Err(e));
474            return;
475        }
476    }
477    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
478        let _ = reply.send(Err(e));
479        return;
480    }
481    a.record_write(stream);
482
483    let stream_for_check = stream.clone();
484    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
485        let h = handle.lock();
486        let mut ws = workspace.lock();
487        let mut info_lock = info.lock();
488        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
489        let ws_slice = ws.as_mut().expect("workspace ensured");
490        let (ws_ptr, _g2) = ws_slice.device_ptr_mut(&stream_for_check);
491        let (info_ptr, _g3) = info_lock.device_ptr_mut(&stream_for_check);
492        let status = unsafe {
493            T::potrf(
494                h.0.cu(),
495                fill,
496                n,
497                a_ptr as *mut T,
498                n,
499                ws_ptr as *mut T,
500                lwork,
501                info_ptr as *mut i32,
502            )
503        };
504        drop((_g1, _g2, _g3));
505        status_to_result(status, "potrf")?;
506        check_info(info, &stream_for_check, "potrf")?;
507        Ok((a_owned,))
508    });
509}
510
511// =====================================================================
512// SVD
513// =====================================================================
514
515pub struct SvdRequest<T: SolverSupported> {
516    pub a: GpuRef<T>,
517    pub m: i32,
518    pub n: i32,
519    pub s: GpuRef<T>,
520    pub u: Option<GpuRef<T>>,
521    pub vt: Option<GpuRef<T>>,
522    pub reply: oneshot::Sender<Result<(), GpuError>>,
523}
524
525impl<T> SolverDispatch for SvdRequest<T>
526where
527    T: SolverSupported + SolverScalar,
528{
529    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
530        let SvdRequest {
531            a,
532            m,
533            n,
534            s,
535            u,
536            vt,
537            reply,
538        } = *self;
539        run_svd::<T>(cells, a, m, n, s, u, vt, reply);
540    }
541
542    fn dispatch_mock(self: Box<Self>) {
543        let _ = self.reply.send(Err(GpuError::Unrecoverable(
544            "SolverActor in mock mode".into(),
545        )));
546    }
547}
548
549fn run_svd<T: SolverScalar>(
550    cells: SolverCells<'_>,
551    a: GpuRef<T>,
552    m: i32,
553    n: i32,
554    s: GpuRef<T>,
555    u: Option<GpuRef<T>>,
556    vt: Option<GpuRef<T>>,
557    reply: oneshot::Sender<Result<(), GpuError>>,
558) {
559    let SolverCells {
560        handle,
561        stream,
562        completion,
563        workspace,
564        info,
565        ..
566    } = cells;
567
568    let a_slice = match a.access() {
569        Ok(sl) => sl.clone(),
570        Err(e) => {
571            let _ = reply.send(Err(e));
572            return;
573        }
574    };
575    let s_slice = match s.access() {
576        Ok(sl) => sl.clone(),
577        Err(e) => {
578            let _ = reply.send(Err(e));
579            return;
580        }
581    };
582    let mut a_owned = match Arc::try_unwrap(a_slice) {
583        Ok(sl) => sl,
584        Err(_) => {
585            let _ = reply.send(Err(GpuError::Unrecoverable(
586                "SVD a has multiple live references".into(),
587            )));
588            return;
589        }
590    };
591    let mut s_owned = match Arc::try_unwrap(s_slice) {
592        Ok(sl) => sl,
593        Err(_) => {
594            let _ = reply.send(Err(GpuError::Unrecoverable(
595                "SVD s has multiple live references".into(),
596            )));
597            return;
598        }
599    };
600    let u_slice = match u.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
601        Some(Ok(sl)) => Some(sl),
602        Some(Err(e)) => {
603            let _ = reply.send(Err(e));
604            return;
605        }
606        None => None,
607    };
608    let vt_slice = match vt.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
609        Some(Ok(sl)) => Some(sl),
610        Some(Err(e)) => {
611            let _ = reply.send(Err(e));
612            return;
613        }
614        None => None,
615    };
616    let mut u_owned = match u_slice {
617        Some(sl) => match Arc::try_unwrap(sl) {
618            Ok(o) => Some(o),
619            Err(_) => {
620                let _ = reply.send(Err(GpuError::Unrecoverable(
621                    "SVD u has multiple live references".into(),
622                )));
623                return;
624            }
625        },
626        None => None,
627    };
628    let mut vt_owned = match vt_slice {
629        Some(sl) => match Arc::try_unwrap(sl) {
630            Ok(o) => Some(o),
631            Err(_) => {
632                let _ = reply.send(Err(GpuError::Unrecoverable(
633                    "SVD vt has multiple live references".into(),
634                )));
635                return;
636            }
637        },
638        None => None,
639    };
640
641    let mut lwork = 0i32;
642    {
643        let h = handle.lock();
644        let status = unsafe { T::gesvd_buffer_size(h.0.cu(), m, n, &mut lwork as *mut _) };
645        if let Err(e) = status_to_result(status, "gesvd_bufferSize") {
646            let _ = reply.send(Err(e));
647            return;
648        }
649    }
650    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
651        let _ = reply.send(Err(e));
652        return;
653    }
654
655    a.record_write(stream);
656    s.record_write(stream);
657    if let Some(g) = &u {
658        g.record_write(stream);
659    }
660    if let Some(g) = &vt {
661        g.record_write(stream);
662    }
663
664    let jobu = if u_owned.is_some() {
665        b'A' as i8
666    } else {
667        b'N' as i8
668    };
669    let jobvt = if vt_owned.is_some() {
670        b'A' as i8
671    } else {
672        b'N' as i8
673    };
674    let stream_for_check = stream.clone();
675
676    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
677        let h = handle.lock();
678        let mut ws = workspace.lock();
679        let mut info_lock = info.lock();
680        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
681        let (s_ptr, _g2) = s_owned.device_ptr_mut(&stream_for_check);
682        let (u_ptr, _gu_opt): (*mut T, _) = match u_owned.as_mut() {
683            Some(o) => {
684                let (p, g) = o.device_ptr_mut(&stream_for_check);
685                (p as *mut T, Some(g))
686            }
687            None => (std::ptr::null_mut(), None),
688        };
689        let (vt_ptr, _gvt_opt): (*mut T, _) = match vt_owned.as_mut() {
690            Some(o) => {
691                let (p, g) = o.device_ptr_mut(&stream_for_check);
692                (p as *mut T, Some(g))
693            }
694            None => (std::ptr::null_mut(), None),
695        };
696        let ws_slice = ws.as_mut().expect("workspace ensured");
697        let (ws_ptr, _g5) = ws_slice.device_ptr_mut(&stream_for_check);
698        let (info_ptr, _g6) = info_lock.device_ptr_mut(&stream_for_check);
699        let ldu = m;
700        let ldvt = n;
701        let status = unsafe {
702            T::gesvd(
703                h.0.cu(),
704                jobu,
705                jobvt,
706                m,
707                n,
708                a_ptr as *mut T,
709                m,
710                s_ptr as *mut T,
711                u_ptr,
712                ldu,
713                vt_ptr,
714                ldvt,
715                ws_ptr as *mut T,
716                lwork,
717                std::ptr::null_mut(),
718                info_ptr as *mut i32,
719            )
720        };
721        drop((_g1, _g2, _g5, _g6, _gu_opt, _gvt_opt));
722        status_to_result(status, "gesvd")?;
723        check_info(info, &stream_for_check, "gesvd")?;
724        Ok((a_owned, s_owned, u_owned, vt_owned))
725    });
726}
727
728// =====================================================================
729// Symmetric eigendecomposition
730// =====================================================================
731
732pub struct SyevdRequest<T: SolverSupported> {
733    pub a: GpuRef<T>,
734    pub n: i32,
735    pub uplo: Uplo,
736    pub w: GpuRef<T>,
737    pub compute_vectors: bool,
738    pub reply: oneshot::Sender<Result<(), GpuError>>,
739}
740
741impl<T> SolverDispatch for SyevdRequest<T>
742where
743    T: SolverSupported + SolverScalar,
744{
745    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
746        let SyevdRequest {
747            a,
748            n,
749            uplo,
750            w,
751            compute_vectors,
752            reply,
753        } = *self;
754        run_syevd::<T>(cells, a, n, uplo, w, compute_vectors, reply);
755    }
756
757    fn dispatch_mock(self: Box<Self>) {
758        let _ = self.reply.send(Err(GpuError::Unrecoverable(
759            "SolverActor in mock mode".into(),
760        )));
761    }
762}
763
764fn run_syevd<T: SolverScalar>(
765    cells: SolverCells<'_>,
766    a: GpuRef<T>,
767    n: i32,
768    uplo: Uplo,
769    w: GpuRef<T>,
770    compute_vectors: bool,
771    reply: oneshot::Sender<Result<(), GpuError>>,
772) {
773    let SolverCells {
774        handle,
775        stream,
776        completion,
777        workspace,
778        info,
779        ..
780    } = cells;
781
782    let a_slice = match a.access() {
783        Ok(sl) => sl.clone(),
784        Err(e) => {
785            let _ = reply.send(Err(e));
786            return;
787        }
788    };
789    let w_slice = match w.access() {
790        Ok(sl) => sl.clone(),
791        Err(e) => {
792            let _ = reply.send(Err(e));
793            return;
794        }
795    };
796    let mut a_owned = match Arc::try_unwrap(a_slice) {
797        Ok(sl) => sl,
798        Err(_) => {
799            let _ = reply.send(Err(GpuError::Unrecoverable(
800                "Syevd a has multiple live references".into(),
801            )));
802            return;
803        }
804    };
805    let mut w_owned = match Arc::try_unwrap(w_slice) {
806        Ok(sl) => sl,
807        Err(_) => {
808            let _ = reply.send(Err(GpuError::Unrecoverable(
809                "Syevd w has multiple live references".into(),
810            )));
811            return;
812        }
813    };
814    let fill = uplo.as_cusolver_fill();
815    let jobz = if compute_vectors {
816        cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_VECTOR
817    } else {
818        cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_NOVECTOR
819    };
820
821    let mut lwork = 0i32;
822    {
823        let h = handle.lock();
824        let (a_ptr, _ga) = a_owned.device_ptr_mut(stream);
825        let (w_ptr, _gw) = w_owned.device_ptr_mut(stream);
826        let status = unsafe {
827            T::syevd_buffer_size(
828                h.0.cu(),
829                jobz,
830                fill,
831                n,
832                a_ptr as *const T,
833                n,
834                w_ptr as *const T,
835                &mut lwork as *mut _,
836            )
837        };
838        drop((_ga, _gw));
839        if let Err(e) = status_to_result(status, "syevd_bufferSize") {
840            let _ = reply.send(Err(e));
841            return;
842        }
843    }
844    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
845        let _ = reply.send(Err(e));
846        return;
847    }
848
849    a.record_write(stream);
850    w.record_write(stream);
851
852    let stream_for_check = stream.clone();
853    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
854        let h = handle.lock();
855        let mut ws = workspace.lock();
856        let mut info_lock = info.lock();
857        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
858        let (w_ptr, _g2) = w_owned.device_ptr_mut(&stream_for_check);
859        let ws_slice = ws.as_mut().expect("workspace ensured");
860        let (ws_ptr, _g3) = ws_slice.device_ptr_mut(&stream_for_check);
861        let (info_ptr, _g4) = info_lock.device_ptr_mut(&stream_for_check);
862        let status = unsafe {
863            T::syevd(
864                h.0.cu(),
865                jobz,
866                fill,
867                n,
868                a_ptr as *mut T,
869                n,
870                w_ptr as *mut T,
871                ws_ptr as *mut T,
872                lwork,
873                info_ptr as *mut i32,
874            )
875        };
876        drop((_g1, _g2, _g3, _g4));
877        status_to_result(status, "syevd")?;
878        check_info(info, &stream_for_check, "syevd")?;
879        Ok((a_owned, w_owned))
880    });
881}
882
883// Suppress unused import warnings for the f64-only typed alias below.
884#[allow(dead_code)]
885fn _phantom<T: SolverSupported>() -> PhantomData<T> {
886    PhantomData
887}
888
889// =====================================================================
890// Tests
891// =====================================================================
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896
897    /// Round-trip the request types through `Box<dyn SolverDispatch>`
898    /// for both f32 and f64 to ensure the dispatch generics compile.
899    /// We can't run the kernels without a GPU, so we just assert that
900    /// dropping a request through `dispatch_mock` closes the reply
901    /// with the expected error.
902    #[test]
903    fn qr_lu_cholesky_svd_syevd_round_trip_f32_f64() {
904        // We don't have a real GpuRef here; this test is purely a
905        // compile-time check that all permutations form a valid
906        // `SolverDispatch`. We construct a `Box<dyn SolverDispatch>`
907        // for each (op × dtype) and let it drop. The `SolverActor`
908        // mock branch is exercised separately in the integration
909        // test below.
910        fn assert_dispatch<R: SolverDispatch>() {}
911        assert_dispatch::<QrRequest<f32>>();
912        assert_dispatch::<QrRequest<f64>>();
913        assert_dispatch::<LuRequest<f32>>();
914        assert_dispatch::<LuRequest<f64>>();
915        assert_dispatch::<LuSolveRequest<f32>>();
916        assert_dispatch::<LuSolveRequest<f64>>();
917        assert_dispatch::<CholeskyRequest<f32>>();
918        assert_dispatch::<CholeskyRequest<f64>>();
919        assert_dispatch::<SvdRequest<f32>>();
920        assert_dispatch::<SvdRequest<f64>>();
921        assert_dispatch::<SyevdRequest<f32>>();
922        assert_dispatch::<SyevdRequest<f64>>();
923    }
924}