Skip to main content

atomr_accel_cuda/kernel/solver/
generalized.rs

1//! Generalized symmetric / Hermitian eigenvalue problems.
2//!
3//! Solves `A x = λ B x` (itype=1), `A B x = λ x` (itype=2), or
4//! `B A x = λ x` (itype=3) for an `n × n` symmetric `A` and SPD
5//! `B`. f32/f64 dispatch through `cusolverDn[SD]sygvd`; the
6//! Hermitian-complex `Hegvd` request is type-aliased to the same
7//! launch path so adding c32/c64 in a future phase only needs
8//! `SolverScalar` impls — no new request type.
9
10use std::sync::Arc;
11
12use cudarc::cusolver::sys as cs;
13use cudarc::driver::DevicePtrMut;
14use tokio::sync::oneshot;
15
16use crate::dtype::SolverSupported;
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::envelope;
20use crate::sys::cusolver::{status_to_result, SolverScalar, LIB};
21
22use super::workspace::{check_info, ensure_workspace_bytes, lwork_bytes};
23use super::{SolverCells, SolverDispatch, Uplo};
24
25/// Eigenproblem type as defined by cuSOLVER's `cusolverEigType_t`:
26/// `1: A x = λ B x`, `2: A B x = λ x`, `3: B A x = λ x`.
27#[derive(Debug, Clone, Copy)]
28pub enum EigType {
29    Type1,
30    Type2,
31    Type3,
32}
33
34impl EigType {
35    fn as_cusolver(self) -> cs::cusolverEigType_t {
36        match self {
37            EigType::Type1 => cs::cusolverEigType_t::CUSOLVER_EIG_TYPE_1,
38            EigType::Type2 => cs::cusolverEigType_t::CUSOLVER_EIG_TYPE_2,
39            EigType::Type3 => cs::cusolverEigType_t::CUSOLVER_EIG_TYPE_3,
40        }
41    }
42}
43
44pub struct SygvdRequest<T: SolverSupported> {
45    pub a: GpuRef<T>,
46    pub b: GpuRef<T>,
47    pub n: i32,
48    pub itype: EigType,
49    pub uplo: Uplo,
50    pub w: GpuRef<T>,
51    pub compute_vectors: bool,
52    pub reply: oneshot::Sender<Result<(), GpuError>>,
53}
54
55/// `Hegvd` (Hermitian) is the complex sibling of `Sygvd`. Phase 1
56/// supports only real dtypes, so this request currently shares the
57/// same launch path as `SygvdRequest`. Promoting to a distinct
58/// surface lets callers express intent today and lets us add c32/c64
59/// without a SemVer break later.
60pub struct HegvdRequest<T: SolverSupported> {
61    pub a: GpuRef<T>,
62    pub b: GpuRef<T>,
63    pub n: i32,
64    pub itype: EigType,
65    pub uplo: Uplo,
66    pub w: GpuRef<T>,
67    pub compute_vectors: bool,
68    pub reply: oneshot::Sender<Result<(), GpuError>>,
69}
70
71impl<T> SolverDispatch for SygvdRequest<T>
72where
73    T: SolverSupported + SolverScalar,
74{
75    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
76        let SygvdRequest {
77            a,
78            b,
79            n,
80            itype,
81            uplo,
82            w,
83            compute_vectors,
84            reply,
85        } = *self;
86        run_sygvd::<T>(cells, a, b, n, itype, uplo, w, compute_vectors, reply);
87    }
88
89    fn dispatch_mock(self: Box<Self>) {
90        let _ = self.reply.send(Err(GpuError::Unrecoverable(
91            "SolverActor in mock mode".into(),
92        )));
93    }
94}
95
96impl<T> SolverDispatch for HegvdRequest<T>
97where
98    T: SolverSupported + SolverScalar,
99{
100    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
101        let HegvdRequest {
102            a,
103            b,
104            n,
105            itype,
106            uplo,
107            w,
108            compute_vectors,
109            reply,
110        } = *self;
111        run_sygvd::<T>(cells, a, b, n, itype, uplo, w, compute_vectors, reply);
112    }
113
114    fn dispatch_mock(self: Box<Self>) {
115        let _ = self.reply.send(Err(GpuError::Unrecoverable(
116            "SolverActor in mock mode".into(),
117        )));
118    }
119}
120
121fn run_sygvd<T: SolverScalar>(
122    cells: SolverCells<'_>,
123    a: GpuRef<T>,
124    b: GpuRef<T>,
125    n: i32,
126    itype: EigType,
127    uplo: Uplo,
128    w: GpuRef<T>,
129    compute_vectors: bool,
130    reply: oneshot::Sender<Result<(), GpuError>>,
131) {
132    let SolverCells {
133        handle,
134        stream,
135        completion,
136        workspace,
137        info,
138        ..
139    } = cells;
140
141    let a_slice = match a.access() {
142        Ok(s) => s.clone(),
143        Err(e) => {
144            let _ = reply.send(Err(e));
145            return;
146        }
147    };
148    let b_slice = match b.access() {
149        Ok(s) => s.clone(),
150        Err(e) => {
151            let _ = reply.send(Err(e));
152            return;
153        }
154    };
155    let w_slice = match w.access() {
156        Ok(s) => s.clone(),
157        Err(e) => {
158            let _ = reply.send(Err(e));
159            return;
160        }
161    };
162    let mut a_owned = match Arc::try_unwrap(a_slice) {
163        Ok(s) => s,
164        Err(_) => {
165            let _ = reply.send(Err(GpuError::Unrecoverable(
166                "Sygvd a has multiple live references".into(),
167            )));
168            return;
169        }
170    };
171    let mut b_owned = match Arc::try_unwrap(b_slice) {
172        Ok(s) => s,
173        Err(_) => {
174            let _ = reply.send(Err(GpuError::Unrecoverable(
175                "Sygvd b has multiple live references".into(),
176            )));
177            return;
178        }
179    };
180    let mut w_owned = match Arc::try_unwrap(w_slice) {
181        Ok(s) => s,
182        Err(_) => {
183            let _ = reply.send(Err(GpuError::Unrecoverable(
184                "Sygvd w has multiple live references".into(),
185            )));
186            return;
187        }
188    };
189
190    let fill = uplo.as_cusolver_fill();
191    let jobz = if compute_vectors {
192        cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_VECTOR
193    } else {
194        cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_NOVECTOR
195    };
196    let itype_cs = itype.as_cusolver();
197
198    let mut lwork = 0i32;
199    {
200        let h = handle.lock();
201        let (a_ptr, _ga) = a_owned.device_ptr_mut(stream);
202        let (b_ptr, _gb) = b_owned.device_ptr_mut(stream);
203        let (w_ptr, _gw) = w_owned.device_ptr_mut(stream);
204        let status = unsafe {
205            T::sygvd_buffer_size(
206                h.0.cu(),
207                itype_cs,
208                jobz,
209                fill,
210                n,
211                a_ptr as *const T,
212                n,
213                b_ptr as *const T,
214                n,
215                w_ptr as *const T,
216                &mut lwork as *mut _,
217            )
218        };
219        drop((_ga, _gb, _gw));
220        if let Err(e) = status_to_result(status, "sygvd_bufferSize") {
221            let _ = reply.send(Err(e));
222            return;
223        }
224    }
225    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
226        let _ = reply.send(Err(e));
227        return;
228    }
229
230    a.record_write(stream);
231    b.record_write(stream);
232    w.record_write(stream);
233
234    let stream_for_check = stream.clone();
235    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
236        let h = handle.lock();
237        let mut ws = workspace.lock();
238        let mut info_lock = info.lock();
239        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
240        let (b_ptr, _g2) = b_owned.device_ptr_mut(&stream_for_check);
241        let (w_ptr, _g3) = w_owned.device_ptr_mut(&stream_for_check);
242        let ws_slice = ws.as_mut().expect("workspace ensured");
243        let (ws_ptr, _g4) = ws_slice.device_ptr_mut(&stream_for_check);
244        let (info_ptr, _g5) = info_lock.device_ptr_mut(&stream_for_check);
245        let status = unsafe {
246            T::sygvd(
247                h.0.cu(),
248                itype_cs,
249                jobz,
250                fill,
251                n,
252                a_ptr as *mut T,
253                n,
254                b_ptr as *mut T,
255                n,
256                w_ptr as *mut T,
257                ws_ptr as *mut T,
258                lwork,
259                info_ptr as *mut i32,
260            )
261        };
262        drop((_g1, _g2, _g3, _g4, _g5));
263        status_to_result(status, "sygvd")?;
264        check_info(info, &stream_for_check, "sygvd")?;
265        Ok((a_owned, b_owned, w_owned))
266    });
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn sygvd_request_round_trip() {
275        fn assert_dispatch<R: SolverDispatch>() {}
276        assert_dispatch::<SygvdRequest<f32>>();
277        assert_dispatch::<SygvdRequest<f64>>();
278        assert_dispatch::<HegvdRequest<f32>>();
279        assert_dispatch::<HegvdRequest<f64>>();
280    }
281}