1use std::sync::Arc;
32
33use async_trait::async_trait;
34use atomr_core::actor::{Actor, Context, Props};
35use cudarc::cusolver::DnHandle;
36use cudarc::driver::CudaSlice;
37use parking_lot::Mutex;
38use tokio::sync::oneshot;
39
40use crate::completion::CompletionStrategy;
41use crate::device::DeviceState;
42use crate::error::GpuError;
43use crate::gpu_ref::GpuRef;
44use crate::stream::StreamAllocator;
45
46pub mod batched;
47pub mod dense;
48pub mod generalized;
49#[cfg(feature = "cusolver-sp")]
50pub mod sparse;
51mod workspace;
52
53pub use batched::{GesvdjBatchedRequest, GetrfBatchedRequest, PotrfBatchedRequest};
54pub use dense::{CholeskyRequest, LuRequest, LuSolveRequest, QrRequest, SvdRequest, SyevdRequest};
55pub use generalized::{HegvdRequest, SygvdRequest};
56#[cfg(feature = "cusolver-sp")]
57pub use sparse::{SparseCholeskyRequest, SparseLuRequest, SparseQrRequest};
58
59#[derive(Debug, Clone, Copy)]
61pub enum Uplo {
62 Upper,
63 Lower,
64}
65
66impl Uplo {
67 pub(crate) fn as_cusolver_fill(self) -> cudarc::cusolver::sys::cublasFillMode_t {
68 use cudarc::cusolver::sys::cublasFillMode_t;
69 match self {
70 Uplo::Upper => cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
71 Uplo::Lower => cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
72 }
73 }
74}
75
76pub struct SolverCells<'a> {
88 pub(crate) handle: &'a Mutex<SendDn>,
89 pub(crate) stream: &'a Arc<cudarc::driver::CudaStream>,
90 pub(crate) completion: &'a Arc<dyn CompletionStrategy>,
91 pub(crate) workspace: &'a Mutex<Option<CudaSlice<u8>>>,
92 pub(crate) info: &'a Mutex<CudaSlice<i32>>,
93 #[cfg(feature = "cusolver-sp")]
94 pub(crate) sp_handle: &'a Mutex<Option<SendSp>>,
95}
96
97pub trait SolverDispatch: Send + 'static {
101 fn dispatch(self: Box<Self>, cells: SolverCells<'_>);
103
104 fn dispatch_mock(self: Box<Self>) {
108 drop(self);
109 }
110}
111
112pub enum SolverMsg {
113 Op(Box<dyn SolverDispatch>),
116
117 #[deprecated(note = "use SolverMsg::Op(Box::new(QrRequest { .. }))")]
120 QrFactorize {
121 a: GpuRef<f32>,
122 m: i32,
123 n: i32,
124 tau: GpuRef<f32>,
125 reply: oneshot::Sender<Result<(), GpuError>>,
126 },
127 #[deprecated(note = "use SolverMsg::Op(Box::new(LuRequest { .. }))")]
129 LuFactorize {
130 a: GpuRef<f32>,
131 m: i32,
132 n: i32,
133 ipiv: GpuRef<i32>,
134 reply: oneshot::Sender<Result<(), GpuError>>,
135 },
136 #[deprecated(note = "use SolverMsg::Op(Box::new(LuSolveRequest { .. }))")]
138 LuSolve {
139 lu: GpuRef<f32>,
140 ipiv: GpuRef<i32>,
141 b: GpuRef<f32>,
142 n: i32,
143 nrhs: i32,
144 trans: bool,
145 reply: oneshot::Sender<Result<(), GpuError>>,
146 },
147 #[deprecated(note = "use SolverMsg::Op(Box::new(CholeskyRequest { .. }))")]
149 Cholesky {
150 a: GpuRef<f32>,
151 n: i32,
152 uplo: Uplo,
153 reply: oneshot::Sender<Result<(), GpuError>>,
154 },
155 #[deprecated(note = "use SolverMsg::Op(Box::new(SvdRequest { .. }))")]
157 Svd {
158 a: GpuRef<f32>,
159 m: i32,
160 n: i32,
161 s: GpuRef<f32>,
162 u: Option<GpuRef<f32>>,
163 vt: Option<GpuRef<f32>>,
164 reply: oneshot::Sender<Result<(), GpuError>>,
165 },
166 #[deprecated(note = "use SolverMsg::Op(Box::new(SyevdRequest { .. }))")]
169 Syevd {
170 a: GpuRef<f32>,
171 n: i32,
172 uplo: Uplo,
173 w: GpuRef<f32>,
174 compute_vectors: bool,
175 reply: oneshot::Sender<Result<(), GpuError>>,
176 },
177}
178
179pub struct SolverActor {
180 inner: SolverInner,
181}
182
183pub(crate) struct SendDn(pub(crate) DnHandle);
184unsafe impl Send for SendDn {}
185unsafe impl Sync for SendDn {}
186
187#[cfg(feature = "cusolver-sp")]
188pub(crate) struct SendSp(pub(crate) cudarc::cusolver::SpHandle);
189#[cfg(feature = "cusolver-sp")]
190unsafe impl Send for SendSp {}
191#[cfg(feature = "cusolver-sp")]
192unsafe impl Sync for SendSp {}
193
194#[allow(dead_code)]
195enum SolverInner {
196 Real {
197 handle: Mutex<SendDn>,
198 stream: Arc<cudarc::driver::CudaStream>,
199 completion: Arc<dyn CompletionStrategy>,
200 state: Arc<DeviceState>,
201 workspace: Mutex<Option<CudaSlice<u8>>>,
206 info: Mutex<CudaSlice<i32>>,
208 #[cfg(feature = "cusolver-sp")]
210 sp_handle: Mutex<Option<SendSp>>,
211 },
212 Mock,
213}
214
215impl SolverActor {
216 pub fn props(
217 stream: Arc<cudarc::driver::CudaStream>,
218 _allocator: Arc<dyn StreamAllocator>,
219 completion: Arc<dyn CompletionStrategy>,
220 state: Arc<DeviceState>,
221 ) -> Props<Self> {
222 Props::create(move || {
223 let handle = match DnHandle::new(stream.clone()) {
224 Ok(h) => h,
225 Err(e) => panic!("ContextPoisoned: DnHandle::new failed: {e}"),
226 };
227 let info = stream
228 .alloc_zeros::<i32>(1)
229 .unwrap_or_else(|e| panic!("ContextPoisoned: alloc info: {e}"));
230 SolverActor {
231 inner: SolverInner::Real {
232 handle: Mutex::new(SendDn(handle)),
233 stream: stream.clone(),
234 completion: completion.clone(),
235 state: state.clone(),
236 workspace: Mutex::new(None),
237 info: Mutex::new(info),
238 #[cfg(feature = "cusolver-sp")]
239 sp_handle: Mutex::new(None),
240 },
241 }
242 })
243 }
244
245 pub fn mock_props() -> Props<Self> {
246 Props::create(|| SolverActor {
247 inner: SolverInner::Mock,
248 })
249 }
250}
251
252#[async_trait]
253impl Actor for SolverActor {
254 type Msg = SolverMsg;
255
256 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: SolverMsg) {
257 match &self.inner {
258 SolverInner::Mock => mock_reply(msg),
259 SolverInner::Real {
260 handle,
261 stream,
262 completion,
263 workspace,
264 info,
265 #[cfg(feature = "cusolver-sp")]
266 sp_handle,
267 ..
268 } => {
269 let cells = SolverCells {
270 handle,
271 stream,
272 completion,
273 workspace,
274 info,
275 #[cfg(feature = "cusolver-sp")]
276 sp_handle,
277 };
278 dispatch_msg(msg, cells);
279 }
280 }
281 }
282}
283
284#[allow(deprecated)]
285fn dispatch_msg(msg: SolverMsg, cells: SolverCells<'_>) {
286 match msg {
287 SolverMsg::Op(op) => op.dispatch(cells),
288 SolverMsg::QrFactorize {
289 a,
290 m,
291 n,
292 tau,
293 reply,
294 } => Box::new(QrRequest::<f32> {
295 a,
296 m,
297 n,
298 tau,
299 reply,
300 })
301 .dispatch(cells),
302 SolverMsg::LuFactorize {
303 a,
304 m,
305 n,
306 ipiv,
307 reply,
308 } => Box::new(LuRequest::<f32> {
309 a,
310 m,
311 n,
312 ipiv,
313 reply,
314 })
315 .dispatch(cells),
316 SolverMsg::LuSolve {
317 lu,
318 ipiv,
319 b,
320 n,
321 nrhs,
322 trans,
323 reply,
324 } => Box::new(LuSolveRequest::<f32> {
325 lu,
326 ipiv,
327 b,
328 n,
329 nrhs,
330 trans,
331 reply,
332 })
333 .dispatch(cells),
334 SolverMsg::Cholesky { a, n, uplo, reply } => {
335 Box::new(CholeskyRequest::<f32> { a, n, uplo, reply }).dispatch(cells)
336 }
337 SolverMsg::Svd {
338 a,
339 m,
340 n,
341 s,
342 u,
343 vt,
344 reply,
345 } => Box::new(SvdRequest::<f32> {
346 a,
347 m,
348 n,
349 s,
350 u,
351 vt,
352 reply,
353 })
354 .dispatch(cells),
355 SolverMsg::Syevd {
356 a,
357 n,
358 uplo,
359 w,
360 compute_vectors,
361 reply,
362 } => Box::new(SyevdRequest::<f32> {
363 a,
364 n,
365 uplo,
366 w,
367 compute_vectors,
368 reply,
369 })
370 .dispatch(cells),
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use tokio::sync::oneshot;
378
379 #[test]
384 #[allow(deprecated)]
385 fn deprecated_qr_alias_still_constructs() {
386 let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
389 let make = move |reply: oneshot::Sender<Result<(), GpuError>>| -> &'static str {
393 #[allow(dead_code)]
397 #[allow(deprecated)]
398 fn _check(
399 a: GpuRef<f32>,
400 tau: GpuRef<f32>,
401 reply: oneshot::Sender<Result<(), GpuError>>,
402 ) -> SolverMsg {
403 SolverMsg::QrFactorize {
404 a,
405 m: 0,
406 n: 0,
407 tau,
408 reply,
409 }
410 }
411 drop(reply);
412 "ok"
413 };
414 assert_eq!(make(tx), "ok");
415 }
416}
417
418#[allow(deprecated)]
419fn mock_reply(msg: SolverMsg) {
420 let err = || GpuError::Unrecoverable("SolverActor in mock mode".into());
421 match msg {
422 SolverMsg::Op(op) => op.dispatch_mock(),
423 SolverMsg::QrFactorize { reply, .. }
424 | SolverMsg::LuFactorize { reply, .. }
425 | SolverMsg::LuSolve { reply, .. }
426 | SolverMsg::Cholesky { reply, .. }
427 | SolverMsg::Svd { reply, .. }
428 | SolverMsg::Syevd { reply, .. } => {
429 let _ = reply.send(Err(err()));
430 }
431 }
432}