Skip to main content

atomr_accel_cuda/kernel/solver/
batched.rs

1//! Batched cuSOLVER ops.
2//!
3//! cuSOLVER ships native batched variants for Cholesky
4//! (`Spotrf/DpotrfBatched`) and Jacobi SVD
5//! (`Sgesvdj/DgesvdjBatched`). The batched LU live on the cuBLAS side
6//! (`cublasSgetrf/DgetrfBatched`); we adopt them into `SolverActor` for
7//! API symmetry — each request takes a contiguous strided block of
8//! `batch_size × n × n` matrices and stages an array-of-pointers on
9//! demand.
10//!
11//! All three requests share the same input shape:
12//! - `a: GpuRef<T>` — `batch_size × m × n × stride` packed
13//!   column-major, with stride equal to `m * n` (no extra padding).
14//! - `batch_size: i32` — number of independent problems.
15//!
16//! Per-batch info codes are read back into a host vec; any non-zero
17//! entry surfaces as a `LibraryError` identifying the failing batch
18//! index.
19
20use std::ffi::c_int;
21use std::sync::Arc;
22
23use cudarc::cusolver::sys as cs;
24use cudarc::driver::{DevicePtr, DevicePtrMut};
25use parking_lot::Mutex;
26use tokio::sync::oneshot;
27
28use crate::dtype::SolverSupported;
29use crate::error::GpuError;
30use crate::gpu_ref::GpuRef;
31use crate::kernel::envelope;
32use crate::sys::cusolver::{status_to_result, SolverScalar, LIB};
33
34use super::workspace::{check_info_array, ensure_workspace_bytes, lwork_bytes};
35use super::{SolverCells, SolverDispatch, Uplo};
36
37// =====================================================================
38// LU batched (`cublasSgetrfBatched` / `cublasDgetrfBatched`)
39// =====================================================================
40
41pub struct GetrfBatchedRequest<T: SolverSupported> {
42    /// Contiguous `batch_size × n × n` column-major buffer.
43    pub a: GpuRef<T>,
44    /// Square problem size.
45    pub n: i32,
46    /// Number of independent problems.
47    pub batch_size: i32,
48    /// Pivot indices: `batch_size * n` `i32` entries.
49    pub ipiv: GpuRef<i32>,
50    pub reply: oneshot::Sender<Result<(), GpuError>>,
51}
52
53/// Per-dtype dispatch into `cublas[SD]getrfBatched`. The cuBLAS
54/// batched LU lives outside cuSOLVER; we keep the FFI thunks
55/// inline (rather than expanding `SolverScalar`) so the cuBLAS
56/// symbol surface stays scoped to this module.
57trait BatchedLu: SolverScalar {
58    unsafe fn getrf_batched(
59        handle: cudarc::cublas::sys::cublasHandle_t,
60        n: c_int,
61        a_array: *const *mut Self,
62        lda: c_int,
63        pivots: *mut c_int,
64        info_array: *mut c_int,
65        batch_size: c_int,
66    ) -> cudarc::cublas::sys::cublasStatus_t;
67}
68
69impl BatchedLu for f32 {
70    unsafe fn getrf_batched(
71        handle: cudarc::cublas::sys::cublasHandle_t,
72        n: c_int,
73        a_array: *const *mut Self,
74        lda: c_int,
75        pivots: *mut c_int,
76        info_array: *mut c_int,
77        batch_size: c_int,
78    ) -> cudarc::cublas::sys::cublasStatus_t {
79        cudarc::cublas::sys::cublasSgetrfBatched(
80            handle, n, a_array, lda, pivots, info_array, batch_size,
81        )
82    }
83}
84
85impl BatchedLu for f64 {
86    unsafe fn getrf_batched(
87        handle: cudarc::cublas::sys::cublasHandle_t,
88        n: c_int,
89        a_array: *const *mut Self,
90        lda: c_int,
91        pivots: *mut c_int,
92        info_array: *mut c_int,
93        batch_size: c_int,
94    ) -> cudarc::cublas::sys::cublasStatus_t {
95        cudarc::cublas::sys::cublasDgetrfBatched(
96            handle, n, a_array, lda, pivots, info_array, batch_size,
97        )
98    }
99}
100
101impl<T> SolverDispatch for GetrfBatchedRequest<T>
102where
103    T: SolverSupported + SolverScalar + BatchedLu,
104{
105    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
106        let GetrfBatchedRequest {
107            a,
108            n,
109            batch_size,
110            ipiv,
111            reply,
112        } = *self;
113        run_getrf_batched::<T>(cells, a, n, batch_size, ipiv, reply);
114    }
115
116    fn dispatch_mock(self: Box<Self>) {
117        let _ = self.reply.send(Err(GpuError::Unrecoverable(
118            "SolverActor in mock mode".into(),
119        )));
120    }
121}
122
123/// Build a contiguous `Vec<*mut T>` of per-batch starting pointers and
124/// upload it as a `CudaSlice<u64>` (raw pointer values). We use the
125/// stream's `memcpy_htod` since the device pointer table is unique to
126/// this launch and short-lived.
127fn upload_pointer_table<T>(
128    stream: &Arc<cudarc::driver::CudaStream>,
129    base: *mut T,
130    batch: i32,
131    n: i32,
132) -> Result<cudarc::driver::CudaSlice<u64>, GpuError> {
133    let count = batch.max(0) as usize;
134    let stride_bytes = (n.max(0) as usize) * (n.max(0) as usize) * std::mem::size_of::<T>();
135    let mut ptrs = Vec::with_capacity(count);
136    for i in 0..count {
137        let p = (base as usize).saturating_add(i * stride_bytes);
138        ptrs.push(p as u64);
139    }
140    let mut buf = stream
141        .alloc_zeros::<u64>(count.max(1))
142        .map_err(|e| GpuError::OutOfMemory(format!("ptr table ({count}): {e}")))?;
143    stream
144        .memcpy_htod(&ptrs, &mut buf)
145        .map_err(|e| GpuError::lib(LIB, format!("upload ptr table: {e}")))?;
146    Ok(buf)
147}
148
149fn run_getrf_batched<T: SolverScalar + BatchedLu>(
150    cells: SolverCells<'_>,
151    a: GpuRef<T>,
152    n: i32,
153    batch_size: i32,
154    ipiv: GpuRef<i32>,
155    reply: oneshot::Sender<Result<(), GpuError>>,
156) {
157    let SolverCells {
158        stream, completion, ..
159    } = cells;
160
161    let a_slice = match a.access() {
162        Ok(s) => s.clone(),
163        Err(e) => {
164            let _ = reply.send(Err(e));
165            return;
166        }
167    };
168    let ipiv_slice = match ipiv.access() {
169        Ok(s) => s.clone(),
170        Err(e) => {
171            let _ = reply.send(Err(e));
172            return;
173        }
174    };
175    let mut a_owned = match Arc::try_unwrap(a_slice) {
176        Ok(s) => s,
177        Err(_) => {
178            let _ = reply.send(Err(GpuError::Unrecoverable(
179                "GetrfBatched a has multiple live references".into(),
180            )));
181            return;
182        }
183    };
184    let mut ipiv_owned = match Arc::try_unwrap(ipiv_slice) {
185        Ok(s) => s,
186        Err(_) => {
187            let _ = reply.send(Err(GpuError::Unrecoverable(
188                "GetrfBatched ipiv has multiple live references".into(),
189            )));
190            return;
191        }
192    };
193
194    // Lazy cuBLAS handle: cuSOLVER doesn't ship batched LU, so we
195    // bind cuBLAS into the cuSOLVER actor's stream just for this
196    // op. Created/destroyed locally so it doesn't outlive the
197    // launch. cudarc 0.19's `CudaBlas::new` pins the handle to the
198    // supplied stream — exactly what we need.
199    let blas = match cudarc::cublas::CudaBlas::new(stream.clone()) {
200        Ok(b) => b,
201        Err(e) => {
202            let _ = reply.send(Err(GpuError::lib(LIB, format!("CudaBlas::new: {e}"))));
203            return;
204        }
205    };
206    let blas_handle = *blas.handle();
207
208    // Build pointer table.
209    let (a_base_ptr, _g_base) = a_owned.device_ptr_mut(stream);
210    let ptr_table = match upload_pointer_table::<T>(stream, a_base_ptr as *mut T, batch_size, n) {
211        Ok(t) => t,
212        Err(e) => {
213            let _ = reply.send(Err(e));
214            return;
215        }
216    };
217    drop(_g_base);
218
219    // info array: one int per batch entry, allocated fresh.
220    let info_array = match stream.alloc_zeros::<i32>(batch_size.max(1) as usize) {
221        Ok(b) => b,
222        Err(e) => {
223            let _ = reply.send(Err(GpuError::OutOfMemory(format!(
224                "GetrfBatched info: {e}"
225            ))));
226            return;
227        }
228    };
229
230    a.record_write(stream);
231    ipiv.record_write(stream);
232
233    let stream_for_check = stream.clone();
234    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
235        let (ptrs_dev, _gp) = ptr_table.device_ptr(&stream_for_check);
236        let (ipiv_ptr, _gpiv) = ipiv_owned.device_ptr_mut(&stream_for_check);
237        let (info_ptr, _ginfo) = info_array.device_ptr(&stream_for_check);
238        let status = unsafe {
239            T::getrf_batched(
240                blas_handle,
241                n,
242                ptrs_dev as *const *mut T,
243                n,
244                ipiv_ptr as *mut c_int,
245                info_ptr as *mut c_int,
246                batch_size,
247            )
248        };
249        drop((_gp, _gpiv, _ginfo));
250        if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
251            return Err(GpuError::lib(LIB, format!("getrfBatched: {status:?}")));
252        }
253        check_info_array(
254            &info_array,
255            &stream_for_check,
256            "getrfBatched",
257            batch_size.max(0) as usize,
258        )?;
259        // Keep blas + table alive until completion.
260        Ok((a_owned, ipiv_owned, ptr_table, info_array, blas))
261    });
262}
263
264// =====================================================================
265// Cholesky batched (`cusolverDn[SD]potrfBatched`)
266// =====================================================================
267
268pub struct PotrfBatchedRequest<T: SolverSupported> {
269    pub a: GpuRef<T>,
270    pub n: i32,
271    pub batch_size: i32,
272    pub uplo: Uplo,
273    pub reply: oneshot::Sender<Result<(), GpuError>>,
274}
275
276impl<T> SolverDispatch for PotrfBatchedRequest<T>
277where
278    T: SolverSupported + SolverScalar,
279{
280    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
281        let PotrfBatchedRequest {
282            a,
283            n,
284            batch_size,
285            uplo,
286            reply,
287        } = *self;
288        run_potrf_batched::<T>(cells, a, n, batch_size, uplo, reply);
289    }
290
291    fn dispatch_mock(self: Box<Self>) {
292        let _ = self.reply.send(Err(GpuError::Unrecoverable(
293            "SolverActor in mock mode".into(),
294        )));
295    }
296}
297
298fn run_potrf_batched<T: SolverScalar>(
299    cells: SolverCells<'_>,
300    a: GpuRef<T>,
301    n: i32,
302    batch_size: i32,
303    uplo: Uplo,
304    reply: oneshot::Sender<Result<(), GpuError>>,
305) {
306    let SolverCells {
307        handle,
308        stream,
309        completion,
310        ..
311    } = cells;
312
313    let a_slice = match a.access() {
314        Ok(s) => s.clone(),
315        Err(e) => {
316            let _ = reply.send(Err(e));
317            return;
318        }
319    };
320    let mut a_owned = match Arc::try_unwrap(a_slice) {
321        Ok(s) => s,
322        Err(_) => {
323            let _ = reply.send(Err(GpuError::Unrecoverable(
324                "PotrfBatched a has multiple live references".into(),
325            )));
326            return;
327        }
328    };
329
330    let (a_base_ptr, _g_base) = a_owned.device_ptr_mut(stream);
331    let ptr_table = match upload_pointer_table::<T>(stream, a_base_ptr as *mut T, batch_size, n) {
332        Ok(t) => t,
333        Err(e) => {
334            let _ = reply.send(Err(e));
335            return;
336        }
337    };
338    drop(_g_base);
339
340    let info_array = match stream.alloc_zeros::<i32>(batch_size.max(1) as usize) {
341        Ok(b) => b,
342        Err(e) => {
343            let _ = reply.send(Err(GpuError::OutOfMemory(format!(
344                "PotrfBatched info: {e}"
345            ))));
346            return;
347        }
348    };
349
350    a.record_write(stream);
351    let fill = uplo.as_cusolver_fill();
352    let stream_for_check = stream.clone();
353
354    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
355        let h = handle.lock();
356        let (ptrs_dev, _gp) = ptr_table.device_ptr(&stream_for_check);
357        let (info_ptr, _ginfo) = info_array.device_ptr(&stream_for_check);
358        let status = unsafe {
359            T::potrf_batched(
360                h.0.cu(),
361                fill,
362                n,
363                ptrs_dev as *mut *mut T,
364                n,
365                info_ptr as *mut i32,
366                batch_size,
367            )
368        };
369        drop((_gp, _ginfo));
370        status_to_result(status, "potrfBatched")?;
371        check_info_array(
372            &info_array,
373            &stream_for_check,
374            "potrfBatched",
375            batch_size.max(0) as usize,
376        )?;
377        Ok((a_owned, ptr_table, info_array))
378    });
379}
380
381// =====================================================================
382// Batched Jacobi SVD (`cusolverDn[SD]gesvdjBatched`)
383// =====================================================================
384
385pub struct GesvdjBatchedRequest<T: SolverSupported> {
386    /// Contiguous `batch_size × m × n` column-major buffer.
387    pub a: GpuRef<T>,
388    pub m: i32,
389    pub n: i32,
390    pub batch_size: i32,
391    /// Singular values: `batch_size * min(m, n)` entries.
392    pub s: GpuRef<T>,
393    /// Left singular vectors (`batch_size × m × m`). When `None`,
394    /// `jobz = NOVECTOR`.
395    pub u: Option<GpuRef<T>>,
396    /// Right singular vectors (`batch_size × n × n`). When `None`,
397    /// `jobz = NOVECTOR`.
398    pub v: Option<GpuRef<T>>,
399    pub reply: oneshot::Sender<Result<(), GpuError>>,
400}
401
402/// `gesvdjInfo_t` is non-Send by default (raw pointer); we wrap so
403/// the keep-alive can ride the post-launch task.
404struct GesvdjParams(cs::gesvdjInfo_t);
405unsafe impl Send for GesvdjParams {}
406impl Drop for GesvdjParams {
407    fn drop(&mut self) {
408        unsafe {
409            let _ = cs::cusolverDnDestroyGesvdjInfo(self.0);
410        }
411    }
412}
413
414impl<T> SolverDispatch for GesvdjBatchedRequest<T>
415where
416    T: SolverSupported + SolverScalar,
417{
418    fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
419        let GesvdjBatchedRequest {
420            a,
421            m,
422            n,
423            batch_size,
424            s,
425            u,
426            v,
427            reply,
428        } = *self;
429        run_gesvdj_batched::<T>(cells, a, m, n, batch_size, s, u, v, reply);
430    }
431
432    fn dispatch_mock(self: Box<Self>) {
433        let _ = self.reply.send(Err(GpuError::Unrecoverable(
434            "SolverActor in mock mode".into(),
435        )));
436    }
437}
438
439fn run_gesvdj_batched<T: SolverScalar>(
440    cells: SolverCells<'_>,
441    a: GpuRef<T>,
442    m: i32,
443    n: i32,
444    batch_size: i32,
445    s: GpuRef<T>,
446    u: Option<GpuRef<T>>,
447    v: Option<GpuRef<T>>,
448    reply: oneshot::Sender<Result<(), GpuError>>,
449) {
450    let SolverCells {
451        handle,
452        stream,
453        completion,
454        workspace,
455        ..
456    } = cells;
457
458    let a_slice = match a.access() {
459        Ok(sl) => sl.clone(),
460        Err(e) => {
461            let _ = reply.send(Err(e));
462            return;
463        }
464    };
465    let s_slice = match s.access() {
466        Ok(sl) => sl.clone(),
467        Err(e) => {
468            let _ = reply.send(Err(e));
469            return;
470        }
471    };
472    let mut a_owned = match Arc::try_unwrap(a_slice) {
473        Ok(sl) => sl,
474        Err(_) => {
475            let _ = reply.send(Err(GpuError::Unrecoverable(
476                "GesvdjBatched a has multiple live references".into(),
477            )));
478            return;
479        }
480    };
481    let mut s_owned = match Arc::try_unwrap(s_slice) {
482        Ok(sl) => sl,
483        Err(_) => {
484            let _ = reply.send(Err(GpuError::Unrecoverable(
485                "GesvdjBatched s has multiple live references".into(),
486            )));
487            return;
488        }
489    };
490
491    let mut u_owned = match u.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
492        Some(Ok(sl)) => match Arc::try_unwrap(sl) {
493            Ok(o) => Some(o),
494            Err(_) => {
495                let _ = reply.send(Err(GpuError::Unrecoverable(
496                    "GesvdjBatched u has multiple live references".into(),
497                )));
498                return;
499            }
500        },
501        Some(Err(e)) => {
502            let _ = reply.send(Err(e));
503            return;
504        }
505        None => None,
506    };
507    let mut v_owned = match v.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
508        Some(Ok(sl)) => match Arc::try_unwrap(sl) {
509            Ok(o) => Some(o),
510            Err(_) => {
511                let _ = reply.send(Err(GpuError::Unrecoverable(
512                    "GesvdjBatched v has multiple live references".into(),
513                )));
514                return;
515            }
516        },
517        Some(Err(e)) => {
518            let _ = reply.send(Err(e));
519            return;
520        }
521        None => None,
522    };
523
524    // Create gesvdj params. Defaults are fine for Phase 1.
525    let mut info_handle: cs::gesvdjInfo_t = std::ptr::null_mut();
526    let st = unsafe { cs::cusolverDnCreateGesvdjInfo(&mut info_handle as *mut _) };
527    if let Err(e) = status_to_result(st, "CreateGesvdjInfo") {
528        let _ = reply.send(Err(e));
529        return;
530    }
531    let params = GesvdjParams(info_handle);
532
533    // info array: batch_size + 1 (matches cuSOLVER reference; we
534    // size it to batch_size for the basic check).
535    let info_array = match stream.alloc_zeros::<i32>(batch_size.max(1) as usize) {
536        Ok(b) => b,
537        Err(e) => {
538            let _ = reply.send(Err(GpuError::OutOfMemory(format!(
539                "GesvdjBatched info: {e}"
540            ))));
541            return;
542        }
543    };
544
545    let jobz = if u_owned.is_some() && v_owned.is_some() {
546        cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_VECTOR
547    } else {
548        cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_NOVECTOR
549    };
550
551    // Workspace query.
552    let ldu = m;
553    let ldv = n;
554    let mut lwork = 0i32;
555    {
556        let h = handle.lock();
557        let (a_ptr, _ga) = a_owned.device_ptr(stream);
558        let (s_ptr, _gs) = s_owned.device_ptr(stream);
559        let u_ptr: *const T = match u_owned.as_ref() {
560            Some(o) => {
561                let (p, _g) = o.device_ptr(stream);
562                p as *const T
563            }
564            None => std::ptr::null(),
565        };
566        let v_ptr: *const T = match v_owned.as_ref() {
567            Some(o) => {
568                let (p, _g) = o.device_ptr(stream);
569                p as *const T
570            }
571            None => std::ptr::null(),
572        };
573        let status = unsafe {
574            T::gesvdj_batched_buffer_size(
575                h.0.cu(),
576                jobz,
577                m,
578                n,
579                a_ptr as *const T,
580                m,
581                s_ptr as *const T,
582                u_ptr,
583                ldu,
584                v_ptr,
585                ldv,
586                &mut lwork as *mut _,
587                params.0,
588                batch_size,
589            )
590        };
591        drop((_ga, _gs));
592        if let Err(e) = status_to_result(status, "gesvdjBatched_bufferSize") {
593            let _ = reply.send(Err(e));
594            return;
595        }
596    }
597    if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
598        let _ = reply.send(Err(e));
599        return;
600    }
601
602    a.record_write(stream);
603    s.record_write(stream);
604    if let Some(g) = &u {
605        g.record_write(stream);
606    }
607    if let Some(g) = &v {
608        g.record_write(stream);
609    }
610
611    let stream_for_check = stream.clone();
612    let workspace_ref: &Mutex<Option<cudarc::driver::CudaSlice<u8>>> = workspace;
613
614    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
615        let h = handle.lock();
616        let mut ws = workspace_ref.lock();
617        let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
618        let (s_ptr, _g2) = s_owned.device_ptr_mut(&stream_for_check);
619        let (u_ptr, _gu_opt): (*mut T, _) = match u_owned.as_mut() {
620            Some(o) => {
621                let (p, g) = o.device_ptr_mut(&stream_for_check);
622                (p as *mut T, Some(g))
623            }
624            None => (std::ptr::null_mut(), None),
625        };
626        let (v_ptr, _gv_opt): (*mut T, _) = match v_owned.as_mut() {
627            Some(o) => {
628                let (p, g) = o.device_ptr_mut(&stream_for_check);
629                (p as *mut T, Some(g))
630            }
631            None => (std::ptr::null_mut(), None),
632        };
633        let ws_slice = ws.as_mut().expect("workspace ensured");
634        let (ws_ptr, _g5) = ws_slice.device_ptr_mut(&stream_for_check);
635        let (info_ptr, _ginfo) = info_array.device_ptr(&stream_for_check);
636        let status = unsafe {
637            T::gesvdj_batched(
638                h.0.cu(),
639                jobz,
640                m,
641                n,
642                a_ptr as *mut T,
643                m,
644                s_ptr as *mut T,
645                u_ptr,
646                ldu,
647                v_ptr,
648                ldv,
649                ws_ptr as *mut T,
650                lwork,
651                info_ptr as *mut i32,
652                params.0,
653                batch_size,
654            )
655        };
656        drop((_g1, _g2, _g5, _ginfo, _gu_opt, _gv_opt));
657        status_to_result(status, "gesvdjBatched")?;
658        check_info_array(
659            &info_array,
660            &stream_for_check,
661            "gesvdjBatched",
662            batch_size.max(0) as usize,
663        )?;
664        Ok((a_owned, s_owned, u_owned, v_owned, info_array, params))
665    });
666}
667
668// =====================================================================
669// Tests
670// =====================================================================
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn batched_request_round_trip() {
678        fn assert_dispatch<R: SolverDispatch>() {}
679        assert_dispatch::<GetrfBatchedRequest<f32>>();
680        assert_dispatch::<GetrfBatchedRequest<f64>>();
681        assert_dispatch::<PotrfBatchedRequest<f32>>();
682        assert_dispatch::<PotrfBatchedRequest<f64>>();
683        assert_dispatch::<GesvdjBatchedRequest<f32>>();
684        assert_dispatch::<GesvdjBatchedRequest<f64>>();
685    }
686}