atomr_accel_cuda/graph/record/
cusparse.rs1#![cfg(feature = "cusparse")]
9
10use crate::error::GpuError;
11use crate::gpu_ref::GpuRef;
12use crate::graph::{GraphOpRecord, GraphRecordCtx};
13use crate::kernel::CsrMatrix;
14
15pub struct SpMvOp {
17 pub csr: CsrMatrix,
18 pub x: GpuRef<f32>,
19 pub y: GpuRef<f32>,
20 pub alpha: f32,
21 pub beta: f32,
22}
23
24pub struct SpMmOp {
26 pub csr: CsrMatrix,
27 pub b: GpuRef<f32>,
28 pub c: GpuRef<f32>,
29 pub b_cols: i64,
30 pub ldb: i64,
31 pub ldc: i64,
32 pub alpha: f32,
33 pub beta: f32,
34}
35
36impl GraphOpRecord for SpMvOp {
37 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
38 validate_csr(&self.csr)?;
39 let _ = self.x.access()?;
40 let _ = self.y.access()?;
41 let _ = ctx;
42 Err(GpuError::Unrecoverable(
48 "graph::record::cusparse::SpMv: cuSPARSE capture-mode \
49 entry not yet wired (Phase 4 will revisit when the \
50 actor surface expands)"
51 .into(),
52 ))
53 }
54}
55
56impl GraphOpRecord for SpMmOp {
57 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
58 validate_csr(&self.csr)?;
59 let _ = self.b.access()?;
60 let _ = self.c.access()?;
61 if self.b_cols <= 0 || self.ldb <= 0 || self.ldc <= 0 {
62 return Err(GpuError::Unrecoverable(format!(
63 "SpMm: non-positive (b_cols, ldb, ldc) = ({}, {}, {})",
64 self.b_cols, self.ldb, self.ldc
65 )));
66 }
67 let _ = ctx;
68 Err(GpuError::Unrecoverable(
69 "graph::record::cusparse::SpMm: cuSPARSE capture-mode \
70 entry not yet wired"
71 .into(),
72 ))
73 }
74}
75
76fn validate_csr(c: &CsrMatrix) -> Result<(), GpuError> {
77 if c.rows <= 0 || c.cols <= 0 || c.nnz < 0 {
78 return Err(GpuError::Unrecoverable(format!(
79 "CsrMatrix: invalid dims (rows={}, cols={}, nnz={})",
80 c.rows, c.cols, c.nnz
81 )));
82 }
83 let _ = c.row_offsets.access()?;
84 let _ = c.col_indices.access()?;
85 let _ = c.values.access()?;
86 Ok(())
87}
88
89#[cfg(test)]
90mod tests {
91
92 #[test]
93 fn spmv_op_records() {
94 struct Dims {
101 rows: i64,
102 cols: i64,
103 nnz: i64,
104 }
105 let bad = Dims {
106 rows: 4,
107 cols: 4,
108 nnz: -1,
109 };
110 assert!(bad.nnz < 0);
113
114 let good = Dims {
115 rows: 4,
116 cols: 4,
117 nnz: 0,
118 };
119 assert!(good.nnz >= 0);
120 }
121
122 #[test]
123 fn spmm_dim_validation_rejects_zero() {
124 let bad: (i64, i64, i64) = (0, 1, 1);
127 assert!(bad.0 <= 0 || bad.1 <= 0 || bad.2 <= 0);
128 let good: (i64, i64, i64) = (4, 4, 4);
129 assert!(good.0 > 0 && good.1 > 0 && good.2 > 0);
130 }
131}