atomr_accel_cuda/kernel/solver/
generalized.rs1use std::sync::Arc;
11
12use cudarc::cusolver::sys as cs;
13use cudarc::driver::DevicePtrMut;
14use tokio::sync::oneshot;
15
16use crate::dtype::SolverSupported;
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::envelope;
20use crate::sys::cusolver::{status_to_result, SolverScalar, LIB};
21
22use super::workspace::{check_info, ensure_workspace_bytes, lwork_bytes};
23use super::{SolverCells, SolverDispatch, Uplo};
24
25#[derive(Debug, Clone, Copy)]
28pub enum EigType {
29 Type1,
30 Type2,
31 Type3,
32}
33
34impl EigType {
35 fn as_cusolver(self) -> cs::cusolverEigType_t {
36 match self {
37 EigType::Type1 => cs::cusolverEigType_t::CUSOLVER_EIG_TYPE_1,
38 EigType::Type2 => cs::cusolverEigType_t::CUSOLVER_EIG_TYPE_2,
39 EigType::Type3 => cs::cusolverEigType_t::CUSOLVER_EIG_TYPE_3,
40 }
41 }
42}
43
44pub struct SygvdRequest<T: SolverSupported> {
45 pub a: GpuRef<T>,
46 pub b: GpuRef<T>,
47 pub n: i32,
48 pub itype: EigType,
49 pub uplo: Uplo,
50 pub w: GpuRef<T>,
51 pub compute_vectors: bool,
52 pub reply: oneshot::Sender<Result<(), GpuError>>,
53}
54
55pub struct HegvdRequest<T: SolverSupported> {
61 pub a: GpuRef<T>,
62 pub b: GpuRef<T>,
63 pub n: i32,
64 pub itype: EigType,
65 pub uplo: Uplo,
66 pub w: GpuRef<T>,
67 pub compute_vectors: bool,
68 pub reply: oneshot::Sender<Result<(), GpuError>>,
69}
70
71impl<T> SolverDispatch for SygvdRequest<T>
72where
73 T: SolverSupported + SolverScalar,
74{
75 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
76 let SygvdRequest {
77 a,
78 b,
79 n,
80 itype,
81 uplo,
82 w,
83 compute_vectors,
84 reply,
85 } = *self;
86 run_sygvd::<T>(cells, a, b, n, itype, uplo, w, compute_vectors, reply);
87 }
88
89 fn dispatch_mock(self: Box<Self>) {
90 let _ = self.reply.send(Err(GpuError::Unrecoverable(
91 "SolverActor in mock mode".into(),
92 )));
93 }
94}
95
96impl<T> SolverDispatch for HegvdRequest<T>
97where
98 T: SolverSupported + SolverScalar,
99{
100 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
101 let HegvdRequest {
102 a,
103 b,
104 n,
105 itype,
106 uplo,
107 w,
108 compute_vectors,
109 reply,
110 } = *self;
111 run_sygvd::<T>(cells, a, b, n, itype, uplo, w, compute_vectors, reply);
112 }
113
114 fn dispatch_mock(self: Box<Self>) {
115 let _ = self.reply.send(Err(GpuError::Unrecoverable(
116 "SolverActor in mock mode".into(),
117 )));
118 }
119}
120
121fn run_sygvd<T: SolverScalar>(
122 cells: SolverCells<'_>,
123 a: GpuRef<T>,
124 b: GpuRef<T>,
125 n: i32,
126 itype: EigType,
127 uplo: Uplo,
128 w: GpuRef<T>,
129 compute_vectors: bool,
130 reply: oneshot::Sender<Result<(), GpuError>>,
131) {
132 let SolverCells {
133 handle,
134 stream,
135 completion,
136 workspace,
137 info,
138 ..
139 } = cells;
140
141 let a_slice = match a.access() {
142 Ok(s) => s.clone(),
143 Err(e) => {
144 let _ = reply.send(Err(e));
145 return;
146 }
147 };
148 let b_slice = match b.access() {
149 Ok(s) => s.clone(),
150 Err(e) => {
151 let _ = reply.send(Err(e));
152 return;
153 }
154 };
155 let w_slice = match w.access() {
156 Ok(s) => s.clone(),
157 Err(e) => {
158 let _ = reply.send(Err(e));
159 return;
160 }
161 };
162 let mut a_owned = match Arc::try_unwrap(a_slice) {
163 Ok(s) => s,
164 Err(_) => {
165 let _ = reply.send(Err(GpuError::Unrecoverable(
166 "Sygvd a has multiple live references".into(),
167 )));
168 return;
169 }
170 };
171 let mut b_owned = match Arc::try_unwrap(b_slice) {
172 Ok(s) => s,
173 Err(_) => {
174 let _ = reply.send(Err(GpuError::Unrecoverable(
175 "Sygvd b has multiple live references".into(),
176 )));
177 return;
178 }
179 };
180 let mut w_owned = match Arc::try_unwrap(w_slice) {
181 Ok(s) => s,
182 Err(_) => {
183 let _ = reply.send(Err(GpuError::Unrecoverable(
184 "Sygvd w has multiple live references".into(),
185 )));
186 return;
187 }
188 };
189
190 let fill = uplo.as_cusolver_fill();
191 let jobz = if compute_vectors {
192 cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_VECTOR
193 } else {
194 cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_NOVECTOR
195 };
196 let itype_cs = itype.as_cusolver();
197
198 let mut lwork = 0i32;
199 {
200 let h = handle.lock();
201 let (a_ptr, _ga) = a_owned.device_ptr_mut(stream);
202 let (b_ptr, _gb) = b_owned.device_ptr_mut(stream);
203 let (w_ptr, _gw) = w_owned.device_ptr_mut(stream);
204 let status = unsafe {
205 T::sygvd_buffer_size(
206 h.0.cu(),
207 itype_cs,
208 jobz,
209 fill,
210 n,
211 a_ptr as *const T,
212 n,
213 b_ptr as *const T,
214 n,
215 w_ptr as *const T,
216 &mut lwork as *mut _,
217 )
218 };
219 drop((_ga, _gb, _gw));
220 if let Err(e) = status_to_result(status, "sygvd_bufferSize") {
221 let _ = reply.send(Err(e));
222 return;
223 }
224 }
225 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
226 let _ = reply.send(Err(e));
227 return;
228 }
229
230 a.record_write(stream);
231 b.record_write(stream);
232 w.record_write(stream);
233
234 let stream_for_check = stream.clone();
235 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
236 let h = handle.lock();
237 let mut ws = workspace.lock();
238 let mut info_lock = info.lock();
239 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
240 let (b_ptr, _g2) = b_owned.device_ptr_mut(&stream_for_check);
241 let (w_ptr, _g3) = w_owned.device_ptr_mut(&stream_for_check);
242 let ws_slice = ws.as_mut().expect("workspace ensured");
243 let (ws_ptr, _g4) = ws_slice.device_ptr_mut(&stream_for_check);
244 let (info_ptr, _g5) = info_lock.device_ptr_mut(&stream_for_check);
245 let status = unsafe {
246 T::sygvd(
247 h.0.cu(),
248 itype_cs,
249 jobz,
250 fill,
251 n,
252 a_ptr as *mut T,
253 n,
254 b_ptr as *mut T,
255 n,
256 w_ptr as *mut T,
257 ws_ptr as *mut T,
258 lwork,
259 info_ptr as *mut i32,
260 )
261 };
262 drop((_g1, _g2, _g3, _g4, _g5));
263 status_to_result(status, "sygvd")?;
264 check_info(info, &stream_for_check, "sygvd")?;
265 Ok((a_owned, b_owned, w_owned))
266 });
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn sygvd_request_round_trip() {
275 fn assert_dispatch<R: SolverDispatch>() {}
276 assert_dispatch::<SygvdRequest<f32>>();
277 assert_dispatch::<SygvdRequest<f64>>();
278 assert_dispatch::<HegvdRequest<f32>>();
279 assert_dispatch::<HegvdRequest<f64>>();
280 }
281}