Skip to main content

atomr_accel_cuda/kernel/solver/
mod.rs

1//! `SolverActor` — wraps a [`cudarc::cusolver::DnHandle`] for dense
2//! linear algebra and a [`cudarc::cusolver::SpHandle`] for sparse
3//! solves (gated `cusolver-sp`).
4//!
5//! Phase 1 cuSOLVER scope:
6//! - Dense: `Qr`, `Lu` (factorize / solve), `Cholesky`, `Svd`,
7//!   `Syevd` for f32 and f64 (see [`dense`]).
8//! - Batched: `getrfBatched` (cuBLAS-side LU, lifted into this
9//!   actor for symmetry), `potrfBatched`, `gesvdjBatched` (see
10//!   [`batched`]).
11//! - Generalized symmetric eigenvalue: `Sygvd` / `Hegvd` (real
12//!   variants today; complex Hermitian deferred — see
13//!   [`generalized`]).
14//! - Sparse: `cusolverSp` `Cholesky` / `QR` solves over CSR matrices
15//!   (gated `cusolver-sp`, see [`sparse`]).
16//!
17//! Implementation notes:
18//! - cudarc 0.19's safe layer exposes only handle management; per-op
19//!   entry points live in `cusolver::sys` and are wired through
20//!   [`crate::sys::cusolver::SolverScalar`] into a dtype-generic
21//!   surface (see `crate/src/sys/cusolver.rs`).
22//! - Each op queries the cuSOLVER workspace size, grows our on-demand
23//!   `CudaSlice<u8>` workspace, then dispatches the factorisation. The
24//!   1-element `info` buffer is read back to detect failures (singular
25//!   matrix, illegal arg, etc.).
26//! - `SolverMsg::Op(Box<dyn SolverDispatch>)` is the canonical
27//!   surface; the typed `Qr` / `Lu` / `Cholesky` / `Svd` / `Syevd`
28//!   variants are kept as `#[deprecated]` aliases for backward
29//!   compatibility with the f32-only Phase 0 layout.
30
31use std::sync::Arc;
32
33use async_trait::async_trait;
34use atomr_core::actor::{Actor, Context, Props};
35use cudarc::cusolver::DnHandle;
36use cudarc::driver::CudaSlice;
37use parking_lot::Mutex;
38use tokio::sync::oneshot;
39
40use crate::completion::CompletionStrategy;
41use crate::device::DeviceState;
42use crate::error::GpuError;
43use crate::gpu_ref::GpuRef;
44use crate::stream::StreamAllocator;
45
46pub mod batched;
47pub mod dense;
48pub mod generalized;
49#[cfg(feature = "cusolver-sp")]
50pub mod sparse;
51mod workspace;
52
53pub use batched::{GesvdjBatchedRequest, GetrfBatchedRequest, PotrfBatchedRequest};
54pub use dense::{CholeskyRequest, LuRequest, LuSolveRequest, QrRequest, SvdRequest, SyevdRequest};
55pub use generalized::{HegvdRequest, SygvdRequest};
56#[cfg(feature = "cusolver-sp")]
57pub use sparse::{SparseCholeskyRequest, SparseLuRequest, SparseQrRequest};
58
59/// Storage triangle for symmetric / Hermitian / triangular factorisations.
60#[derive(Debug, Clone, Copy)]
61pub enum Uplo {
62    Upper,
63    Lower,
64}
65
66impl Uplo {
67    pub(crate) fn as_cusolver_fill(self) -> cudarc::cusolver::sys::cublasFillMode_t {
68        use cudarc::cusolver::sys::cublasFillMode_t;
69        match self {
70            Uplo::Upper => cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
71            Uplo::Lower => cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
72        }
73    }
74}
75
76/// Crate-private cells the dispatch traits operate against. Passed
77/// to [`SolverDispatch::dispatch`] as a single bundle so each op
78/// implementation only depends on what it actually uses.
79///
80/// The struct itself is publicly visible because [`SolverDispatch`]
81/// is a public trait whose method takes a `SolverCells<'_>`, but
82/// every field is `pub(crate)` since the `SendDn` / `SendSp`
83/// newtypes leak FFI handles that have no stable external
84/// representation. External code wires custom solver ops by
85/// implementing `SolverDispatch` only when also living inside this
86/// crate.
87pub struct SolverCells<'a> {
88    pub(crate) handle: &'a Mutex<SendDn>,
89    pub(crate) stream: &'a Arc<cudarc::driver::CudaStream>,
90    pub(crate) completion: &'a Arc<dyn CompletionStrategy>,
91    pub(crate) workspace: &'a Mutex<Option<CudaSlice<u8>>>,
92    pub(crate) info: &'a Mutex<CudaSlice<i32>>,
93    #[cfg(feature = "cusolver-sp")]
94    pub(crate) sp_handle: &'a Mutex<Option<SendSp>>,
95}
96
97/// Trait implemented by every solver request. The actor turns the
98/// boxed trait object back into a typed launcher and forwards the
99/// runtime cells.
100pub trait SolverDispatch: Send + 'static {
101    /// Execute the request against a real cuSOLVER handle.
102    fn dispatch(self: Box<Self>, cells: SolverCells<'_>);
103
104    /// Reply with a "mock mode" error without touching the GPU.
105    /// Default impl drops `self` so the caller's oneshot closes;
106    /// per-request impls override to send a typed `Err`.
107    fn dispatch_mock(self: Box<Self>) {
108        drop(self);
109    }
110}
111
112pub enum SolverMsg {
113    /// Canonical, dtype-generic surface. New code should prefer this
114    /// over the legacy enum variants.
115    Op(Box<dyn SolverDispatch>),
116
117    /// Legacy QR factorize. Use [`QrRequest`] via [`SolverMsg::Op`]
118    /// instead.
119    #[deprecated(note = "use SolverMsg::Op(Box::new(QrRequest { .. }))")]
120    QrFactorize {
121        a: GpuRef<f32>,
122        m: i32,
123        n: i32,
124        tau: GpuRef<f32>,
125        reply: oneshot::Sender<Result<(), GpuError>>,
126    },
127    /// Legacy LU factorize. Use [`LuRequest`] via [`SolverMsg::Op`].
128    #[deprecated(note = "use SolverMsg::Op(Box::new(LuRequest { .. }))")]
129    LuFactorize {
130        a: GpuRef<f32>,
131        m: i32,
132        n: i32,
133        ipiv: GpuRef<i32>,
134        reply: oneshot::Sender<Result<(), GpuError>>,
135    },
136    /// Legacy LU solve. Use [`LuSolveRequest`] via [`SolverMsg::Op`].
137    #[deprecated(note = "use SolverMsg::Op(Box::new(LuSolveRequest { .. }))")]
138    LuSolve {
139        lu: GpuRef<f32>,
140        ipiv: GpuRef<i32>,
141        b: GpuRef<f32>,
142        n: i32,
143        nrhs: i32,
144        trans: bool,
145        reply: oneshot::Sender<Result<(), GpuError>>,
146    },
147    /// Legacy Cholesky. Use [`CholeskyRequest`] via [`SolverMsg::Op`].
148    #[deprecated(note = "use SolverMsg::Op(Box::new(CholeskyRequest { .. }))")]
149    Cholesky {
150        a: GpuRef<f32>,
151        n: i32,
152        uplo: Uplo,
153        reply: oneshot::Sender<Result<(), GpuError>>,
154    },
155    /// Legacy SVD. Use [`SvdRequest`] via [`SolverMsg::Op`].
156    #[deprecated(note = "use SolverMsg::Op(Box::new(SvdRequest { .. }))")]
157    Svd {
158        a: GpuRef<f32>,
159        m: i32,
160        n: i32,
161        s: GpuRef<f32>,
162        u: Option<GpuRef<f32>>,
163        vt: Option<GpuRef<f32>>,
164        reply: oneshot::Sender<Result<(), GpuError>>,
165    },
166    /// Legacy symmetric eigendecomposition. Use [`SyevdRequest`] via
167    /// [`SolverMsg::Op`].
168    #[deprecated(note = "use SolverMsg::Op(Box::new(SyevdRequest { .. }))")]
169    Syevd {
170        a: GpuRef<f32>,
171        n: i32,
172        uplo: Uplo,
173        w: GpuRef<f32>,
174        compute_vectors: bool,
175        reply: oneshot::Sender<Result<(), GpuError>>,
176    },
177}
178
179pub struct SolverActor {
180    inner: SolverInner,
181}
182
183pub(crate) struct SendDn(pub(crate) DnHandle);
184unsafe impl Send for SendDn {}
185unsafe impl Sync for SendDn {}
186
187#[cfg(feature = "cusolver-sp")]
188pub(crate) struct SendSp(pub(crate) cudarc::cusolver::SpHandle);
189#[cfg(feature = "cusolver-sp")]
190unsafe impl Send for SendSp {}
191#[cfg(feature = "cusolver-sp")]
192unsafe impl Sync for SendSp {}
193
194#[allow(dead_code)]
195enum SolverInner {
196    Real {
197        handle: Mutex<SendDn>,
198        stream: Arc<cudarc::driver::CudaStream>,
199        completion: Arc<dyn CompletionStrategy>,
200        state: Arc<DeviceState>,
201        /// On-demand-grown scratch buffer (in bytes; we widen from
202        /// the per-op f32/f64 workspaces by multiplying out
203        /// `lwork * size_of::<T>()`). Never shrunk; rebuilt fresh on
204        /// context restart.
205        workspace: Mutex<Option<CudaSlice<u8>>>,
206        /// 1-element `i32` info buffer reused across calls.
207        info: Mutex<CudaSlice<i32>>,
208        /// Lazy `cusolverSp` handle; created on first sparse op.
209        #[cfg(feature = "cusolver-sp")]
210        sp_handle: Mutex<Option<SendSp>>,
211    },
212    Mock,
213}
214
215impl SolverActor {
216    pub fn props(
217        stream: Arc<cudarc::driver::CudaStream>,
218        _allocator: Arc<dyn StreamAllocator>,
219        completion: Arc<dyn CompletionStrategy>,
220        state: Arc<DeviceState>,
221    ) -> Props<Self> {
222        Props::create(move || {
223            let handle = match DnHandle::new(stream.clone()) {
224                Ok(h) => h,
225                Err(e) => panic!("ContextPoisoned: DnHandle::new failed: {e}"),
226            };
227            let info = stream
228                .alloc_zeros::<i32>(1)
229                .unwrap_or_else(|e| panic!("ContextPoisoned: alloc info: {e}"));
230            SolverActor {
231                inner: SolverInner::Real {
232                    handle: Mutex::new(SendDn(handle)),
233                    stream: stream.clone(),
234                    completion: completion.clone(),
235                    state: state.clone(),
236                    workspace: Mutex::new(None),
237                    info: Mutex::new(info),
238                    #[cfg(feature = "cusolver-sp")]
239                    sp_handle: Mutex::new(None),
240                },
241            }
242        })
243    }
244
245    pub fn mock_props() -> Props<Self> {
246        Props::create(|| SolverActor {
247            inner: SolverInner::Mock,
248        })
249    }
250}
251
252#[async_trait]
253impl Actor for SolverActor {
254    type Msg = SolverMsg;
255
256    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: SolverMsg) {
257        match &self.inner {
258            SolverInner::Mock => mock_reply(msg),
259            SolverInner::Real {
260                handle,
261                stream,
262                completion,
263                workspace,
264                info,
265                #[cfg(feature = "cusolver-sp")]
266                sp_handle,
267                ..
268            } => {
269                let cells = SolverCells {
270                    handle,
271                    stream,
272                    completion,
273                    workspace,
274                    info,
275                    #[cfg(feature = "cusolver-sp")]
276                    sp_handle,
277                };
278                dispatch_msg(msg, cells);
279            }
280        }
281    }
282}
283
284#[allow(deprecated)]
285fn dispatch_msg(msg: SolverMsg, cells: SolverCells<'_>) {
286    match msg {
287        SolverMsg::Op(op) => op.dispatch(cells),
288        SolverMsg::QrFactorize {
289            a,
290            m,
291            n,
292            tau,
293            reply,
294        } => Box::new(QrRequest::<f32> {
295            a,
296            m,
297            n,
298            tau,
299            reply,
300        })
301        .dispatch(cells),
302        SolverMsg::LuFactorize {
303            a,
304            m,
305            n,
306            ipiv,
307            reply,
308        } => Box::new(LuRequest::<f32> {
309            a,
310            m,
311            n,
312            ipiv,
313            reply,
314        })
315        .dispatch(cells),
316        SolverMsg::LuSolve {
317            lu,
318            ipiv,
319            b,
320            n,
321            nrhs,
322            trans,
323            reply,
324        } => Box::new(LuSolveRequest::<f32> {
325            lu,
326            ipiv,
327            b,
328            n,
329            nrhs,
330            trans,
331            reply,
332        })
333        .dispatch(cells),
334        SolverMsg::Cholesky { a, n, uplo, reply } => {
335            Box::new(CholeskyRequest::<f32> { a, n, uplo, reply }).dispatch(cells)
336        }
337        SolverMsg::Svd {
338            a,
339            m,
340            n,
341            s,
342            u,
343            vt,
344            reply,
345        } => Box::new(SvdRequest::<f32> {
346            a,
347            m,
348            n,
349            s,
350            u,
351            vt,
352            reply,
353        })
354        .dispatch(cells),
355        SolverMsg::Syevd {
356            a,
357            n,
358            uplo,
359            w,
360            compute_vectors,
361            reply,
362        } => Box::new(SyevdRequest::<f32> {
363            a,
364            n,
365            uplo,
366            w,
367            compute_vectors,
368            reply,
369        })
370        .dispatch(cells),
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use tokio::sync::oneshot;
378
379    /// The deprecated `SolverMsg::QrFactorize` alias must still
380    /// construct (and the actor must still route it through the
381    /// dispatch path) so existing applications can compile during
382    /// the Phase 1 transition window.
383    #[test]
384    #[allow(deprecated)]
385    fn deprecated_qr_alias_still_constructs() {
386        // We can't run the actor without a GPU; just ensure the
387        // variant is constructible and matches the documented shape.
388        let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
389        // Use placeholder values; we never `tell` it to a live actor.
390        // Construction is enough to assert the deprecated surface
391        // hasn't been removed.
392        let make = move |reply: oneshot::Sender<Result<(), GpuError>>| -> &'static str {
393            // Compile-time only: build the deprecated variant.
394            // We avoid constructing a real GpuRef by deferring to a
395            // closure that's never called.
396            #[allow(dead_code)]
397            #[allow(deprecated)]
398            fn _check(
399                a: GpuRef<f32>,
400                tau: GpuRef<f32>,
401                reply: oneshot::Sender<Result<(), GpuError>>,
402            ) -> SolverMsg {
403                SolverMsg::QrFactorize {
404                    a,
405                    m: 0,
406                    n: 0,
407                    tau,
408                    reply,
409                }
410            }
411            drop(reply);
412            "ok"
413        };
414        assert_eq!(make(tx), "ok");
415    }
416}
417
418#[allow(deprecated)]
419fn mock_reply(msg: SolverMsg) {
420    let err = || GpuError::Unrecoverable("SolverActor in mock mode".into());
421    match msg {
422        SolverMsg::Op(op) => op.dispatch_mock(),
423        SolverMsg::QrFactorize { reply, .. }
424        | SolverMsg::LuFactorize { reply, .. }
425        | SolverMsg::LuSolve { reply, .. }
426        | SolverMsg::Cholesky { reply, .. }
427        | SolverMsg::Svd { reply, .. }
428        | SolverMsg::Syevd { reply, .. } => {
429            let _ = reply.send(Err(err()));
430        }
431    }
432}