Skip to main content

atomr_accel_cuda/kernel/sparse/
mod.rs

1//! `SparseActor` — wraps cuSPARSE for the full op surface.
2//!
3//! Phase 4 expands the F-9 CSR-only SpMv/SpMm shape into a full
4//! generic-API surface (CSR/COO/CSC/Blocked-ELL/BSR × f32/f64/f16/bf16
5//! × i32/i64) plus SpGEMM, SpSV, SDDMM, and dense↔sparse conversion.
6//!
7//! ## Mailbox shape
8//!
9//! ```ignore
10//! pub enum SparseMsg {
11//!     Op(Box<dyn SparseDispatch>),    // canonical
12//!     #[deprecated] SpMv { ... },     // legacy CSR-only f32
13//!     #[deprecated] SpMm { ... },
14//! }
15//! ```
16//!
17//! New callers ship the `Op(Box<…>)` variant produced by an
18//! `SpMvRequest::new(…)` / `SpMmRequest::new(…)` / etc. The deprecated
19//! typed variants route through the original F-9 implementation
20//! directly (CSR-only, f32-only) and remain wire-compatible with the
21//! existing `tests/spmv_e2e.rs` end-to-end test.
22//!
23//! The CudaContext / supervision wiring is unchanged from F-9 — the
24//! actor still owns a `cusparseHandle_t` for the lifetime of the
25//! current `ContextActor` generation, panics with `"ContextPoisoned"`
26//! on init failure, and drops to `Mock` mode when no GPU is present.
27
28pub mod convert;
29pub mod descriptor;
30pub mod dispatch_impls;
31pub mod format;
32pub mod sddmm;
33pub mod spgemm;
34pub mod spmm;
35pub mod spmv;
36pub mod spsv;
37
38use std::sync::Arc;
39
40use async_trait::async_trait;
41use atomr_core::actor::{Actor, Context, Props};
42use cudarc::cusparse::sys as cusparse_sys;
43use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut};
44use parking_lot::Mutex;
45use tokio::sync::oneshot;
46
47use crate::completion::CompletionStrategy;
48use crate::device::DeviceState;
49use crate::error::GpuError;
50use crate::gpu_ref::GpuRef;
51use crate::kernel::dispatch::{SendSparseHandle, SparseDispatch, SparseDispatchCtx};
52use crate::kernel::envelope;
53use crate::stream::StreamAllocator;
54
55const LIB: &str = "cusparse";
56
57/// Legacy CSR sparse matrix in device memory — kept for back-compat with
58/// callers built against F-9. Prefer [`SparseMatrix`] for new code.
59#[derive(Clone)]
60pub struct CsrMatrix {
61    pub row_offsets: GpuRef<i32>,
62    pub col_indices: GpuRef<i32>,
63    pub values: GpuRef<f32>,
64    pub rows: i64,
65    pub cols: i64,
66    pub nnz: i64,
67}
68
69/// Public messages for [`SparseActor`].
70///
71/// New code uses the canonical `Op(Box<dyn SparseDispatch>)` payload.
72/// The two deprecated typed variants are aliases retained for
73/// back-compat with F-9 callers and the existing `spmv_e2e` integration
74/// test.
75pub enum SparseMsg {
76    /// Canonical Phase-4 dispatch — generic over dtype/format/index
77    /// type via the boxed [`SparseDispatch`].
78    Op(Box<dyn SparseDispatch>),
79
80    #[deprecated(
81        note = "use SparseMsg::Op(Box::new(SpMvRequest::new(...))) for the dtype-generic path"
82    )]
83    SpMv {
84        csr: CsrMatrix,
85        x: GpuRef<f32>,
86        y: GpuRef<f32>,
87        alpha: f32,
88        beta: f32,
89        reply: oneshot::Sender<Result<(), GpuError>>,
90    },
91
92    #[deprecated(
93        note = "use SparseMsg::Op(Box::new(SpMmRequest::new(...))) for the dtype-generic path"
94    )]
95    SpMm {
96        csr: CsrMatrix,
97        b: GpuRef<f32>,
98        c: GpuRef<f32>,
99        b_cols: i64,
100        ldb: i64,
101        ldc: i64,
102        alpha: f32,
103        beta: f32,
104        reply: oneshot::Sender<Result<(), GpuError>>,
105    },
106}
107
108pub struct SparseActor {
109    inner: SparseInner,
110}
111
112#[allow(dead_code)]
113enum SparseInner {
114    Real {
115        handle: Mutex<SendSparseHandle>,
116        stream: Arc<cudarc::driver::CudaStream>,
117        completion: Arc<dyn CompletionStrategy>,
118        state: Arc<DeviceState>,
119        /// On-demand-grown external buffer (in u8). Never shrunk.
120        workspace: Mutex<Option<CudaSlice<u8>>>,
121    },
122    Mock,
123}
124
125impl Drop for SparseInner {
126    fn drop(&mut self) {
127        if let SparseInner::Real { handle, .. } = self {
128            let h = handle.lock();
129            unsafe {
130                let _ = cusparse_sys::cusparseDestroy(h.0);
131            }
132        }
133    }
134}
135
136impl SparseActor {
137    pub fn props(
138        stream: Arc<cudarc::driver::CudaStream>,
139        _allocator: Arc<dyn StreamAllocator>,
140        completion: Arc<dyn CompletionStrategy>,
141        state: Arc<DeviceState>,
142    ) -> Props<Self> {
143        Props::create(move || {
144            let mut h: cusparse_sys::cusparseHandle_t = std::ptr::null_mut();
145            let s = unsafe { cusparse_sys::cusparseCreate(&mut h as *mut _) };
146            if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
147                panic!("ContextPoisoned: cusparseCreate failed: {s:?}");
148            }
149            let s = unsafe { cusparse_sys::cusparseSetStream(h, stream.cu_stream() as *mut _) };
150            if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
151                unsafe {
152                    let _ = cusparse_sys::cusparseDestroy(h);
153                }
154                panic!("ContextPoisoned: cusparseSetStream failed: {s:?}");
155            }
156            SparseActor {
157                inner: SparseInner::Real {
158                    handle: Mutex::new(SendSparseHandle(h)),
159                    stream: stream.clone(),
160                    completion: completion.clone(),
161                    state: state.clone(),
162                    workspace: Mutex::new(None),
163                },
164            }
165        })
166    }
167
168    pub fn mock_props() -> Props<Self> {
169        Props::create(|| SparseActor {
170            inner: SparseInner::Mock,
171        })
172    }
173}
174
175#[async_trait]
176impl Actor for SparseActor {
177    type Msg = SparseMsg;
178
179    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: SparseMsg) {
180        match &self.inner {
181            SparseInner::Mock => mock_reply(msg),
182            SparseInner::Real {
183                handle,
184                stream,
185                completion,
186                workspace,
187                ..
188            } =>
189            {
190                #[allow(deprecated)]
191                match msg {
192                    SparseMsg::Op(op) => {
193                        let ctx = SparseDispatchCtx {
194                            handle,
195                            stream,
196                            completion,
197                            workspace,
198                        };
199                        op.dispatch(&ctx);
200                    }
201                    SparseMsg::SpMv {
202                        csr,
203                        x,
204                        y,
205                        alpha,
206                        beta,
207                        reply,
208                    } => {
209                        handle_spmv(
210                            handle, stream, completion, workspace, csr, x, y, alpha, beta, reply,
211                        );
212                    }
213                    SparseMsg::SpMm {
214                        csr,
215                        b,
216                        c,
217                        b_cols,
218                        ldb,
219                        ldc,
220                        alpha,
221                        beta,
222                        reply,
223                    } => {
224                        handle_spmm(
225                            handle, stream, completion, workspace, csr, b, c, b_cols, ldb, ldc,
226                            alpha, beta, reply,
227                        );
228                    }
229                }
230            }
231        }
232    }
233}
234
235fn mock_reply(msg: SparseMsg) {
236    let err = || GpuError::Unrecoverable("SparseActor in mock mode".into());
237    #[allow(deprecated)]
238    match msg {
239        SparseMsg::Op(op) => {
240            // We can't dispatch without a handle, so surface the error
241            // via the boxed op's own dispatch path. The ctx is unused
242            // for the mock case — but we still need to give the op a
243            // place to put its reply. Drop the box; the reply oneshot
244            // inside is dropped, which the caller observes as a
245            // `RecvError`.
246            //
247            // For symmetry with the typed variants, we surface a typed
248            // error via the dispatch trait's own context — but no ctx
249            // exists here. Drop is the documented mock-mode behaviour.
250            drop(op);
251        }
252        SparseMsg::SpMv { reply, .. } | SparseMsg::SpMm { reply, .. } => {
253            let _ = reply.send(Err(err()));
254        }
255    }
256}
257
258fn ensure_workspace(
259    workspace: &Mutex<Option<CudaSlice<u8>>>,
260    stream: &Arc<cudarc::driver::CudaStream>,
261    needed_bytes: usize,
262) -> Result<(), GpuError> {
263    let mut g = workspace.lock();
264    let cur = g.as_ref().map(|s| s.len()).unwrap_or(0);
265    if cur >= needed_bytes {
266        return Ok(());
267    }
268    *g = Some(stream.alloc_zeros::<u8>(needed_bytes.max(1)).map_err(|e| {
269        GpuError::OutOfMemory(format!("cusparse workspace ({needed_bytes}B): {e}"))
270    })?);
271    Ok(())
272}
273
274#[allow(clippy::too_many_arguments)]
275fn handle_spmv(
276    handle: &Mutex<SendSparseHandle>,
277    stream: &Arc<cudarc::driver::CudaStream>,
278    completion: &Arc<dyn CompletionStrategy>,
279    workspace: &Mutex<Option<CudaSlice<u8>>>,
280    csr: CsrMatrix,
281    x: GpuRef<f32>,
282    y: GpuRef<f32>,
283    alpha: f32,
284    beta: f32,
285    reply: oneshot::Sender<Result<(), GpuError>>,
286) {
287    let row_off = match csr.row_offsets.access() {
288        Ok(s) => s.clone(),
289        Err(e) => {
290            let _ = reply.send(Err(e));
291            return;
292        }
293    };
294    let col_idx = match csr.col_indices.access() {
295        Ok(s) => s.clone(),
296        Err(e) => {
297            let _ = reply.send(Err(e));
298            return;
299        }
300    };
301    let vals = match csr.values.access() {
302        Ok(s) => s.clone(),
303        Err(e) => {
304            let _ = reply.send(Err(e));
305            return;
306        }
307    };
308    let x_slice = match x.access() {
309        Ok(s) => s.clone(),
310        Err(e) => {
311            let _ = reply.send(Err(e));
312            return;
313        }
314    };
315    let y_slice = match y.access() {
316        Ok(s) => s.clone(),
317        Err(e) => {
318            let _ = reply.send(Err(e));
319            return;
320        }
321    };
322    let mut y_owned = match Arc::try_unwrap(y_slice) {
323        Ok(s) => s,
324        Err(_) => {
325            let _ = reply.send(Err(GpuError::Unrecoverable(
326                "SpMv y has multiple live references".into(),
327            )));
328            return;
329        }
330    };
331
332    let h = handle.lock();
333    let (row_off_ptr, _g0) = row_off.device_ptr(stream);
334    let (col_idx_ptr, _g1) = col_idx.device_ptr(stream);
335    let (vals_ptr, _g2) = vals.device_ptr(stream);
336    let (x_ptr, _g3) = x_slice.device_ptr(stream);
337    let (y_ptr, _g4) = y_owned.device_ptr_mut(stream);
338
339    let mut mat_desc: cusparse_sys::cusparseSpMatDescr_t = std::ptr::null_mut();
340    let s = unsafe {
341        cusparse_sys::cusparseCreateCsr(
342            &mut mat_desc as *mut _,
343            csr.rows,
344            csr.cols,
345            csr.nnz,
346            row_off_ptr as *mut _,
347            col_idx_ptr as *mut _,
348            vals_ptr as *mut _,
349            cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
350            cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
351            cusparse_sys::cusparseIndexBase_t::CUSPARSE_INDEX_BASE_ZERO,
352            cusparse_sys::cudaDataType::CUDA_R_32F,
353        )
354    };
355    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
356        let _ = reply.send(Err(GpuError::LibraryError {
357            lib: LIB,
358            msg: format!("CreateCsr: {s:?}"),
359        }));
360        return;
361    }
362    let mut x_desc: cusparse_sys::cusparseDnVecDescr_t = std::ptr::null_mut();
363    let s = unsafe {
364        cusparse_sys::cusparseCreateDnVec(
365            &mut x_desc as *mut _,
366            csr.cols,
367            x_ptr as *mut _,
368            cusparse_sys::cudaDataType::CUDA_R_32F,
369        )
370    };
371    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
372        unsafe {
373            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
374        }
375        let _ = reply.send(Err(GpuError::LibraryError {
376            lib: LIB,
377            msg: format!("CreateDnVec(x): {s:?}"),
378        }));
379        return;
380    }
381    let mut y_desc: cusparse_sys::cusparseDnVecDescr_t = std::ptr::null_mut();
382    let s = unsafe {
383        cusparse_sys::cusparseCreateDnVec(
384            &mut y_desc as *mut _,
385            csr.rows,
386            y_ptr as *mut _,
387            cusparse_sys::cudaDataType::CUDA_R_32F,
388        )
389    };
390    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
391        unsafe {
392            let _ = cusparse_sys::cusparseDestroyDnVec(x_desc);
393            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
394        }
395        let _ = reply.send(Err(GpuError::LibraryError {
396            lib: LIB,
397            msg: format!("CreateDnVec(y): {s:?}"),
398        }));
399        return;
400    }
401
402    let alpha_h = alpha;
403    let beta_h = beta;
404    let mut buf_size: usize = 0;
405    let s = unsafe {
406        cusparse_sys::cusparseSpMV_bufferSize(
407            h.0,
408            cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
409            &alpha_h as *const f32 as *const _,
410            mat_desc,
411            x_desc,
412            &beta_h as *const f32 as *const _,
413            y_desc,
414            cusparse_sys::cudaDataType::CUDA_R_32F,
415            cusparse_sys::cusparseSpMVAlg_t::CUSPARSE_SPMV_ALG_DEFAULT,
416            &mut buf_size as *mut _,
417        )
418    };
419    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
420        unsafe {
421            let _ = cusparse_sys::cusparseDestroyDnVec(y_desc);
422            let _ = cusparse_sys::cusparseDestroyDnVec(x_desc);
423            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
424        }
425        let _ = reply.send(Err(GpuError::LibraryError {
426            lib: LIB,
427            msg: format!("SpMV_bufferSize: {s:?}"),
428        }));
429        return;
430    }
431    drop((_g0, _g1, _g2, _g3, _g4));
432    drop(h);
433
434    if let Err(e) = ensure_workspace(workspace, stream, buf_size) {
435        unsafe {
436            let _ = cusparse_sys::cusparseDestroyDnVec(y_desc);
437            let _ = cusparse_sys::cusparseDestroyDnVec(x_desc);
438            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
439        }
440        let _ = reply.send(Err(e));
441        return;
442    }
443
444    y.record_write(stream);
445
446    let handle_clone = handle;
447    let workspace_ref = workspace;
448    let stream_for_check = stream.clone();
449    struct SendDesc<T>(T);
450    unsafe impl<T> Send for SendDesc<T> {}
451    let mat = SendDesc(mat_desc);
452    let xd = SendDesc(x_desc);
453    let yd = SendDesc(y_desc);
454
455    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
456        let h = handle_clone.lock();
457        let mut ws = workspace_ref.lock();
458        let (y_ptr, _g) = y_owned.device_ptr_mut(&stream_for_check);
459        let _ = y_ptr;
460        let ws_slice = ws.as_mut().expect("workspace ensured");
461        let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
462        let s = unsafe {
463            cusparse_sys::cusparseSpMV(
464                h.0,
465                cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
466                &alpha_h as *const f32 as *const _,
467                mat.0,
468                xd.0,
469                &beta_h as *const f32 as *const _,
470                yd.0,
471                cusparse_sys::cudaDataType::CUDA_R_32F,
472                cusparse_sys::cusparseSpMVAlg_t::CUSPARSE_SPMV_ALG_DEFAULT,
473                ws_ptr as *mut _,
474            )
475        };
476        drop((_g, _gws));
477        if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
478            unsafe {
479                let _ = cusparse_sys::cusparseDestroyDnVec(yd.0);
480                let _ = cusparse_sys::cusparseDestroyDnVec(xd.0);
481                let _ = cusparse_sys::cusparseDestroySpMat(mat.0);
482            }
483            return Err(GpuError::LibraryError {
484                lib: LIB,
485                msg: format!("SpMV: {s:?}"),
486            });
487        }
488        struct DescGuard {
489            mat: cusparse_sys::cusparseSpMatDescr_t,
490            x: cusparse_sys::cusparseDnVecDescr_t,
491            y: cusparse_sys::cusparseDnVecDescr_t,
492        }
493        impl Drop for DescGuard {
494            fn drop(&mut self) {
495                unsafe {
496                    let _ = cusparse_sys::cusparseDestroyDnVec(self.y);
497                    let _ = cusparse_sys::cusparseDestroyDnVec(self.x);
498                    let _ = cusparse_sys::cusparseDestroySpMat(self.mat);
499                }
500            }
501        }
502        unsafe impl Send for DescGuard {}
503        let guard = DescGuard {
504            mat: mat.0,
505            x: xd.0,
506            y: yd.0,
507        };
508        Ok((y_owned, row_off, col_idx, vals, x_slice, guard))
509    });
510}
511
512#[allow(clippy::too_many_arguments)]
513fn handle_spmm(
514    handle: &Mutex<SendSparseHandle>,
515    stream: &Arc<cudarc::driver::CudaStream>,
516    completion: &Arc<dyn CompletionStrategy>,
517    workspace: &Mutex<Option<CudaSlice<u8>>>,
518    csr: CsrMatrix,
519    b: GpuRef<f32>,
520    c: GpuRef<f32>,
521    b_cols: i64,
522    ldb: i64,
523    ldc: i64,
524    alpha: f32,
525    beta: f32,
526    reply: oneshot::Sender<Result<(), GpuError>>,
527) {
528    let row_off = match csr.row_offsets.access() {
529        Ok(s) => s.clone(),
530        Err(e) => {
531            let _ = reply.send(Err(e));
532            return;
533        }
534    };
535    let col_idx = match csr.col_indices.access() {
536        Ok(s) => s.clone(),
537        Err(e) => {
538            let _ = reply.send(Err(e));
539            return;
540        }
541    };
542    let vals = match csr.values.access() {
543        Ok(s) => s.clone(),
544        Err(e) => {
545            let _ = reply.send(Err(e));
546            return;
547        }
548    };
549    let b_slice = match b.access() {
550        Ok(s) => s.clone(),
551        Err(e) => {
552            let _ = reply.send(Err(e));
553            return;
554        }
555    };
556    let c_slice = match c.access() {
557        Ok(s) => s.clone(),
558        Err(e) => {
559            let _ = reply.send(Err(e));
560            return;
561        }
562    };
563    let mut c_owned = match Arc::try_unwrap(c_slice) {
564        Ok(s) => s,
565        Err(_) => {
566            let _ = reply.send(Err(GpuError::Unrecoverable(
567                "SpMm c has multiple live references".into(),
568            )));
569            return;
570        }
571    };
572
573    let h = handle.lock();
574    let (row_off_ptr, _g0) = row_off.device_ptr(stream);
575    let (col_idx_ptr, _g1) = col_idx.device_ptr(stream);
576    let (vals_ptr, _g2) = vals.device_ptr(stream);
577    let (b_ptr, _g3) = b_slice.device_ptr(stream);
578    let (c_ptr, _g4) = c_owned.device_ptr_mut(stream);
579
580    let mut mat_desc: cusparse_sys::cusparseSpMatDescr_t = std::ptr::null_mut();
581    let s = unsafe {
582        cusparse_sys::cusparseCreateCsr(
583            &mut mat_desc as *mut _,
584            csr.rows,
585            csr.cols,
586            csr.nnz,
587            row_off_ptr as *mut _,
588            col_idx_ptr as *mut _,
589            vals_ptr as *mut _,
590            cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
591            cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
592            cusparse_sys::cusparseIndexBase_t::CUSPARSE_INDEX_BASE_ZERO,
593            cusparse_sys::cudaDataType::CUDA_R_32F,
594        )
595    };
596    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
597        let _ = reply.send(Err(GpuError::LibraryError {
598            lib: LIB,
599            msg: format!("CreateCsr: {s:?}"),
600        }));
601        return;
602    }
603    let mut b_desc: cusparse_sys::cusparseDnMatDescr_t = std::ptr::null_mut();
604    let s = unsafe {
605        cusparse_sys::cusparseCreateDnMat(
606            &mut b_desc as *mut _,
607            csr.cols,
608            b_cols,
609            ldb,
610            b_ptr as *mut _,
611            cusparse_sys::cudaDataType::CUDA_R_32F,
612            cusparse_sys::cusparseOrder_t::CUSPARSE_ORDER_COL,
613        )
614    };
615    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
616        unsafe {
617            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
618        }
619        let _ = reply.send(Err(GpuError::LibraryError {
620            lib: LIB,
621            msg: format!("CreateDnMat(b): {s:?}"),
622        }));
623        return;
624    }
625    let mut c_desc: cusparse_sys::cusparseDnMatDescr_t = std::ptr::null_mut();
626    let s = unsafe {
627        cusparse_sys::cusparseCreateDnMat(
628            &mut c_desc as *mut _,
629            csr.rows,
630            b_cols,
631            ldc,
632            c_ptr as *mut _,
633            cusparse_sys::cudaDataType::CUDA_R_32F,
634            cusparse_sys::cusparseOrder_t::CUSPARSE_ORDER_COL,
635        )
636    };
637    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
638        unsafe {
639            let _ = cusparse_sys::cusparseDestroyDnMat(b_desc);
640            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
641        }
642        let _ = reply.send(Err(GpuError::LibraryError {
643            lib: LIB,
644            msg: format!("CreateDnMat(c): {s:?}"),
645        }));
646        return;
647    }
648
649    let alpha_h = alpha;
650    let beta_h = beta;
651    let mut buf_size: usize = 0;
652    let s = unsafe {
653        cusparse_sys::cusparseSpMM_bufferSize(
654            h.0,
655            cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
656            cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
657            &alpha_h as *const f32 as *const _,
658            mat_desc,
659            b_desc,
660            &beta_h as *const f32 as *const _,
661            c_desc,
662            cusparse_sys::cudaDataType::CUDA_R_32F,
663            cusparse_sys::cusparseSpMMAlg_t::CUSPARSE_SPMM_ALG_DEFAULT,
664            &mut buf_size as *mut _,
665        )
666    };
667    if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
668        unsafe {
669            let _ = cusparse_sys::cusparseDestroyDnMat(c_desc);
670            let _ = cusparse_sys::cusparseDestroyDnMat(b_desc);
671            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
672        }
673        let _ = reply.send(Err(GpuError::LibraryError {
674            lib: LIB,
675            msg: format!("SpMM_bufferSize: {s:?}"),
676        }));
677        return;
678    }
679    drop((_g0, _g1, _g2, _g3, _g4));
680    drop(h);
681
682    if let Err(e) = ensure_workspace(workspace, stream, buf_size) {
683        unsafe {
684            let _ = cusparse_sys::cusparseDestroyDnMat(c_desc);
685            let _ = cusparse_sys::cusparseDestroyDnMat(b_desc);
686            let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
687        }
688        let _ = reply.send(Err(e));
689        return;
690    }
691
692    c.record_write(stream);
693
694    let handle_clone = handle;
695    let workspace_ref = workspace;
696    let stream_for_check = stream.clone();
697    struct SendDesc<T>(T);
698    unsafe impl<T> Send for SendDesc<T> {}
699    let mat = SendDesc(mat_desc);
700    let bd = SendDesc(b_desc);
701    let cd = SendDesc(c_desc);
702
703    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
704        let h = handle_clone.lock();
705        let mut ws = workspace_ref.lock();
706        let (_c_ptr, _g) = c_owned.device_ptr_mut(&stream_for_check);
707        let ws_slice = ws.as_mut().expect("workspace ensured");
708        let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
709        let s = unsafe {
710            cusparse_sys::cusparseSpMM(
711                h.0,
712                cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
713                cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
714                &alpha_h as *const f32 as *const _,
715                mat.0,
716                bd.0,
717                &beta_h as *const f32 as *const _,
718                cd.0,
719                cusparse_sys::cudaDataType::CUDA_R_32F,
720                cusparse_sys::cusparseSpMMAlg_t::CUSPARSE_SPMM_ALG_DEFAULT,
721                ws_ptr as *mut _,
722            )
723        };
724        drop((_g, _gws));
725        if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
726            unsafe {
727                let _ = cusparse_sys::cusparseDestroyDnMat(cd.0);
728                let _ = cusparse_sys::cusparseDestroyDnMat(bd.0);
729                let _ = cusparse_sys::cusparseDestroySpMat(mat.0);
730            }
731            return Err(GpuError::LibraryError {
732                lib: LIB,
733                msg: format!("SpMM: {s:?}"),
734            });
735        }
736        struct DescGuard {
737            mat: cusparse_sys::cusparseSpMatDescr_t,
738            b: cusparse_sys::cusparseDnMatDescr_t,
739            c: cusparse_sys::cusparseDnMatDescr_t,
740        }
741        impl Drop for DescGuard {
742            fn drop(&mut self) {
743                unsafe {
744                    let _ = cusparse_sys::cusparseDestroyDnMat(self.c);
745                    let _ = cusparse_sys::cusparseDestroyDnMat(self.b);
746                    let _ = cusparse_sys::cusparseDestroySpMat(self.mat);
747                }
748            }
749        }
750        unsafe impl Send for DescGuard {}
751        let guard = DescGuard {
752            mat: mat.0,
753            b: bd.0,
754            c: cd.0,
755        };
756        Ok((c_owned, row_off, col_idx, vals, b_slice, guard))
757    });
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use std::sync::Arc;
764
765    /// The deprecated typed `SpMv` variant is what the F-9 e2e test
766    /// emits — make sure it still constructs cleanly even though new
767    /// callers route through `SparseMsg::Op(...)`.
768    #[test]
769    #[allow(deprecated)]
770    fn deprecated_spmv_alias_still_constructs() {
771        // We can't mint a real GpuRef without a CudaSlice, so we
772        // exercise just the enum match shape.
773        let state = Arc::new(DeviceState::new(0));
774        // Touch `state` so the unused-binding lint stays quiet on the
775        // host-only test path.
776        assert_eq!(state.generation(), 0);
777
778        // Compile-only: the enum variant is constructible by name.
779        fn _assemble<F>(_f: F) {}
780        _assemble(
781            |csr: CsrMatrix,
782             x: GpuRef<f32>,
783             y: GpuRef<f32>,
784             reply: oneshot::Sender<Result<(), GpuError>>| {
785                SparseMsg::SpMv {
786                    csr,
787                    x,
788                    y,
789                    alpha: 1.0,
790                    beta: 0.0,
791                    reply,
792                }
793            },
794        );
795        _assemble(
796            |csr: CsrMatrix,
797             b: GpuRef<f32>,
798             c: GpuRef<f32>,
799             reply: oneshot::Sender<Result<(), GpuError>>| {
800                SparseMsg::SpMm {
801                    csr,
802                    b,
803                    c,
804                    b_cols: 1,
805                    ldb: 1,
806                    ldc: 1,
807                    alpha: 1.0,
808                    beta: 0.0,
809                    reply,
810                }
811            },
812        );
813    }
814}