Skip to main content

atomr_accel_cuda/sys/
cusparse.rs

1//! Thin safe-ish wrappers over `cudarc::cusparse::sys` entry points
2//! Phase 4 needs but the safe layer doesn't provide.
3//!
4//! Every wrapper:
5//! 1. Translates a `cusparseStatus_t` into a [`crate::error::GpuError::LibraryError`]
6//!    tagged `"cusparse"`.
7//! 2. Funnels `cusparseDestroy*` calls into a `Drop` so descriptors do
8//!    not leak when an error short-circuits the pipeline.
9//!
10//! Higher-level orchestration (descriptor caches, workspace pooling)
11//! lives in `crate::kernel::sparse`.
12
13use cudarc::cusparse::sys as cs;
14
15use crate::error::GpuError;
16
17const LIB: &str = "cusparse";
18
19/// Convert a cuSPARSE status into a `Result<(), GpuError>`.
20#[inline]
21pub fn ok(status: cs::cusparseStatus_t, what: &'static str) -> Result<(), GpuError> {
22    if status == cs::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
23        Ok(())
24    } else {
25        Err(GpuError::LibraryError {
26            lib: LIB,
27            msg: format!("{what}: {status:?}"),
28        })
29    }
30}
31
32/// RAII guard for a `cusparseSpMatDescr_t`.
33pub struct SpMatGuard(pub cs::cusparseSpMatDescr_t);
34unsafe impl Send for SpMatGuard {}
35impl Drop for SpMatGuard {
36    fn drop(&mut self) {
37        unsafe {
38            let _ = cs::cusparseDestroySpMat(self.0);
39        }
40    }
41}
42
43/// RAII guard for a `cusparseDnVecDescr_t`.
44pub struct DnVecGuard(pub cs::cusparseDnVecDescr_t);
45unsafe impl Send for DnVecGuard {}
46impl Drop for DnVecGuard {
47    fn drop(&mut self) {
48        unsafe {
49            let _ = cs::cusparseDestroyDnVec(self.0);
50        }
51    }
52}
53
54/// RAII guard for a `cusparseDnMatDescr_t`.
55pub struct DnMatGuard(pub cs::cusparseDnMatDescr_t);
56unsafe impl Send for DnMatGuard {}
57impl Drop for DnMatGuard {
58    fn drop(&mut self) {
59        unsafe {
60            let _ = cs::cusparseDestroyDnMat(self.0);
61        }
62    }
63}
64
65/// RAII guard for a `cusparseSpGEMMDescr_t`.
66pub struct SpGemmDescGuard(pub cs::cusparseSpGEMMDescr_t);
67unsafe impl Send for SpGemmDescGuard {}
68impl Drop for SpGemmDescGuard {
69    fn drop(&mut self) {
70        unsafe {
71            let _ = cs::cusparseSpGEMM_destroyDescr(self.0);
72        }
73    }
74}
75
76/// RAII guard for a `cusparseSpSVDescr_t`.
77pub struct SpSvDescGuard(pub cs::cusparseSpSVDescr_t);
78unsafe impl Send for SpSvDescGuard {}
79impl Drop for SpSvDescGuard {
80    fn drop(&mut self) {
81        unsafe {
82            let _ = cs::cusparseSpSV_destroyDescr(self.0);
83        }
84    }
85}
86
87/// Algorithm tag to use for `cusparseSpMV`. Public so `kernel/sparse/spmv.rs`
88/// can plumb it through the request struct in a future phase.
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum SpMvAlg {
91    Default,
92    Csr,
93    Coo,
94}
95
96impl SpMvAlg {
97    pub fn raw(self) -> cs::cusparseSpMVAlg_t {
98        match self {
99            SpMvAlg::Default => cs::cusparseSpMVAlg_t::CUSPARSE_SPMV_ALG_DEFAULT,
100            SpMvAlg::Csr => cs::cusparseSpMVAlg_t::CUSPARSE_SPMV_CSR_ALG1,
101            SpMvAlg::Coo => cs::cusparseSpMVAlg_t::CUSPARSE_SPMV_COO_ALG1,
102        }
103    }
104}
105
106/// Algorithm tag for `cusparseSpMM`.
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum SpMmAlg {
109    Default,
110    Csr,
111    BlockedEll,
112}
113
114impl SpMmAlg {
115    pub fn raw(self) -> cs::cusparseSpMMAlg_t {
116        match self {
117            SpMmAlg::Default => cs::cusparseSpMMAlg_t::CUSPARSE_SPMM_ALG_DEFAULT,
118            SpMmAlg::Csr => cs::cusparseSpMMAlg_t::CUSPARSE_SPMM_CSR_ALG2,
119            SpMmAlg::BlockedEll => cs::cusparseSpMMAlg_t::CUSPARSE_SPMM_BLOCKED_ELL_ALG1,
120        }
121    }
122}
123
124/// Algorithm tag for `cusparseSpGEMM`.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum SpGemmAlg {
127    Default,
128}
129
130impl SpGemmAlg {
131    pub fn raw(self) -> cs::cusparseSpGEMMAlg_t {
132        match self {
133            SpGemmAlg::Default => cs::cusparseSpGEMMAlg_t::CUSPARSE_SPGEMM_DEFAULT,
134        }
135    }
136}
137
138/// Algorithm tag for `cusparseSDDMM`.
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
140pub enum SddmmAlg {
141    Default,
142}
143
144impl SddmmAlg {
145    pub fn raw(self) -> cs::cusparseSDDMMAlg_t {
146        match self {
147            SddmmAlg::Default => cs::cusparseSDDMMAlg_t::CUSPARSE_SDDMM_ALG_DEFAULT,
148        }
149    }
150}
151
152/// Algorithm tag for `cusparseSpSV`.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum SpSvAlg {
155    Default,
156}
157
158impl SpSvAlg {
159    pub fn raw(self) -> cs::cusparseSpSVAlg_t {
160        match self {
161            SpSvAlg::Default => cs::cusparseSpSVAlg_t::CUSPARSE_SPSV_ALG_DEFAULT,
162        }
163    }
164}