1use std::any::Any;
35use std::sync::Arc;
36
37use cudarc::driver::{CudaSlice, DeviceRepr, LaunchArgs, PushKernelArg};
38
39use atomr_accel::DType;
40
41use crate::completion::CompletionStrategy;
42use crate::device::DeviceState;
43use crate::dtype::CudaDtype;
44use crate::error::GpuError;
45use crate::gpu_ref::GpuRef;
46
47pub trait NvrtcLaunchDispatch: Send + 'static {
62 fn op_name(&self) -> &'static str;
66
67 fn dtype(&self) -> Option<DType>;
72
73 fn dispatch(self: Box<Self>, ctx: &NvrtcDispatchCtx<'_>);
76}
77
78pub struct NvrtcDispatchCtx<'a> {
87 pub stream: &'a Arc<cudarc::driver::CudaStream>,
88 pub completion: &'a Arc<dyn CompletionStrategy>,
89 pub state: &'a Arc<DeviceState>,
90}
91
92pub struct BlasDispatchCtx<'a> {
94 pub cublas: &'a Arc<cudarc::cublas::CudaBlas>,
95 pub stream: &'a Arc<cudarc::driver::CudaStream>,
96 pub completion: &'a Arc<dyn CompletionStrategy>,
97 pub state: &'a Arc<DeviceState>,
98}
99
100pub trait DevSliceArg: Send + Sync + 'static {
119 fn validate(&self) -> Result<Box<dyn Any + Send>, GpuError>;
123
124 fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) -> Result<(), GpuError>;
137
138 fn dtype(&self) -> Option<DType>;
141
142 fn len(&self) -> usize;
144
145 fn is_empty(&self) -> bool {
147 self.len() == 0
148 }
149}
150
151impl<T> DevSliceArg for GpuRef<T>
152where
153 T: CudaDtype,
154{
155 #[inline]
156 fn validate(&self) -> Result<Box<dyn Any + Send>, GpuError> {
157 let arc: Arc<CudaSlice<T>> = self.access()?.clone();
158 Ok(Box::new(arc))
159 }
160
161 #[inline]
162 fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) -> Result<(), GpuError> {
163 let arc = self.access()?;
164 builder.arg(&**arc);
167 Ok(())
168 }
169
170 #[inline]
171 fn dtype(&self) -> Option<DType> {
172 Some(<T as atomr_accel::AccelDtype>::KIND)
173 }
174
175 #[inline]
176 fn len(&self) -> usize {
177 GpuRef::<T>::len(self)
178 }
179}
180
181pub trait ScalarArg: Send + Sync + 'static {
189 fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>);
194
195 fn dtype(&self) -> Option<DType>;
197}
198
199impl<T> ScalarArg for T
200where
201 T: CudaDtype + DeviceRepr + Sync,
202{
203 #[inline]
204 fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) {
205 builder.arg(self);
206 }
207
208 #[inline]
209 fn dtype(&self) -> Option<DType> {
210 Some(<T as atomr_accel::AccelDtype>::KIND)
211 }
212}
213
214pub type GemmDispatchCtx<'a> = BlasDispatchCtx<'a>;
224
225pub trait GemmDispatch: Send + 'static {
227 fn dtype_name(&self) -> &'static str;
228 fn op_name(&self) -> &'static str;
229 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
230}
231
232pub trait GemmStridedBatchedDispatch: Send + 'static {
234 fn dtype_name(&self) -> &'static str;
235 fn op_name(&self) -> &'static str;
236 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
237}
238
239pub trait BlasL1Dispatch: Send + 'static {
241 fn dtype_name(&self) -> &'static str;
242 fn op_name(&self) -> &'static str;
243 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
244}
245
246pub trait BlasL2Dispatch: Send + 'static {
248 fn dtype_name(&self) -> &'static str;
249 fn op_name(&self) -> &'static str;
250 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
251}
252
253pub trait BlasL3Dispatch: Send + 'static {
255 fn dtype_name(&self) -> &'static str;
256 fn op_name(&self) -> &'static str;
257 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
258}
259
260#[cfg(feature = "cublaslt")]
261mod blaslt_dispatch_internal {
262 use std::sync::Arc;
265
266 use cudarc::cublaslt::CudaBlasLT;
267 use tokio::sync::oneshot;
268
269 use crate::completion::CompletionStrategy;
270 use crate::error::GpuError;
271 use crate::kernel::blas_lt::heuristic::HeuristicCacheRef;
272 use crate::kernel::blas_lt::workspace::WorkspacePool;
273
274 pub struct BlasLtDispatchCtx<'a> {
276 pub blas_lt: Arc<CudaBlasLT>,
277 pub stream: &'a Arc<cudarc::driver::CudaStream>,
278 pub completion: &'a Arc<dyn CompletionStrategy>,
279 pub workspace: &'a WorkspacePool,
280 pub heuristic: HeuristicCacheRef,
281 pub sm_arch: u32,
282 }
283
284 pub fn reply_unsupported(
285 reply: oneshot::Sender<Result<(), GpuError>>,
286 dtype_name: &'static str,
287 ) {
288 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
289 "BlasLtDispatch: dtype {dtype_name} unsupported in this build"
290 ))));
291 }
292}
293
294#[cfg(feature = "cublaslt")]
295pub use blaslt_dispatch_internal::{reply_unsupported, BlasLtDispatchCtx};
296
297#[cfg(feature = "cublaslt")]
300pub trait BlasLtDispatch: Send + 'static {
301 fn dtype_kind(&self) -> crate::dtype::DTypeKind;
302 fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>);
303}
304
305#[cfg(feature = "cudnn")]
308pub use cudnn_dispatch::{CudnnDispatch, CudnnDispatchCtx};
309
310#[cfg(feature = "cudnn")]
311mod cudnn_dispatch {
312 use std::sync::Arc;
313
314 use parking_lot::Mutex;
315
316 use crate::completion::CompletionStrategy;
317
318 pub struct CudnnDispatchCtx<'a> {
320 pub handle: Arc<cudarc::cudnn::Cudnn>,
321 pub stream: Arc<cudarc::driver::CudaStream>,
322 pub completion: Arc<dyn CompletionStrategy>,
323 pub plan_cache: &'a Mutex<crate::kernel::cudnn::graph::PlanCache>,
324 pub workspace: &'a Mutex<Option<cudarc::driver::CudaSlice<u8>>>,
325 }
326
327 pub trait CudnnDispatch: Send + 'static {
329 fn dtype_name(&self) -> &'static str;
330 fn op_kind(&self) -> &'static str;
331 fn dispatch(self: Box<Self>, ctx: &CudnnDispatchCtx<'_>);
332 }
333}
334
335#[cfg(feature = "cufft")]
339pub struct FftDispatchCtx<'a> {
340 pub stream: &'a Arc<cudarc::driver::CudaStream>,
341 pub completion: &'a Arc<dyn CompletionStrategy>,
342 pub plan: Arc<dyn std::any::Any + Send + Sync>,
346}
347
348#[cfg(feature = "cufft")]
351pub trait FftDispatch: Send + 'static {
352 fn dtype_kind(&self) -> DType;
353 fn plan_key(&self) -> crate::kernel::fft::PlanKey;
354 fn dispatch(self: Box<Self>, ctx: &FftDispatchCtx<'_>);
355}
356
357pub trait RngDispatch: Send + 'static {
365 fn fill(
366 self: Box<Self>,
367 generator: cudarc::curand::sys::curandGenerator_t,
368 stream: &Arc<cudarc::driver::CudaStream>,
369 completion: &Arc<dyn CompletionStrategy>,
370 ) -> Result<(), GpuError>;
371}
372
373#[cfg(feature = "cusolver")]
376pub use crate::kernel::solver::SolverDispatch;
377
378#[cfg(feature = "cusparse")]
382pub struct SendSparseHandle(pub cudarc::cusparse::sys::cusparseHandle_t);
383#[cfg(feature = "cusparse")]
384unsafe impl Send for SendSparseHandle {}
385#[cfg(feature = "cusparse")]
386unsafe impl Sync for SendSparseHandle {}
387
388#[cfg(feature = "cusparse")]
390pub struct SparseDispatchCtx<'a> {
391 pub handle: &'a parking_lot::Mutex<SendSparseHandle>,
392 pub stream: &'a Arc<cudarc::driver::CudaStream>,
393 pub completion: &'a Arc<dyn CompletionStrategy>,
394 pub workspace: &'a parking_lot::Mutex<Option<cudarc::driver::CudaSlice<u8>>>,
395}
396
397#[cfg(feature = "cusparse")]
399#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
400pub enum SparseOp {
401 SpMv,
402 SpMm,
403 SpGemm,
404 SpSv,
405 Sddmm,
406 DenseToSparse,
407 SparseToDense,
408 Convert,
409}
410
411#[cfg(feature = "cusparse")]
412impl SparseOp {
413 pub fn as_str(self) -> &'static str {
414 match self {
415 SparseOp::SpMv => "spmv",
416 SparseOp::SpMm => "spmm",
417 SparseOp::SpGemm => "spgemm",
418 SparseOp::SpSv => "spsv",
419 SparseOp::Sddmm => "sddmm",
420 SparseOp::DenseToSparse => "dense_to_sparse",
421 SparseOp::SparseToDense => "sparse_to_dense",
422 SparseOp::Convert => "convert",
423 }
424 }
425}
426
427#[cfg(feature = "cusparse")]
429pub trait SparseDispatch: Send + 'static {
430 fn op_name(&self) -> SparseOp;
431 fn dtype(&self) -> DType;
432 fn dispatch(self: Box<Self>, ctx: &SparseDispatchCtx<'_>);
433}
434
435#[cfg(feature = "cutensor")]
437pub use cutensor_dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
438
439#[cfg(feature = "cutensor")]
440mod cutensor_dispatch {
441 use std::sync::Arc;
442
443 use parking_lot::Mutex;
444
445 use crate::completion::CompletionStrategy;
446 use crate::error::GpuError;
447 use crate::kernel::tensor::plan_cache::PlanCache;
448 use crate::kernel::tensor::SendHandle;
449
450 pub struct WorkspacePool {
451 stream: Arc<cudarc::driver::CudaStream>,
452 buckets: Mutex<Vec<Bucket>>,
453 }
454
455 struct Bucket {
456 size: usize,
457 slice: cudarc::driver::CudaSlice<u8>,
458 }
459
460 impl WorkspacePool {
461 pub fn new(stream: Arc<cudarc::driver::CudaStream>) -> Self {
462 Self {
463 stream,
464 buckets: Mutex::new(Vec::new()),
465 }
466 }
467
468 pub fn ensure(&self, n: usize) -> Result<usize, GpuError> {
469 if n == 0 {
470 return Ok(0);
471 }
472 let bucket_size = n.next_power_of_two();
473 let mut g = self.buckets.lock();
474 if g.iter().any(|b| b.size == bucket_size) {
475 return Ok(bucket_size);
476 }
477 let slice = self
478 .stream
479 .alloc_zeros::<u8>(bucket_size)
480 .map_err(|e| GpuError::OutOfMemory(format!("cutensor workspace: {e}")))?;
481 g.push(Bucket {
482 size: bucket_size,
483 slice,
484 });
485 Ok(bucket_size)
486 }
487
488 pub fn with_bucket<F, R>(&self, n: usize, f: F) -> Option<R>
489 where
490 F: FnOnce(&mut cudarc::driver::CudaSlice<u8>) -> R,
491 {
492 if n == 0 {
493 return None;
494 }
495 let bucket_size = n.next_power_of_two();
496 let mut g = self.buckets.lock();
497 let b = g.iter_mut().find(|b| b.size == bucket_size)?;
498 Some(f(&mut b.slice))
499 }
500 }
501
502 pub struct TensorDispatchCtx {
503 pub handle: Arc<Mutex<SendHandle>>,
504 pub stream: Arc<cudarc::driver::CudaStream>,
505 pub completion: Arc<dyn CompletionStrategy>,
506 pub plan_cache: Arc<PlanCache>,
507 pub workspace: Arc<WorkspacePool>,
508 }
509
510 pub trait TensorDispatch: Send + 'static {
511 fn op_tag(&self) -> &'static str;
512 fn dtype_tag(&self) -> &'static str;
513 fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx);
514 fn fail_mock(self: Box<Self>);
515 }
516}
517
518#[cfg(feature = "nccl")]
521pub use atomr_accel::DType as DispatchDType;
522
523#[cfg(feature = "nccl")]
529pub trait CollectiveDispatch: Send + 'static {
530 fn dtype_kind(&self) -> DispatchDType;
531 fn device_id(&self) -> Option<u32>;
532 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>);
533}
534
535#[cfg(feature = "nccl")]
539pub struct CollectiveDispatchCtx<'a> {
540 pub comm: &'a cudarc::nccl::Comm,
541 pub state: &'a Arc<DeviceState>,
542 pub completion: &'a Arc<dyn CompletionStrategy>,
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 struct DummyNvrtc {
552 op: &'static str,
553 d: Option<DType>,
554 called: std::sync::atomic::AtomicBool,
555 }
556
557 impl NvrtcLaunchDispatch for DummyNvrtc {
558 fn op_name(&self) -> &'static str {
559 self.op
560 }
561 fn dtype(&self) -> Option<DType> {
562 self.d
563 }
564 fn dispatch(self: Box<Self>, _ctx: &NvrtcDispatchCtx<'_>) {
565 self.called.store(true, std::sync::atomic::Ordering::SeqCst);
566 }
567 }
568
569 #[test]
570 fn nvrtc_dispatch_box_round_trip() {
571 let req = DummyNvrtc {
572 op: "relu",
573 d: Some(DType::F32),
574 called: std::sync::atomic::AtomicBool::new(false),
575 };
576 let boxed: Box<dyn NvrtcLaunchDispatch> = Box::new(req);
578 assert_eq!(boxed.op_name(), "relu");
579 assert_eq!(boxed.dtype(), Some(DType::F32));
580
581 let req2 = DummyNvrtc {
588 op: "noop",
589 d: None,
590 called: std::sync::atomic::AtomicBool::new(false),
591 };
592 assert_eq!(req2.op_name(), "noop");
593 assert_eq!(req2.dtype(), None);
594 }
595
596 #[allow(dead_code)]
601 fn _assert_dev_slice_arg_object_safe() {
602 fn takes_box(_: Box<dyn DevSliceArg>) {}
603 let _: fn(GpuRef<f32>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
605 let _: fn(GpuRef<f64>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
606 let _: fn(GpuRef<u8>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
607 let _: fn(GpuRef<i32>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
608 let _ = takes_box;
609 #[cfg(feature = "f16")]
610 {
611 let _: fn(GpuRef<half::f16>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
612 let _: fn(GpuRef<half::bf16>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
613 }
614 }
615
616 #[test]
617 fn dev_slice_arg_for_gpu_ref() {
618 _assert_dev_slice_arg_object_safe();
621 }
622
623 #[test]
626 fn scalar_arg_blanket_impls_compile() {
627 fn takes(_: Box<dyn ScalarArg>) {}
628 takes(Box::new(1.0f32));
629 takes(Box::new(2.0f64));
630 takes(Box::new(3i32));
631 takes(Box::new(4u32));
632 takes(Box::new(5u64));
633 #[cfg(feature = "f16")]
634 {
635 takes(Box::new(half::f16::ONE));
636 takes(Box::new(half::bf16::ONE));
637 }
638 }
639
640 #[test]
643 fn stub_dispatch_traits_compile() {
644 fn _gemm(_: Box<dyn GemmDispatch>) {}
645 #[cfg(feature = "cublaslt")]
646 fn _blaslt(_: Box<dyn BlasLtDispatch>) {}
647 #[cfg(feature = "cudnn")]
648 fn _cudnn(_: Box<dyn CudnnDispatch>) {}
649 #[cfg(feature = "cufft")]
650 fn _fft(_: Box<dyn FftDispatch>) {}
651 fn _rng(_: Box<dyn RngDispatch>) {}
652 #[cfg(feature = "cusolver")]
653 fn _solver(_: Box<dyn crate::kernel::solver::SolverDispatch>) {}
654 #[cfg(feature = "cusparse")]
655 fn _sparse(_: Box<dyn SparseDispatch>) {}
656 #[cfg(feature = "cutensor")]
657 fn _tensor(_: Box<dyn TensorDispatch>) {}
658 #[cfg(feature = "nccl")]
659 fn _coll(_: Box<dyn CollectiveDispatch>) {}
660 }
661}