1pub mod convert;
29pub mod descriptor;
30pub mod dispatch_impls;
31pub mod format;
32pub mod sddmm;
33pub mod spgemm;
34pub mod spmm;
35pub mod spmv;
36pub mod spsv;
37
38use std::sync::Arc;
39
40use async_trait::async_trait;
41use atomr_core::actor::{Actor, Context, Props};
42use cudarc::cusparse::sys as cusparse_sys;
43use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut};
44use parking_lot::Mutex;
45use tokio::sync::oneshot;
46
47use crate::completion::CompletionStrategy;
48use crate::device::DeviceState;
49use crate::error::GpuError;
50use crate::gpu_ref::GpuRef;
51use crate::kernel::dispatch::{SendSparseHandle, SparseDispatch, SparseDispatchCtx};
52use crate::kernel::envelope;
53use crate::stream::StreamAllocator;
54
55const LIB: &str = "cusparse";
56
57#[derive(Clone)]
60pub struct CsrMatrix {
61 pub row_offsets: GpuRef<i32>,
62 pub col_indices: GpuRef<i32>,
63 pub values: GpuRef<f32>,
64 pub rows: i64,
65 pub cols: i64,
66 pub nnz: i64,
67}
68
69pub enum SparseMsg {
76 Op(Box<dyn SparseDispatch>),
79
80 #[deprecated(
81 note = "use SparseMsg::Op(Box::new(SpMvRequest::new(...))) for the dtype-generic path"
82 )]
83 SpMv {
84 csr: CsrMatrix,
85 x: GpuRef<f32>,
86 y: GpuRef<f32>,
87 alpha: f32,
88 beta: f32,
89 reply: oneshot::Sender<Result<(), GpuError>>,
90 },
91
92 #[deprecated(
93 note = "use SparseMsg::Op(Box::new(SpMmRequest::new(...))) for the dtype-generic path"
94 )]
95 SpMm {
96 csr: CsrMatrix,
97 b: GpuRef<f32>,
98 c: GpuRef<f32>,
99 b_cols: i64,
100 ldb: i64,
101 ldc: i64,
102 alpha: f32,
103 beta: f32,
104 reply: oneshot::Sender<Result<(), GpuError>>,
105 },
106}
107
108pub struct SparseActor {
109 inner: SparseInner,
110}
111
112#[allow(dead_code)]
113enum SparseInner {
114 Real {
115 handle: Mutex<SendSparseHandle>,
116 stream: Arc<cudarc::driver::CudaStream>,
117 completion: Arc<dyn CompletionStrategy>,
118 state: Arc<DeviceState>,
119 workspace: Mutex<Option<CudaSlice<u8>>>,
121 },
122 Mock,
123}
124
125impl Drop for SparseInner {
126 fn drop(&mut self) {
127 if let SparseInner::Real { handle, .. } = self {
128 let h = handle.lock();
129 unsafe {
130 let _ = cusparse_sys::cusparseDestroy(h.0);
131 }
132 }
133 }
134}
135
136impl SparseActor {
137 pub fn props(
138 stream: Arc<cudarc::driver::CudaStream>,
139 _allocator: Arc<dyn StreamAllocator>,
140 completion: Arc<dyn CompletionStrategy>,
141 state: Arc<DeviceState>,
142 ) -> Props<Self> {
143 Props::create(move || {
144 let mut h: cusparse_sys::cusparseHandle_t = std::ptr::null_mut();
145 let s = unsafe { cusparse_sys::cusparseCreate(&mut h as *mut _) };
146 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
147 panic!("ContextPoisoned: cusparseCreate failed: {s:?}");
148 }
149 let s = unsafe { cusparse_sys::cusparseSetStream(h, stream.cu_stream() as *mut _) };
150 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
151 unsafe {
152 let _ = cusparse_sys::cusparseDestroy(h);
153 }
154 panic!("ContextPoisoned: cusparseSetStream failed: {s:?}");
155 }
156 SparseActor {
157 inner: SparseInner::Real {
158 handle: Mutex::new(SendSparseHandle(h)),
159 stream: stream.clone(),
160 completion: completion.clone(),
161 state: state.clone(),
162 workspace: Mutex::new(None),
163 },
164 }
165 })
166 }
167
168 pub fn mock_props() -> Props<Self> {
169 Props::create(|| SparseActor {
170 inner: SparseInner::Mock,
171 })
172 }
173}
174
175#[async_trait]
176impl Actor for SparseActor {
177 type Msg = SparseMsg;
178
179 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: SparseMsg) {
180 match &self.inner {
181 SparseInner::Mock => mock_reply(msg),
182 SparseInner::Real {
183 handle,
184 stream,
185 completion,
186 workspace,
187 ..
188 } =>
189 {
190 #[allow(deprecated)]
191 match msg {
192 SparseMsg::Op(op) => {
193 let ctx = SparseDispatchCtx {
194 handle,
195 stream,
196 completion,
197 workspace,
198 };
199 op.dispatch(&ctx);
200 }
201 SparseMsg::SpMv {
202 csr,
203 x,
204 y,
205 alpha,
206 beta,
207 reply,
208 } => {
209 handle_spmv(
210 handle, stream, completion, workspace, csr, x, y, alpha, beta, reply,
211 );
212 }
213 SparseMsg::SpMm {
214 csr,
215 b,
216 c,
217 b_cols,
218 ldb,
219 ldc,
220 alpha,
221 beta,
222 reply,
223 } => {
224 handle_spmm(
225 handle, stream, completion, workspace, csr, b, c, b_cols, ldb, ldc,
226 alpha, beta, reply,
227 );
228 }
229 }
230 }
231 }
232 }
233}
234
235fn mock_reply(msg: SparseMsg) {
236 let err = || GpuError::Unrecoverable("SparseActor in mock mode".into());
237 #[allow(deprecated)]
238 match msg {
239 SparseMsg::Op(op) => {
240 drop(op);
251 }
252 SparseMsg::SpMv { reply, .. } | SparseMsg::SpMm { reply, .. } => {
253 let _ = reply.send(Err(err()));
254 }
255 }
256}
257
258fn ensure_workspace(
259 workspace: &Mutex<Option<CudaSlice<u8>>>,
260 stream: &Arc<cudarc::driver::CudaStream>,
261 needed_bytes: usize,
262) -> Result<(), GpuError> {
263 let mut g = workspace.lock();
264 let cur = g.as_ref().map(|s| s.len()).unwrap_or(0);
265 if cur >= needed_bytes {
266 return Ok(());
267 }
268 *g = Some(stream.alloc_zeros::<u8>(needed_bytes.max(1)).map_err(|e| {
269 GpuError::OutOfMemory(format!("cusparse workspace ({needed_bytes}B): {e}"))
270 })?);
271 Ok(())
272}
273
274#[allow(clippy::too_many_arguments)]
275fn handle_spmv(
276 handle: &Mutex<SendSparseHandle>,
277 stream: &Arc<cudarc::driver::CudaStream>,
278 completion: &Arc<dyn CompletionStrategy>,
279 workspace: &Mutex<Option<CudaSlice<u8>>>,
280 csr: CsrMatrix,
281 x: GpuRef<f32>,
282 y: GpuRef<f32>,
283 alpha: f32,
284 beta: f32,
285 reply: oneshot::Sender<Result<(), GpuError>>,
286) {
287 let row_off = match csr.row_offsets.access() {
288 Ok(s) => s.clone(),
289 Err(e) => {
290 let _ = reply.send(Err(e));
291 return;
292 }
293 };
294 let col_idx = match csr.col_indices.access() {
295 Ok(s) => s.clone(),
296 Err(e) => {
297 let _ = reply.send(Err(e));
298 return;
299 }
300 };
301 let vals = match csr.values.access() {
302 Ok(s) => s.clone(),
303 Err(e) => {
304 let _ = reply.send(Err(e));
305 return;
306 }
307 };
308 let x_slice = match x.access() {
309 Ok(s) => s.clone(),
310 Err(e) => {
311 let _ = reply.send(Err(e));
312 return;
313 }
314 };
315 let y_slice = match y.access() {
316 Ok(s) => s.clone(),
317 Err(e) => {
318 let _ = reply.send(Err(e));
319 return;
320 }
321 };
322 let mut y_owned = match Arc::try_unwrap(y_slice) {
323 Ok(s) => s,
324 Err(_) => {
325 let _ = reply.send(Err(GpuError::Unrecoverable(
326 "SpMv y has multiple live references".into(),
327 )));
328 return;
329 }
330 };
331
332 let h = handle.lock();
333 let (row_off_ptr, _g0) = row_off.device_ptr(stream);
334 let (col_idx_ptr, _g1) = col_idx.device_ptr(stream);
335 let (vals_ptr, _g2) = vals.device_ptr(stream);
336 let (x_ptr, _g3) = x_slice.device_ptr(stream);
337 let (y_ptr, _g4) = y_owned.device_ptr_mut(stream);
338
339 let mut mat_desc: cusparse_sys::cusparseSpMatDescr_t = std::ptr::null_mut();
340 let s = unsafe {
341 cusparse_sys::cusparseCreateCsr(
342 &mut mat_desc as *mut _,
343 csr.rows,
344 csr.cols,
345 csr.nnz,
346 row_off_ptr as *mut _,
347 col_idx_ptr as *mut _,
348 vals_ptr as *mut _,
349 cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
350 cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
351 cusparse_sys::cusparseIndexBase_t::CUSPARSE_INDEX_BASE_ZERO,
352 cusparse_sys::cudaDataType::CUDA_R_32F,
353 )
354 };
355 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
356 let _ = reply.send(Err(GpuError::LibraryError {
357 lib: LIB,
358 msg: format!("CreateCsr: {s:?}"),
359 }));
360 return;
361 }
362 let mut x_desc: cusparse_sys::cusparseDnVecDescr_t = std::ptr::null_mut();
363 let s = unsafe {
364 cusparse_sys::cusparseCreateDnVec(
365 &mut x_desc as *mut _,
366 csr.cols,
367 x_ptr as *mut _,
368 cusparse_sys::cudaDataType::CUDA_R_32F,
369 )
370 };
371 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
372 unsafe {
373 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
374 }
375 let _ = reply.send(Err(GpuError::LibraryError {
376 lib: LIB,
377 msg: format!("CreateDnVec(x): {s:?}"),
378 }));
379 return;
380 }
381 let mut y_desc: cusparse_sys::cusparseDnVecDescr_t = std::ptr::null_mut();
382 let s = unsafe {
383 cusparse_sys::cusparseCreateDnVec(
384 &mut y_desc as *mut _,
385 csr.rows,
386 y_ptr as *mut _,
387 cusparse_sys::cudaDataType::CUDA_R_32F,
388 )
389 };
390 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
391 unsafe {
392 let _ = cusparse_sys::cusparseDestroyDnVec(x_desc);
393 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
394 }
395 let _ = reply.send(Err(GpuError::LibraryError {
396 lib: LIB,
397 msg: format!("CreateDnVec(y): {s:?}"),
398 }));
399 return;
400 }
401
402 let alpha_h = alpha;
403 let beta_h = beta;
404 let mut buf_size: usize = 0;
405 let s = unsafe {
406 cusparse_sys::cusparseSpMV_bufferSize(
407 h.0,
408 cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
409 &alpha_h as *const f32 as *const _,
410 mat_desc,
411 x_desc,
412 &beta_h as *const f32 as *const _,
413 y_desc,
414 cusparse_sys::cudaDataType::CUDA_R_32F,
415 cusparse_sys::cusparseSpMVAlg_t::CUSPARSE_SPMV_ALG_DEFAULT,
416 &mut buf_size as *mut _,
417 )
418 };
419 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
420 unsafe {
421 let _ = cusparse_sys::cusparseDestroyDnVec(y_desc);
422 let _ = cusparse_sys::cusparseDestroyDnVec(x_desc);
423 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
424 }
425 let _ = reply.send(Err(GpuError::LibraryError {
426 lib: LIB,
427 msg: format!("SpMV_bufferSize: {s:?}"),
428 }));
429 return;
430 }
431 drop((_g0, _g1, _g2, _g3, _g4));
432 drop(h);
433
434 if let Err(e) = ensure_workspace(workspace, stream, buf_size) {
435 unsafe {
436 let _ = cusparse_sys::cusparseDestroyDnVec(y_desc);
437 let _ = cusparse_sys::cusparseDestroyDnVec(x_desc);
438 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
439 }
440 let _ = reply.send(Err(e));
441 return;
442 }
443
444 y.record_write(stream);
445
446 let handle_clone = handle;
447 let workspace_ref = workspace;
448 let stream_for_check = stream.clone();
449 struct SendDesc<T>(T);
450 unsafe impl<T> Send for SendDesc<T> {}
451 let mat = SendDesc(mat_desc);
452 let xd = SendDesc(x_desc);
453 let yd = SendDesc(y_desc);
454
455 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
456 let h = handle_clone.lock();
457 let mut ws = workspace_ref.lock();
458 let (y_ptr, _g) = y_owned.device_ptr_mut(&stream_for_check);
459 let _ = y_ptr;
460 let ws_slice = ws.as_mut().expect("workspace ensured");
461 let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
462 let s = unsafe {
463 cusparse_sys::cusparseSpMV(
464 h.0,
465 cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
466 &alpha_h as *const f32 as *const _,
467 mat.0,
468 xd.0,
469 &beta_h as *const f32 as *const _,
470 yd.0,
471 cusparse_sys::cudaDataType::CUDA_R_32F,
472 cusparse_sys::cusparseSpMVAlg_t::CUSPARSE_SPMV_ALG_DEFAULT,
473 ws_ptr as *mut _,
474 )
475 };
476 drop((_g, _gws));
477 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
478 unsafe {
479 let _ = cusparse_sys::cusparseDestroyDnVec(yd.0);
480 let _ = cusparse_sys::cusparseDestroyDnVec(xd.0);
481 let _ = cusparse_sys::cusparseDestroySpMat(mat.0);
482 }
483 return Err(GpuError::LibraryError {
484 lib: LIB,
485 msg: format!("SpMV: {s:?}"),
486 });
487 }
488 struct DescGuard {
489 mat: cusparse_sys::cusparseSpMatDescr_t,
490 x: cusparse_sys::cusparseDnVecDescr_t,
491 y: cusparse_sys::cusparseDnVecDescr_t,
492 }
493 impl Drop for DescGuard {
494 fn drop(&mut self) {
495 unsafe {
496 let _ = cusparse_sys::cusparseDestroyDnVec(self.y);
497 let _ = cusparse_sys::cusparseDestroyDnVec(self.x);
498 let _ = cusparse_sys::cusparseDestroySpMat(self.mat);
499 }
500 }
501 }
502 unsafe impl Send for DescGuard {}
503 let guard = DescGuard {
504 mat: mat.0,
505 x: xd.0,
506 y: yd.0,
507 };
508 Ok((y_owned, row_off, col_idx, vals, x_slice, guard))
509 });
510}
511
512#[allow(clippy::too_many_arguments)]
513fn handle_spmm(
514 handle: &Mutex<SendSparseHandle>,
515 stream: &Arc<cudarc::driver::CudaStream>,
516 completion: &Arc<dyn CompletionStrategy>,
517 workspace: &Mutex<Option<CudaSlice<u8>>>,
518 csr: CsrMatrix,
519 b: GpuRef<f32>,
520 c: GpuRef<f32>,
521 b_cols: i64,
522 ldb: i64,
523 ldc: i64,
524 alpha: f32,
525 beta: f32,
526 reply: oneshot::Sender<Result<(), GpuError>>,
527) {
528 let row_off = match csr.row_offsets.access() {
529 Ok(s) => s.clone(),
530 Err(e) => {
531 let _ = reply.send(Err(e));
532 return;
533 }
534 };
535 let col_idx = match csr.col_indices.access() {
536 Ok(s) => s.clone(),
537 Err(e) => {
538 let _ = reply.send(Err(e));
539 return;
540 }
541 };
542 let vals = match csr.values.access() {
543 Ok(s) => s.clone(),
544 Err(e) => {
545 let _ = reply.send(Err(e));
546 return;
547 }
548 };
549 let b_slice = match b.access() {
550 Ok(s) => s.clone(),
551 Err(e) => {
552 let _ = reply.send(Err(e));
553 return;
554 }
555 };
556 let c_slice = match c.access() {
557 Ok(s) => s.clone(),
558 Err(e) => {
559 let _ = reply.send(Err(e));
560 return;
561 }
562 };
563 let mut c_owned = match Arc::try_unwrap(c_slice) {
564 Ok(s) => s,
565 Err(_) => {
566 let _ = reply.send(Err(GpuError::Unrecoverable(
567 "SpMm c has multiple live references".into(),
568 )));
569 return;
570 }
571 };
572
573 let h = handle.lock();
574 let (row_off_ptr, _g0) = row_off.device_ptr(stream);
575 let (col_idx_ptr, _g1) = col_idx.device_ptr(stream);
576 let (vals_ptr, _g2) = vals.device_ptr(stream);
577 let (b_ptr, _g3) = b_slice.device_ptr(stream);
578 let (c_ptr, _g4) = c_owned.device_ptr_mut(stream);
579
580 let mut mat_desc: cusparse_sys::cusparseSpMatDescr_t = std::ptr::null_mut();
581 let s = unsafe {
582 cusparse_sys::cusparseCreateCsr(
583 &mut mat_desc as *mut _,
584 csr.rows,
585 csr.cols,
586 csr.nnz,
587 row_off_ptr as *mut _,
588 col_idx_ptr as *mut _,
589 vals_ptr as *mut _,
590 cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
591 cusparse_sys::cusparseIndexType_t::CUSPARSE_INDEX_32I,
592 cusparse_sys::cusparseIndexBase_t::CUSPARSE_INDEX_BASE_ZERO,
593 cusparse_sys::cudaDataType::CUDA_R_32F,
594 )
595 };
596 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
597 let _ = reply.send(Err(GpuError::LibraryError {
598 lib: LIB,
599 msg: format!("CreateCsr: {s:?}"),
600 }));
601 return;
602 }
603 let mut b_desc: cusparse_sys::cusparseDnMatDescr_t = std::ptr::null_mut();
604 let s = unsafe {
605 cusparse_sys::cusparseCreateDnMat(
606 &mut b_desc as *mut _,
607 csr.cols,
608 b_cols,
609 ldb,
610 b_ptr as *mut _,
611 cusparse_sys::cudaDataType::CUDA_R_32F,
612 cusparse_sys::cusparseOrder_t::CUSPARSE_ORDER_COL,
613 )
614 };
615 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
616 unsafe {
617 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
618 }
619 let _ = reply.send(Err(GpuError::LibraryError {
620 lib: LIB,
621 msg: format!("CreateDnMat(b): {s:?}"),
622 }));
623 return;
624 }
625 let mut c_desc: cusparse_sys::cusparseDnMatDescr_t = std::ptr::null_mut();
626 let s = unsafe {
627 cusparse_sys::cusparseCreateDnMat(
628 &mut c_desc as *mut _,
629 csr.rows,
630 b_cols,
631 ldc,
632 c_ptr as *mut _,
633 cusparse_sys::cudaDataType::CUDA_R_32F,
634 cusparse_sys::cusparseOrder_t::CUSPARSE_ORDER_COL,
635 )
636 };
637 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
638 unsafe {
639 let _ = cusparse_sys::cusparseDestroyDnMat(b_desc);
640 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
641 }
642 let _ = reply.send(Err(GpuError::LibraryError {
643 lib: LIB,
644 msg: format!("CreateDnMat(c): {s:?}"),
645 }));
646 return;
647 }
648
649 let alpha_h = alpha;
650 let beta_h = beta;
651 let mut buf_size: usize = 0;
652 let s = unsafe {
653 cusparse_sys::cusparseSpMM_bufferSize(
654 h.0,
655 cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
656 cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
657 &alpha_h as *const f32 as *const _,
658 mat_desc,
659 b_desc,
660 &beta_h as *const f32 as *const _,
661 c_desc,
662 cusparse_sys::cudaDataType::CUDA_R_32F,
663 cusparse_sys::cusparseSpMMAlg_t::CUSPARSE_SPMM_ALG_DEFAULT,
664 &mut buf_size as *mut _,
665 )
666 };
667 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
668 unsafe {
669 let _ = cusparse_sys::cusparseDestroyDnMat(c_desc);
670 let _ = cusparse_sys::cusparseDestroyDnMat(b_desc);
671 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
672 }
673 let _ = reply.send(Err(GpuError::LibraryError {
674 lib: LIB,
675 msg: format!("SpMM_bufferSize: {s:?}"),
676 }));
677 return;
678 }
679 drop((_g0, _g1, _g2, _g3, _g4));
680 drop(h);
681
682 if let Err(e) = ensure_workspace(workspace, stream, buf_size) {
683 unsafe {
684 let _ = cusparse_sys::cusparseDestroyDnMat(c_desc);
685 let _ = cusparse_sys::cusparseDestroyDnMat(b_desc);
686 let _ = cusparse_sys::cusparseDestroySpMat(mat_desc);
687 }
688 let _ = reply.send(Err(e));
689 return;
690 }
691
692 c.record_write(stream);
693
694 let handle_clone = handle;
695 let workspace_ref = workspace;
696 let stream_for_check = stream.clone();
697 struct SendDesc<T>(T);
698 unsafe impl<T> Send for SendDesc<T> {}
699 let mat = SendDesc(mat_desc);
700 let bd = SendDesc(b_desc);
701 let cd = SendDesc(c_desc);
702
703 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
704 let h = handle_clone.lock();
705 let mut ws = workspace_ref.lock();
706 let (_c_ptr, _g) = c_owned.device_ptr_mut(&stream_for_check);
707 let ws_slice = ws.as_mut().expect("workspace ensured");
708 let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
709 let s = unsafe {
710 cusparse_sys::cusparseSpMM(
711 h.0,
712 cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
713 cusparse_sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
714 &alpha_h as *const f32 as *const _,
715 mat.0,
716 bd.0,
717 &beta_h as *const f32 as *const _,
718 cd.0,
719 cusparse_sys::cudaDataType::CUDA_R_32F,
720 cusparse_sys::cusparseSpMMAlg_t::CUSPARSE_SPMM_ALG_DEFAULT,
721 ws_ptr as *mut _,
722 )
723 };
724 drop((_g, _gws));
725 if s != cusparse_sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
726 unsafe {
727 let _ = cusparse_sys::cusparseDestroyDnMat(cd.0);
728 let _ = cusparse_sys::cusparseDestroyDnMat(bd.0);
729 let _ = cusparse_sys::cusparseDestroySpMat(mat.0);
730 }
731 return Err(GpuError::LibraryError {
732 lib: LIB,
733 msg: format!("SpMM: {s:?}"),
734 });
735 }
736 struct DescGuard {
737 mat: cusparse_sys::cusparseSpMatDescr_t,
738 b: cusparse_sys::cusparseDnMatDescr_t,
739 c: cusparse_sys::cusparseDnMatDescr_t,
740 }
741 impl Drop for DescGuard {
742 fn drop(&mut self) {
743 unsafe {
744 let _ = cusparse_sys::cusparseDestroyDnMat(self.c);
745 let _ = cusparse_sys::cusparseDestroyDnMat(self.b);
746 let _ = cusparse_sys::cusparseDestroySpMat(self.mat);
747 }
748 }
749 }
750 unsafe impl Send for DescGuard {}
751 let guard = DescGuard {
752 mat: mat.0,
753 b: bd.0,
754 c: cd.0,
755 };
756 Ok((c_owned, row_off, col_idx, vals, b_slice, guard))
757 });
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use std::sync::Arc;
764
765 #[test]
769 #[allow(deprecated)]
770 fn deprecated_spmv_alias_still_constructs() {
771 let state = Arc::new(DeviceState::new(0));
774 assert_eq!(state.generation(), 0);
777
778 fn _assemble<F>(_f: F) {}
780 _assemble(
781 |csr: CsrMatrix,
782 x: GpuRef<f32>,
783 y: GpuRef<f32>,
784 reply: oneshot::Sender<Result<(), GpuError>>| {
785 SparseMsg::SpMv {
786 csr,
787 x,
788 y,
789 alpha: 1.0,
790 beta: 0.0,
791 reply,
792 }
793 },
794 );
795 _assemble(
796 |csr: CsrMatrix,
797 b: GpuRef<f32>,
798 c: GpuRef<f32>,
799 reply: oneshot::Sender<Result<(), GpuError>>| {
800 SparseMsg::SpMm {
801 csr,
802 b,
803 c,
804 b_cols: 1,
805 ldb: 1,
806 ldc: 1,
807 alpha: 1.0,
808 beta: 0.0,
809 reply,
810 }
811 },
812 );
813 }
814}