Skip to main content

atomr_accel_cuda/kernel/
dispatch.rs

1//! Per-actor `*Dispatch` traits + their `*DispatchCtx` bundles
2//! (Phase 0.3).
3//!
4//! Each kernel actor (cuBLAS, cuBLASLt, cuDNN, cuFFT, cuRAND,
5//! cuSOLVER, cuSPARSE, cuTENSOR, NCCL, NVRTC) eventually exposes a
6//! typed public API like:
7//!
8//! ```ignore
9//! blas_actor.tell(BlasMsg::gemm::<f16>(GemmRequest { ... }));
10//! ```
11//!
12//! Internally, the request is boxed as `Box<dyn GemmDispatch>` and
13//! the actor's handle loop calls `dispatch(self, &ctx)` which carries
14//! the cuBLAS handle, stream, and completion strategy. This avoids
15//! the N-fold (op × dtype) variant explosion in the actor's `Msg`
16//! enum without giving up typed `GpuRef<T>` requests on the public
17//! API.
18//!
19//! ## Status
20//!
21//! In Phase 0.3 only **NVRTC** actually adopts the pattern: see
22//! [`NvrtcLaunchDispatch`] + [`NvrtcDispatchCtx`] and the migrated
23//! [`NvrtcActor`](super::NvrtcActor). The remaining actor traits
24//! ([`GemmDispatch`], [`BlasLtDispatch`], …) ship as stub trait
25//! declarations with `*DispatchCtx<'a>` placeholder structs whose
26//! handle fields are `PhantomData` until the matching actor migrates
27//! in a follow-up PR (per the migration order in the Phase 0 plan).
28//!
29//! The shared kernel-arg traits ([`DevSliceArg`], [`ScalarArg`])
30//! collapse the per-dtype `KernelArg::DevSlice*` / `KernelArg::Scalar*`
31//! variants into a single boxed-dyn pair plus a literal `Usize(usize)`
32//! variant.
33
34use std::any::Any;
35use std::sync::Arc;
36
37use cudarc::driver::{CudaSlice, DeviceRepr, LaunchArgs, PushKernelArg};
38
39use atomr_accel::DType;
40
41use crate::completion::CompletionStrategy;
42use crate::device::DeviceState;
43use crate::dtype::CudaDtype;
44use crate::error::GpuError;
45use crate::gpu_ref::GpuRef;
46
47// ---------------------------------------------------------------------------
48// NVRTC — the only adopter in Phase 0.3.
49// ---------------------------------------------------------------------------
50
51/// Boxed-dispatch trait for an NVRTC kernel launch request.
52///
53/// `NvrtcMsg::Launch` carries an `args: Vec<KernelArg>`; each arg
54/// implements either [`DevSliceArg`] (typed device buffers) or
55/// [`ScalarArg`] (typed host scalars). The `Launch` variant itself
56/// does **not** need a `Box<dyn NvrtcLaunchDispatch>` payload because
57/// `KernelHandle` already carries the typed `CudaFunction` —
58/// [`NvrtcLaunchDispatch`] is a marker for the *request as a whole*
59/// so cross-actor tooling (NVTX naming, `KernelTrace`, future graph
60/// recording) sees a uniform interface across all actors.
61pub trait NvrtcLaunchDispatch: Send + 'static {
62    /// Static op identifier surfaced to NVTX / `KernelTrace`. NVRTC
63    /// kernels are user-supplied so this returns `"nvrtc_launch"` by
64    /// default; callers may override with the kernel name.
65    fn op_name(&self) -> &'static str;
66
67    /// Element dtype, when the request has a single-dtype identity.
68    /// `None` for traceless / multi-dtype actors (NVRTC is multi-dtype
69    /// because user kernels can mix arg types — implementors return
70    /// `None`).
71    fn dtype(&self) -> Option<DType>;
72
73    /// Run the dispatch: validate inputs, enqueue the kernel, deliver
74    /// the reply via the completion strategy in `ctx`.
75    fn dispatch(self: Box<Self>, ctx: &NvrtcDispatchCtx<'_>);
76}
77
78/// Per-launch context bundle for [`NvrtcLaunchDispatch::dispatch`].
79///
80/// Pulled by reference from the `NvrtcActor` `Real { ... }` variant.
81/// Trait-object implementations consume this when actually wiring up
82/// the cudarc `launch_builder().arg(..)` chain — Phase 0.3's NVRTC
83/// migration uses the per-arg [`DevSliceArg`] / [`ScalarArg`] traits
84/// rather than a request-level dispatcher, so the only reason this
85/// type exists today is API symmetry with the other actors below.
86pub struct NvrtcDispatchCtx<'a> {
87    pub stream: &'a Arc<cudarc::driver::CudaStream>,
88    pub completion: &'a Arc<dyn CompletionStrategy>,
89    pub state: &'a Arc<DeviceState>,
90}
91
92/// Bundle of resources every cuBLAS dispatcher needs to run an op.
93pub struct BlasDispatchCtx<'a> {
94    pub cublas: &'a Arc<cudarc::cublas::CudaBlas>,
95    pub stream: &'a Arc<cudarc::driver::CudaStream>,
96    pub completion: &'a Arc<dyn CompletionStrategy>,
97    pub state: &'a Arc<DeviceState>,
98}
99
100// ---------------------------------------------------------------------------
101// Shared kernel-arg traits (NVRTC adopts these in Phase 0.3).
102// ---------------------------------------------------------------------------
103
104/// Type-erased typed-device-slice argument for an NVRTC kernel
105/// launch.
106///
107/// Boxed as `Box<dyn DevSliceArg>` inside `KernelArg::DevSlice` so the
108/// per-launch `Vec<KernelArg>` does not need one variant per dtype.
109/// The blanket impl `impl<T: CudaDtype> DevSliceArg for GpuRef<T>`
110/// covers every dtype the runtime understands; raw byte buffers
111/// (`GpuRef<u8>`) are handled by the same impl because `u8` is a
112/// `CudaDtype`.
113///
114/// # Object safety
115///
116/// Methods take `&self` and never return `Self` so the trait is
117/// usable as `dyn DevSliceArg`.
118pub trait DevSliceArg: Send + Sync + 'static {
119    /// Validate the underlying [`GpuRef`] and return a keep-alive
120    /// owner. The caller stores this `Box<dyn Any + Send>` in a `Vec`
121    /// to keep the device buffer alive until the kernel completes.
122    fn validate(&self) -> Result<Box<dyn Any + Send>, GpuError>;
123
124    /// Push the device-pointer reference onto `builder`. Implementors
125    /// re-`access()` the `GpuRef` (cheap — pointer-equality check
126    /// against `DeviceState.generation`) and call
127    /// [`PushKernelArg::arg`] with `&CudaSlice<T>`.
128    ///
129    /// The `'a` lifetime ties `&self` to the builder so the pushed
130    /// device-pointer reference is borrowed from `Self` for as long as
131    /// `builder` lives.
132    ///
133    /// Returns `Err(GpuError::GpuRefStale)` if the buffer has gone
134    /// stale between `validate` and `push` (rare — only happens if a
135    /// context rebuild raced inside the actor).
136    fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) -> Result<(), GpuError>;
137
138    /// Element dtype for tracing / debugging. Always `Some(..)` for
139    /// the default `GpuRef<T: CudaDtype>` impl.
140    fn dtype(&self) -> Option<DType>;
141
142    /// Length of the underlying slice in elements.
143    fn len(&self) -> usize;
144
145    /// True iff the slice has zero elements.
146    fn is_empty(&self) -> bool {
147        self.len() == 0
148    }
149}
150
151impl<T> DevSliceArg for GpuRef<T>
152where
153    T: CudaDtype,
154{
155    #[inline]
156    fn validate(&self) -> Result<Box<dyn Any + Send>, GpuError> {
157        let arc: Arc<CudaSlice<T>> = self.access()?.clone();
158        Ok(Box::new(arc))
159    }
160
161    #[inline]
162    fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) -> Result<(), GpuError> {
163        let arc = self.access()?;
164        // `arc` is `&Arc<CudaSlice<T>>`; `&**arc` is `&CudaSlice<T>`,
165        // which is exactly what `LaunchArgs::arg` accepts.
166        builder.arg(&**arc);
167        Ok(())
168    }
169
170    #[inline]
171    fn dtype(&self) -> Option<DType> {
172        Some(<T as atomr_accel::AccelDtype>::KIND)
173    }
174
175    #[inline]
176    fn len(&self) -> usize {
177        GpuRef::<T>::len(self)
178    }
179}
180
181/// Type-erased typed-host-scalar argument for an NVRTC kernel launch.
182///
183/// Boxed as `Box<dyn ScalarArg>` inside `KernelArg::Scalar`. A blanket
184/// impl `impl<T: CudaDtype> ScalarArg for T` covers every dtype the
185/// runtime understands. `usize` and `bool` are *not* `CudaDtype` so
186/// callers use the dedicated [`super::nvrtc::KernelArg::Usize`]
187/// variant for sizes — the most common scalar arg by far.
188pub trait ScalarArg: Send + Sync + 'static {
189    /// Push the scalar value onto `builder` by reference.
190    ///
191    /// The `'a` lifetime ties `&self` to the builder so the borrowed
192    /// scalar reference lives at least as long as `builder`.
193    fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>);
194
195    /// Dtype for tracing / debugging.
196    fn dtype(&self) -> Option<DType>;
197}
198
199impl<T> ScalarArg for T
200where
201    T: CudaDtype + DeviceRepr + Sync,
202{
203    #[inline]
204    fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) {
205        builder.arg(self);
206    }
207
208    #[inline]
209    fn dtype(&self) -> Option<DType> {
210        Some(<T as atomr_accel::AccelDtype>::KIND)
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Stub trait declarations for the remaining kernel actors. Each
216// per-actor migration ships its own impl alongside the migrated actor
217// file in a follow-up PR (Phases 1.x onward).
218// ---------------------------------------------------------------------------
219
220/// `GemmDispatchCtx` is now an alias for `BlasDispatchCtx` (the cuBLAS
221/// agent unified the per-op contexts since every cuBLAS dispatcher
222/// needs the same set: cuBLAS handle, stream, completion, state).
223pub type GemmDispatchCtx<'a> = BlasDispatchCtx<'a>;
224
225/// Erased `GemmRequest<T>`. Implementors live in `kernel::blas::gemm`.
226pub trait GemmDispatch: Send + 'static {
227    fn dtype_name(&self) -> &'static str;
228    fn op_name(&self) -> &'static str;
229    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
230}
231
232/// Erased `GemmStridedBatchedRequest<T>`.
233pub trait GemmStridedBatchedDispatch: Send + 'static {
234    fn dtype_name(&self) -> &'static str;
235    fn op_name(&self) -> &'static str;
236    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
237}
238
239/// Erased L1 ops: axpy, dot, nrm2, scal, asum, iamax, iamin, copy, swap, rot.
240pub trait BlasL1Dispatch: Send + 'static {
241    fn dtype_name(&self) -> &'static str;
242    fn op_name(&self) -> &'static str;
243    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
244}
245
246/// Erased L2 ops: gemv, ger.
247pub trait BlasL2Dispatch: Send + 'static {
248    fn dtype_name(&self) -> &'static str;
249    fn op_name(&self) -> &'static str;
250    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
251}
252
253/// Erased L3 ops other than gemm: geam, syrk, trsm.
254pub trait BlasL3Dispatch: Send + 'static {
255    fn dtype_name(&self) -> &'static str;
256    fn op_name(&self) -> &'static str;
257    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
258}
259
260#[cfg(feature = "cublaslt")]
261mod blaslt_dispatch_internal {
262    //! Hidden helper: cuBLASLt context type names cudarc/internal types
263    //! without leaking them through the public `BlasLtDispatch` surface.
264    use std::sync::Arc;
265
266    use cudarc::cublaslt::CudaBlasLT;
267    use tokio::sync::oneshot;
268
269    use crate::completion::CompletionStrategy;
270    use crate::error::GpuError;
271    use crate::kernel::blas_lt::heuristic::HeuristicCacheRef;
272    use crate::kernel::blas_lt::workspace::WorkspacePool;
273
274    /// Per-call context handed to a `BlasLtDispatch::dispatch` impl.
275    pub struct BlasLtDispatchCtx<'a> {
276        pub blas_lt: Arc<CudaBlasLT>,
277        pub stream: &'a Arc<cudarc::driver::CudaStream>,
278        pub completion: &'a Arc<dyn CompletionStrategy>,
279        pub workspace: &'a WorkspacePool,
280        pub heuristic: HeuristicCacheRef,
281        pub sm_arch: u32,
282    }
283
284    pub fn reply_unsupported(
285        reply: oneshot::Sender<Result<(), GpuError>>,
286        dtype_name: &'static str,
287    ) {
288        let _ = reply.send(Err(GpuError::Unrecoverable(format!(
289            "BlasLtDispatch: dtype {dtype_name} unsupported in this build"
290        ))));
291    }
292}
293
294#[cfg(feature = "cublaslt")]
295pub use blaslt_dispatch_internal::{reply_unsupported, BlasLtDispatchCtx};
296
297/// Boxed-dispatch trait the cuBLASLt actor uses to call into a typed
298/// `MatmulRequest<T>` after type-erasing it through the mailbox.
299#[cfg(feature = "cublaslt")]
300pub trait BlasLtDispatch: Send + 'static {
301    fn dtype_kind(&self) -> crate::dtype::DTypeKind;
302    fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>);
303}
304
305/// `CudnnDispatch` is owned by `kernel::cudnn` (Phase 2 cuDNN).
306/// Re-exported here for symmetry with other actors' dispatch traits.
307#[cfg(feature = "cudnn")]
308pub use cudnn_dispatch::{CudnnDispatch, CudnnDispatchCtx};
309
310#[cfg(feature = "cudnn")]
311mod cudnn_dispatch {
312    use std::sync::Arc;
313
314    use parking_lot::Mutex;
315
316    use crate::completion::CompletionStrategy;
317
318    /// Context handed to a [`CudnnDispatch::dispatch`] call.
319    pub struct CudnnDispatchCtx<'a> {
320        pub handle: Arc<cudarc::cudnn::Cudnn>,
321        pub stream: Arc<cudarc::driver::CudaStream>,
322        pub completion: Arc<dyn CompletionStrategy>,
323        pub plan_cache: &'a Mutex<crate::kernel::cudnn::graph::PlanCache>,
324        pub workspace: &'a Mutex<Option<cudarc::driver::CudaSlice<u8>>>,
325    }
326
327    /// Dispatch trait for typed cuDNN ops.
328    pub trait CudnnDispatch: Send + 'static {
329        fn dtype_name(&self) -> &'static str;
330        fn op_kind(&self) -> &'static str;
331        fn dispatch(self: Box<Self>, ctx: &CudnnDispatchCtx<'_>);
332    }
333}
334
335/// Per-execution context bundle handed to every [`FftDispatch::dispatch`].
336/// The actor packs its current stream + completion strategy + plan handle
337/// (already resolved against the LRU cache) so dispatch impls stay lean.
338#[cfg(feature = "cufft")]
339pub struct FftDispatchCtx<'a> {
340    pub stream: &'a Arc<cudarc::driver::CudaStream>,
341    pub completion: &'a Arc<dyn CompletionStrategy>,
342    /// Already-resolved cuFFT plan (`Arc<CudaFft>`). Type-erased to
343    /// `dyn Any` to keep this trait import-light; the actor downcasts
344    /// inside `kernel::fft`.
345    pub plan: Arc<dyn std::any::Any + Send + Sync>,
346}
347
348/// Dispatch trait for typed cuFFT requests (`FftRequest<T>` for
349/// `T: FftSupported`).
350#[cfg(feature = "cufft")]
351pub trait FftDispatch: Send + 'static {
352    fn dtype_kind(&self) -> DType;
353    fn plan_key(&self) -> crate::kernel::fft::PlanKey;
354    fn dispatch(self: Box<Self>, ctx: &FftDispatchCtx<'_>);
355}
356
357/// Erased payload accepted by `RngActor` via `RngMsg::Fill`.
358///
359/// The actor takes the cuRAND generator lock and hands it to `fill` along
360/// with the stream + completion strategy. Implementors call
361/// `cudarc::curand::sys::curandGenerate*` (or the safe wrapper),
362/// keep-alive their `GpuRef<T>` via `kernel::envelope::run_kernel`,
363/// and reply on the embedded `oneshot` channel.
364pub trait RngDispatch: Send + 'static {
365    fn fill(
366        self: Box<Self>,
367        generator: cudarc::curand::sys::curandGenerator_t,
368        stream: &Arc<cudarc::driver::CudaStream>,
369        completion: &Arc<dyn CompletionStrategy>,
370    ) -> Result<(), GpuError>;
371}
372
373/// `SolverDispatch` is owned by `kernel::solver` (Phase 1 cuSOLVER).
374/// Re-exported here for API symmetry with other actors' dispatch traits.
375#[cfg(feature = "cusolver")]
376pub use crate::kernel::solver::SolverDispatch;
377
378/// Phase 4 cuSPARSE handle wrapper. Raw pointer is `!Send` by default;
379/// cuSPARSE is thread-safe per-handle as long as a given handle is only
380/// touched by one stream at a time (the actor-per-handle invariant).
381#[cfg(feature = "cusparse")]
382pub struct SendSparseHandle(pub cudarc::cusparse::sys::cusparseHandle_t);
383#[cfg(feature = "cusparse")]
384unsafe impl Send for SendSparseHandle {}
385#[cfg(feature = "cusparse")]
386unsafe impl Sync for SendSparseHandle {}
387
388/// Per-call context handed to a `SparseDispatch::dispatch` impl.
389#[cfg(feature = "cusparse")]
390pub struct SparseDispatchCtx<'a> {
391    pub handle: &'a parking_lot::Mutex<SendSparseHandle>,
392    pub stream: &'a Arc<cudarc::driver::CudaStream>,
393    pub completion: &'a Arc<dyn CompletionStrategy>,
394    pub workspace: &'a parking_lot::Mutex<Option<cudarc::driver::CudaSlice<u8>>>,
395}
396
397/// Op-kind tag a `SparseDispatch` exposes.
398#[cfg(feature = "cusparse")]
399#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
400pub enum SparseOp {
401    SpMv,
402    SpMm,
403    SpGemm,
404    SpSv,
405    Sddmm,
406    DenseToSparse,
407    SparseToDense,
408    Convert,
409}
410
411#[cfg(feature = "cusparse")]
412impl SparseOp {
413    pub fn as_str(self) -> &'static str {
414        match self {
415            SparseOp::SpMv => "spmv",
416            SparseOp::SpMm => "spmm",
417            SparseOp::SpGemm => "spgemm",
418            SparseOp::SpSv => "spsv",
419            SparseOp::Sddmm => "sddmm",
420            SparseOp::DenseToSparse => "dense_to_sparse",
421            SparseOp::SparseToDense => "sparse_to_dense",
422            SparseOp::Convert => "convert",
423        }
424    }
425}
426
427/// Box-erased cuSPARSE op (Phase 4).
428#[cfg(feature = "cusparse")]
429pub trait SparseDispatch: Send + 'static {
430    fn op_name(&self) -> SparseOp;
431    fn dtype(&self) -> DType;
432    fn dispatch(self: Box<Self>, ctx: &SparseDispatchCtx<'_>);
433}
434
435/// `TensorDispatch` is owned by `kernel::tensor` (Phase 2 cuTENSOR).
436#[cfg(feature = "cutensor")]
437pub use cutensor_dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
438
439#[cfg(feature = "cutensor")]
440mod cutensor_dispatch {
441    use std::sync::Arc;
442
443    use parking_lot::Mutex;
444
445    use crate::completion::CompletionStrategy;
446    use crate::error::GpuError;
447    use crate::kernel::tensor::plan_cache::PlanCache;
448    use crate::kernel::tensor::SendHandle;
449
450    pub struct WorkspacePool {
451        stream: Arc<cudarc::driver::CudaStream>,
452        buckets: Mutex<Vec<Bucket>>,
453    }
454
455    struct Bucket {
456        size: usize,
457        slice: cudarc::driver::CudaSlice<u8>,
458    }
459
460    impl WorkspacePool {
461        pub fn new(stream: Arc<cudarc::driver::CudaStream>) -> Self {
462            Self {
463                stream,
464                buckets: Mutex::new(Vec::new()),
465            }
466        }
467
468        pub fn ensure(&self, n: usize) -> Result<usize, GpuError> {
469            if n == 0 {
470                return Ok(0);
471            }
472            let bucket_size = n.next_power_of_two();
473            let mut g = self.buckets.lock();
474            if g.iter().any(|b| b.size == bucket_size) {
475                return Ok(bucket_size);
476            }
477            let slice = self
478                .stream
479                .alloc_zeros::<u8>(bucket_size)
480                .map_err(|e| GpuError::OutOfMemory(format!("cutensor workspace: {e}")))?;
481            g.push(Bucket {
482                size: bucket_size,
483                slice,
484            });
485            Ok(bucket_size)
486        }
487
488        pub fn with_bucket<F, R>(&self, n: usize, f: F) -> Option<R>
489        where
490            F: FnOnce(&mut cudarc::driver::CudaSlice<u8>) -> R,
491        {
492            if n == 0 {
493                return None;
494            }
495            let bucket_size = n.next_power_of_two();
496            let mut g = self.buckets.lock();
497            let b = g.iter_mut().find(|b| b.size == bucket_size)?;
498            Some(f(&mut b.slice))
499        }
500    }
501
502    pub struct TensorDispatchCtx {
503        pub handle: Arc<Mutex<SendHandle>>,
504        pub stream: Arc<cudarc::driver::CudaStream>,
505        pub completion: Arc<dyn CompletionStrategy>,
506        pub plan_cache: Arc<PlanCache>,
507        pub workspace: Arc<WorkspacePool>,
508    }
509
510    pub trait TensorDispatch: Send + 'static {
511        fn op_tag(&self) -> &'static str;
512        fn dtype_tag(&self) -> &'static str;
513        fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx);
514        fn fail_mock(self: Box<Self>);
515    }
516}
517
518/// Alias used by the NCCL CollectiveDispatch (Phase 2). Maps onto the
519/// canonical `atomr_accel::DType`.
520#[cfg(feature = "nccl")]
521pub use atomr_accel::DType as DispatchDType;
522
523/// Boxed-dispatch trait for NCCL collectives. The `CollectiveActor`
524/// handles the message envelope; each typed request struct (e.g.
525/// `AllReduceRequest<T: NcclReduceSupported>`) implements this so the
526/// actor stays single-mailbox while the dtype dimension travels in the
527/// box.
528#[cfg(feature = "nccl")]
529pub trait CollectiveDispatch: Send + 'static {
530    fn dtype_kind(&self) -> DispatchDType;
531    fn device_id(&self) -> Option<u32>;
532    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>);
533}
534
535/// Per-call context handed to a `CollectiveDispatch::dispatch` impl.
536/// Carries the NCCL communicator (cudarc wraps it) plus the device
537/// state and completion strategy.
538#[cfg(feature = "nccl")]
539pub struct CollectiveDispatchCtx<'a> {
540    pub comm: &'a cudarc::nccl::Comm,
541    pub state: &'a Arc<DeviceState>,
542    pub completion: &'a Arc<dyn CompletionStrategy>,
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    /// A no-GPU stand-in for `NvrtcLaunchDispatch` that records its
550    /// `op_name` / `dtype` queries and asserts `dispatch` is called.
551    struct DummyNvrtc {
552        op: &'static str,
553        d: Option<DType>,
554        called: std::sync::atomic::AtomicBool,
555    }
556
557    impl NvrtcLaunchDispatch for DummyNvrtc {
558        fn op_name(&self) -> &'static str {
559            self.op
560        }
561        fn dtype(&self) -> Option<DType> {
562            self.d
563        }
564        fn dispatch(self: Box<Self>, _ctx: &NvrtcDispatchCtx<'_>) {
565            self.called.store(true, std::sync::atomic::Ordering::SeqCst);
566        }
567    }
568
569    #[test]
570    fn nvrtc_dispatch_box_round_trip() {
571        let req = DummyNvrtc {
572            op: "relu",
573            d: Some(DType::F32),
574            called: std::sync::atomic::AtomicBool::new(false),
575        };
576        // Box and downcast through the trait surface (op_name / dtype).
577        let boxed: Box<dyn NvrtcLaunchDispatch> = Box::new(req);
578        assert_eq!(boxed.op_name(), "relu");
579        assert_eq!(boxed.dtype(), Some(DType::F32));
580
581        // We can't construct an `NvrtcDispatchCtx` without a real
582        // stream, so we only verify boxed dispatch indirectly via a
583        // local pointer-equal struct-internal flag through a second
584        // request (the type system already proves the call site
585        // compiles via the round-trip above). The full GPU-side
586        // dispatch path is exercised by the migrated NVRTC actor.
587        let req2 = DummyNvrtc {
588            op: "noop",
589            d: None,
590            called: std::sync::atomic::AtomicBool::new(false),
591        };
592        assert_eq!(req2.op_name(), "noop");
593        assert_eq!(req2.dtype(), None);
594    }
595
596    /// Confirms `Box<dyn DevSliceArg>` for a `GpuRef<f32>` and (under
597    /// `f16`) `GpuRef<half::f16>` compile. We can't construct a real
598    /// `GpuRef` without a CUDA context, so this is a compile-only
599    /// witness via a function-shaped assertion.
600    #[allow(dead_code)]
601    fn _assert_dev_slice_arg_object_safe() {
602        fn takes_box(_: Box<dyn DevSliceArg>) {}
603        // Witness: implementor type is the trait object's interior.
604        let _: fn(GpuRef<f32>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
605        let _: fn(GpuRef<f64>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
606        let _: fn(GpuRef<u8>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
607        let _: fn(GpuRef<i32>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
608        let _ = takes_box;
609        #[cfg(feature = "f16")]
610        {
611            let _: fn(GpuRef<half::f16>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
612            let _: fn(GpuRef<half::bf16>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
613        }
614    }
615
616    #[test]
617    fn dev_slice_arg_for_gpu_ref() {
618        // Compile-only witness: instantiate the closures above so the
619        // function pointers are realized.
620        _assert_dev_slice_arg_object_safe();
621    }
622
623    /// Compile-only witness that `Box<dyn ScalarArg>` round-trips for
624    /// every primitive `CudaDtype`.
625    #[test]
626    fn scalar_arg_blanket_impls_compile() {
627        fn takes(_: Box<dyn ScalarArg>) {}
628        takes(Box::new(1.0f32));
629        takes(Box::new(2.0f64));
630        takes(Box::new(3i32));
631        takes(Box::new(4u32));
632        takes(Box::new(5u64));
633        #[cfg(feature = "f16")]
634        {
635            takes(Box::new(half::f16::ONE));
636            takes(Box::new(half::bf16::ONE));
637        }
638    }
639
640    /// Stub-trait sanity: every non-NVRTC dispatch trait-object is
641    /// at least nameable. (`*DispatchCtx<'_>` placeholders compile.)
642    #[test]
643    fn stub_dispatch_traits_compile() {
644        fn _gemm(_: Box<dyn GemmDispatch>) {}
645        #[cfg(feature = "cublaslt")]
646        fn _blaslt(_: Box<dyn BlasLtDispatch>) {}
647        #[cfg(feature = "cudnn")]
648        fn _cudnn(_: Box<dyn CudnnDispatch>) {}
649        #[cfg(feature = "cufft")]
650        fn _fft(_: Box<dyn FftDispatch>) {}
651        fn _rng(_: Box<dyn RngDispatch>) {}
652        #[cfg(feature = "cusolver")]
653        fn _solver(_: Box<dyn crate::kernel::solver::SolverDispatch>) {}
654        #[cfg(feature = "cusparse")]
655        fn _sparse(_: Box<dyn SparseDispatch>) {}
656        #[cfg(feature = "cutensor")]
657        fn _tensor(_: Box<dyn TensorDispatch>) {}
658        #[cfg(feature = "nccl")]
659        fn _coll(_: Box<dyn CollectiveDispatch>) {}
660    }
661}