1use std::ffi::c_int;
21use std::sync::Arc;
22
23use cudarc::cusolver::sys as cs;
24use cudarc::driver::{DevicePtr, DevicePtrMut};
25use parking_lot::Mutex;
26use tokio::sync::oneshot;
27
28use crate::dtype::SolverSupported;
29use crate::error::GpuError;
30use crate::gpu_ref::GpuRef;
31use crate::kernel::envelope;
32use crate::sys::cusolver::{status_to_result, SolverScalar, LIB};
33
34use super::workspace::{check_info_array, ensure_workspace_bytes, lwork_bytes};
35use super::{SolverCells, SolverDispatch, Uplo};
36
37pub struct GetrfBatchedRequest<T: SolverSupported> {
42 pub a: GpuRef<T>,
44 pub n: i32,
46 pub batch_size: i32,
48 pub ipiv: GpuRef<i32>,
50 pub reply: oneshot::Sender<Result<(), GpuError>>,
51}
52
53trait BatchedLu: SolverScalar {
58 unsafe fn getrf_batched(
59 handle: cudarc::cublas::sys::cublasHandle_t,
60 n: c_int,
61 a_array: *const *mut Self,
62 lda: c_int,
63 pivots: *mut c_int,
64 info_array: *mut c_int,
65 batch_size: c_int,
66 ) -> cudarc::cublas::sys::cublasStatus_t;
67}
68
69impl BatchedLu for f32 {
70 unsafe fn getrf_batched(
71 handle: cudarc::cublas::sys::cublasHandle_t,
72 n: c_int,
73 a_array: *const *mut Self,
74 lda: c_int,
75 pivots: *mut c_int,
76 info_array: *mut c_int,
77 batch_size: c_int,
78 ) -> cudarc::cublas::sys::cublasStatus_t {
79 cudarc::cublas::sys::cublasSgetrfBatched(
80 handle, n, a_array, lda, pivots, info_array, batch_size,
81 )
82 }
83}
84
85impl BatchedLu for f64 {
86 unsafe fn getrf_batched(
87 handle: cudarc::cublas::sys::cublasHandle_t,
88 n: c_int,
89 a_array: *const *mut Self,
90 lda: c_int,
91 pivots: *mut c_int,
92 info_array: *mut c_int,
93 batch_size: c_int,
94 ) -> cudarc::cublas::sys::cublasStatus_t {
95 cudarc::cublas::sys::cublasDgetrfBatched(
96 handle, n, a_array, lda, pivots, info_array, batch_size,
97 )
98 }
99}
100
101impl<T> SolverDispatch for GetrfBatchedRequest<T>
102where
103 T: SolverSupported + SolverScalar + BatchedLu,
104{
105 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
106 let GetrfBatchedRequest {
107 a,
108 n,
109 batch_size,
110 ipiv,
111 reply,
112 } = *self;
113 run_getrf_batched::<T>(cells, a, n, batch_size, ipiv, reply);
114 }
115
116 fn dispatch_mock(self: Box<Self>) {
117 let _ = self.reply.send(Err(GpuError::Unrecoverable(
118 "SolverActor in mock mode".into(),
119 )));
120 }
121}
122
123fn upload_pointer_table<T>(
128 stream: &Arc<cudarc::driver::CudaStream>,
129 base: *mut T,
130 batch: i32,
131 n: i32,
132) -> Result<cudarc::driver::CudaSlice<u64>, GpuError> {
133 let count = batch.max(0) as usize;
134 let stride_bytes = (n.max(0) as usize) * (n.max(0) as usize) * std::mem::size_of::<T>();
135 let mut ptrs = Vec::with_capacity(count);
136 for i in 0..count {
137 let p = (base as usize).saturating_add(i * stride_bytes);
138 ptrs.push(p as u64);
139 }
140 let mut buf = stream
141 .alloc_zeros::<u64>(count.max(1))
142 .map_err(|e| GpuError::OutOfMemory(format!("ptr table ({count}): {e}")))?;
143 stream
144 .memcpy_htod(&ptrs, &mut buf)
145 .map_err(|e| GpuError::lib(LIB, format!("upload ptr table: {e}")))?;
146 Ok(buf)
147}
148
149fn run_getrf_batched<T: SolverScalar + BatchedLu>(
150 cells: SolverCells<'_>,
151 a: GpuRef<T>,
152 n: i32,
153 batch_size: i32,
154 ipiv: GpuRef<i32>,
155 reply: oneshot::Sender<Result<(), GpuError>>,
156) {
157 let SolverCells {
158 stream, completion, ..
159 } = cells;
160
161 let a_slice = match a.access() {
162 Ok(s) => s.clone(),
163 Err(e) => {
164 let _ = reply.send(Err(e));
165 return;
166 }
167 };
168 let ipiv_slice = match ipiv.access() {
169 Ok(s) => s.clone(),
170 Err(e) => {
171 let _ = reply.send(Err(e));
172 return;
173 }
174 };
175 let mut a_owned = match Arc::try_unwrap(a_slice) {
176 Ok(s) => s,
177 Err(_) => {
178 let _ = reply.send(Err(GpuError::Unrecoverable(
179 "GetrfBatched a has multiple live references".into(),
180 )));
181 return;
182 }
183 };
184 let mut ipiv_owned = match Arc::try_unwrap(ipiv_slice) {
185 Ok(s) => s,
186 Err(_) => {
187 let _ = reply.send(Err(GpuError::Unrecoverable(
188 "GetrfBatched ipiv has multiple live references".into(),
189 )));
190 return;
191 }
192 };
193
194 let blas = match cudarc::cublas::CudaBlas::new(stream.clone()) {
200 Ok(b) => b,
201 Err(e) => {
202 let _ = reply.send(Err(GpuError::lib(LIB, format!("CudaBlas::new: {e}"))));
203 return;
204 }
205 };
206 let blas_handle = *blas.handle();
207
208 let (a_base_ptr, _g_base) = a_owned.device_ptr_mut(stream);
210 let ptr_table = match upload_pointer_table::<T>(stream, a_base_ptr as *mut T, batch_size, n) {
211 Ok(t) => t,
212 Err(e) => {
213 let _ = reply.send(Err(e));
214 return;
215 }
216 };
217 drop(_g_base);
218
219 let info_array = match stream.alloc_zeros::<i32>(batch_size.max(1) as usize) {
221 Ok(b) => b,
222 Err(e) => {
223 let _ = reply.send(Err(GpuError::OutOfMemory(format!(
224 "GetrfBatched info: {e}"
225 ))));
226 return;
227 }
228 };
229
230 a.record_write(stream);
231 ipiv.record_write(stream);
232
233 let stream_for_check = stream.clone();
234 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
235 let (ptrs_dev, _gp) = ptr_table.device_ptr(&stream_for_check);
236 let (ipiv_ptr, _gpiv) = ipiv_owned.device_ptr_mut(&stream_for_check);
237 let (info_ptr, _ginfo) = info_array.device_ptr(&stream_for_check);
238 let status = unsafe {
239 T::getrf_batched(
240 blas_handle,
241 n,
242 ptrs_dev as *const *mut T,
243 n,
244 ipiv_ptr as *mut c_int,
245 info_ptr as *mut c_int,
246 batch_size,
247 )
248 };
249 drop((_gp, _gpiv, _ginfo));
250 if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
251 return Err(GpuError::lib(LIB, format!("getrfBatched: {status:?}")));
252 }
253 check_info_array(
254 &info_array,
255 &stream_for_check,
256 "getrfBatched",
257 batch_size.max(0) as usize,
258 )?;
259 Ok((a_owned, ipiv_owned, ptr_table, info_array, blas))
261 });
262}
263
264pub struct PotrfBatchedRequest<T: SolverSupported> {
269 pub a: GpuRef<T>,
270 pub n: i32,
271 pub batch_size: i32,
272 pub uplo: Uplo,
273 pub reply: oneshot::Sender<Result<(), GpuError>>,
274}
275
276impl<T> SolverDispatch for PotrfBatchedRequest<T>
277where
278 T: SolverSupported + SolverScalar,
279{
280 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
281 let PotrfBatchedRequest {
282 a,
283 n,
284 batch_size,
285 uplo,
286 reply,
287 } = *self;
288 run_potrf_batched::<T>(cells, a, n, batch_size, uplo, reply);
289 }
290
291 fn dispatch_mock(self: Box<Self>) {
292 let _ = self.reply.send(Err(GpuError::Unrecoverable(
293 "SolverActor in mock mode".into(),
294 )));
295 }
296}
297
298fn run_potrf_batched<T: SolverScalar>(
299 cells: SolverCells<'_>,
300 a: GpuRef<T>,
301 n: i32,
302 batch_size: i32,
303 uplo: Uplo,
304 reply: oneshot::Sender<Result<(), GpuError>>,
305) {
306 let SolverCells {
307 handle,
308 stream,
309 completion,
310 ..
311 } = cells;
312
313 let a_slice = match a.access() {
314 Ok(s) => s.clone(),
315 Err(e) => {
316 let _ = reply.send(Err(e));
317 return;
318 }
319 };
320 let mut a_owned = match Arc::try_unwrap(a_slice) {
321 Ok(s) => s,
322 Err(_) => {
323 let _ = reply.send(Err(GpuError::Unrecoverable(
324 "PotrfBatched a has multiple live references".into(),
325 )));
326 return;
327 }
328 };
329
330 let (a_base_ptr, _g_base) = a_owned.device_ptr_mut(stream);
331 let ptr_table = match upload_pointer_table::<T>(stream, a_base_ptr as *mut T, batch_size, n) {
332 Ok(t) => t,
333 Err(e) => {
334 let _ = reply.send(Err(e));
335 return;
336 }
337 };
338 drop(_g_base);
339
340 let info_array = match stream.alloc_zeros::<i32>(batch_size.max(1) as usize) {
341 Ok(b) => b,
342 Err(e) => {
343 let _ = reply.send(Err(GpuError::OutOfMemory(format!(
344 "PotrfBatched info: {e}"
345 ))));
346 return;
347 }
348 };
349
350 a.record_write(stream);
351 let fill = uplo.as_cusolver_fill();
352 let stream_for_check = stream.clone();
353
354 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
355 let h = handle.lock();
356 let (ptrs_dev, _gp) = ptr_table.device_ptr(&stream_for_check);
357 let (info_ptr, _ginfo) = info_array.device_ptr(&stream_for_check);
358 let status = unsafe {
359 T::potrf_batched(
360 h.0.cu(),
361 fill,
362 n,
363 ptrs_dev as *mut *mut T,
364 n,
365 info_ptr as *mut i32,
366 batch_size,
367 )
368 };
369 drop((_gp, _ginfo));
370 status_to_result(status, "potrfBatched")?;
371 check_info_array(
372 &info_array,
373 &stream_for_check,
374 "potrfBatched",
375 batch_size.max(0) as usize,
376 )?;
377 Ok((a_owned, ptr_table, info_array))
378 });
379}
380
381pub struct GesvdjBatchedRequest<T: SolverSupported> {
386 pub a: GpuRef<T>,
388 pub m: i32,
389 pub n: i32,
390 pub batch_size: i32,
391 pub s: GpuRef<T>,
393 pub u: Option<GpuRef<T>>,
396 pub v: Option<GpuRef<T>>,
399 pub reply: oneshot::Sender<Result<(), GpuError>>,
400}
401
402struct GesvdjParams(cs::gesvdjInfo_t);
405unsafe impl Send for GesvdjParams {}
406impl Drop for GesvdjParams {
407 fn drop(&mut self) {
408 unsafe {
409 let _ = cs::cusolverDnDestroyGesvdjInfo(self.0);
410 }
411 }
412}
413
414impl<T> SolverDispatch for GesvdjBatchedRequest<T>
415where
416 T: SolverSupported + SolverScalar,
417{
418 fn dispatch(self: Box<Self>, cells: SolverCells<'_>) {
419 let GesvdjBatchedRequest {
420 a,
421 m,
422 n,
423 batch_size,
424 s,
425 u,
426 v,
427 reply,
428 } = *self;
429 run_gesvdj_batched::<T>(cells, a, m, n, batch_size, s, u, v, reply);
430 }
431
432 fn dispatch_mock(self: Box<Self>) {
433 let _ = self.reply.send(Err(GpuError::Unrecoverable(
434 "SolverActor in mock mode".into(),
435 )));
436 }
437}
438
439fn run_gesvdj_batched<T: SolverScalar>(
440 cells: SolverCells<'_>,
441 a: GpuRef<T>,
442 m: i32,
443 n: i32,
444 batch_size: i32,
445 s: GpuRef<T>,
446 u: Option<GpuRef<T>>,
447 v: Option<GpuRef<T>>,
448 reply: oneshot::Sender<Result<(), GpuError>>,
449) {
450 let SolverCells {
451 handle,
452 stream,
453 completion,
454 workspace,
455 ..
456 } = cells;
457
458 let a_slice = match a.access() {
459 Ok(sl) => sl.clone(),
460 Err(e) => {
461 let _ = reply.send(Err(e));
462 return;
463 }
464 };
465 let s_slice = match s.access() {
466 Ok(sl) => sl.clone(),
467 Err(e) => {
468 let _ = reply.send(Err(e));
469 return;
470 }
471 };
472 let mut a_owned = match Arc::try_unwrap(a_slice) {
473 Ok(sl) => sl,
474 Err(_) => {
475 let _ = reply.send(Err(GpuError::Unrecoverable(
476 "GesvdjBatched a has multiple live references".into(),
477 )));
478 return;
479 }
480 };
481 let mut s_owned = match Arc::try_unwrap(s_slice) {
482 Ok(sl) => sl,
483 Err(_) => {
484 let _ = reply.send(Err(GpuError::Unrecoverable(
485 "GesvdjBatched s has multiple live references".into(),
486 )));
487 return;
488 }
489 };
490
491 let mut u_owned = match u.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
492 Some(Ok(sl)) => match Arc::try_unwrap(sl) {
493 Ok(o) => Some(o),
494 Err(_) => {
495 let _ = reply.send(Err(GpuError::Unrecoverable(
496 "GesvdjBatched u has multiple live references".into(),
497 )));
498 return;
499 }
500 },
501 Some(Err(e)) => {
502 let _ = reply.send(Err(e));
503 return;
504 }
505 None => None,
506 };
507 let mut v_owned = match v.as_ref().map(|g| g.access().map(|sl| sl.clone())) {
508 Some(Ok(sl)) => match Arc::try_unwrap(sl) {
509 Ok(o) => Some(o),
510 Err(_) => {
511 let _ = reply.send(Err(GpuError::Unrecoverable(
512 "GesvdjBatched v has multiple live references".into(),
513 )));
514 return;
515 }
516 },
517 Some(Err(e)) => {
518 let _ = reply.send(Err(e));
519 return;
520 }
521 None => None,
522 };
523
524 let mut info_handle: cs::gesvdjInfo_t = std::ptr::null_mut();
526 let st = unsafe { cs::cusolverDnCreateGesvdjInfo(&mut info_handle as *mut _) };
527 if let Err(e) = status_to_result(st, "CreateGesvdjInfo") {
528 let _ = reply.send(Err(e));
529 return;
530 }
531 let params = GesvdjParams(info_handle);
532
533 let info_array = match stream.alloc_zeros::<i32>(batch_size.max(1) as usize) {
536 Ok(b) => b,
537 Err(e) => {
538 let _ = reply.send(Err(GpuError::OutOfMemory(format!(
539 "GesvdjBatched info: {e}"
540 ))));
541 return;
542 }
543 };
544
545 let jobz = if u_owned.is_some() && v_owned.is_some() {
546 cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_VECTOR
547 } else {
548 cs::cusolverEigMode_t::CUSOLVER_EIG_MODE_NOVECTOR
549 };
550
551 let ldu = m;
553 let ldv = n;
554 let mut lwork = 0i32;
555 {
556 let h = handle.lock();
557 let (a_ptr, _ga) = a_owned.device_ptr(stream);
558 let (s_ptr, _gs) = s_owned.device_ptr(stream);
559 let u_ptr: *const T = match u_owned.as_ref() {
560 Some(o) => {
561 let (p, _g) = o.device_ptr(stream);
562 p as *const T
563 }
564 None => std::ptr::null(),
565 };
566 let v_ptr: *const T = match v_owned.as_ref() {
567 Some(o) => {
568 let (p, _g) = o.device_ptr(stream);
569 p as *const T
570 }
571 None => std::ptr::null(),
572 };
573 let status = unsafe {
574 T::gesvdj_batched_buffer_size(
575 h.0.cu(),
576 jobz,
577 m,
578 n,
579 a_ptr as *const T,
580 m,
581 s_ptr as *const T,
582 u_ptr,
583 ldu,
584 v_ptr,
585 ldv,
586 &mut lwork as *mut _,
587 params.0,
588 batch_size,
589 )
590 };
591 drop((_ga, _gs));
592 if let Err(e) = status_to_result(status, "gesvdjBatched_bufferSize") {
593 let _ = reply.send(Err(e));
594 return;
595 }
596 }
597 if let Err(e) = ensure_workspace_bytes(workspace, stream, lwork_bytes::<T>(lwork)) {
598 let _ = reply.send(Err(e));
599 return;
600 }
601
602 a.record_write(stream);
603 s.record_write(stream);
604 if let Some(g) = &u {
605 g.record_write(stream);
606 }
607 if let Some(g) = &v {
608 g.record_write(stream);
609 }
610
611 let stream_for_check = stream.clone();
612 let workspace_ref: &Mutex<Option<cudarc::driver::CudaSlice<u8>>> = workspace;
613
614 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
615 let h = handle.lock();
616 let mut ws = workspace_ref.lock();
617 let (a_ptr, _g1) = a_owned.device_ptr_mut(&stream_for_check);
618 let (s_ptr, _g2) = s_owned.device_ptr_mut(&stream_for_check);
619 let (u_ptr, _gu_opt): (*mut T, _) = match u_owned.as_mut() {
620 Some(o) => {
621 let (p, g) = o.device_ptr_mut(&stream_for_check);
622 (p as *mut T, Some(g))
623 }
624 None => (std::ptr::null_mut(), None),
625 };
626 let (v_ptr, _gv_opt): (*mut T, _) = match v_owned.as_mut() {
627 Some(o) => {
628 let (p, g) = o.device_ptr_mut(&stream_for_check);
629 (p as *mut T, Some(g))
630 }
631 None => (std::ptr::null_mut(), None),
632 };
633 let ws_slice = ws.as_mut().expect("workspace ensured");
634 let (ws_ptr, _g5) = ws_slice.device_ptr_mut(&stream_for_check);
635 let (info_ptr, _ginfo) = info_array.device_ptr(&stream_for_check);
636 let status = unsafe {
637 T::gesvdj_batched(
638 h.0.cu(),
639 jobz,
640 m,
641 n,
642 a_ptr as *mut T,
643 m,
644 s_ptr as *mut T,
645 u_ptr,
646 ldu,
647 v_ptr,
648 ldv,
649 ws_ptr as *mut T,
650 lwork,
651 info_ptr as *mut i32,
652 params.0,
653 batch_size,
654 )
655 };
656 drop((_g1, _g2, _g5, _ginfo, _gu_opt, _gv_opt));
657 status_to_result(status, "gesvdjBatched")?;
658 check_info_array(
659 &info_array,
660 &stream_for_check,
661 "gesvdjBatched",
662 batch_size.max(0) as usize,
663 )?;
664 Ok((a_owned, s_owned, u_owned, v_owned, info_array, params))
665 });
666}
667
668#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn batched_request_round_trip() {
678 fn assert_dispatch<R: SolverDispatch>() {}
679 assert_dispatch::<GetrfBatchedRequest<f32>>();
680 assert_dispatch::<GetrfBatchedRequest<f64>>();
681 assert_dispatch::<PotrfBatchedRequest<f32>>();
682 assert_dispatch::<PotrfBatchedRequest<f64>>();
683 assert_dispatch::<GesvdjBatchedRequest<f32>>();
684 assert_dispatch::<GesvdjBatchedRequest<f64>>();
685 }
686}