Skip to main content

atomr_accel_cuda/device/
device_actor.rs

1//! `DeviceActor` — the outer tier of the §5.11 two-tier supervision tree.
2//!
3//! Responsibilities:
4//! - Stable address: `ActorRef<DeviceMsg>` survives unlimited
5//!   `ContextActor` restarts.
6//! - Spawns the `ContextActor` child (which owns the `Arc<CudaContext>`)
7//!   and queues `WorkRequest`s while the context is being (re)built.
8//! - Holds the shared `Arc<DeviceState>` that outlives any single
9//!   `ContextActor` incarnation.
10
11use std::any::{Any, TypeId};
12use std::collections::{HashMap, VecDeque};
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use atomr_core::actor::{Actor, ActorRef, Context, Props};
17use bitflags::bitflags;
18use parking_lot::RwLock;
19use tokio::sync::oneshot;
20use tracing::{debug, warn};
21
22use crate::dtype::CudaDtype;
23use crate::error::GpuError;
24use crate::gpu_ref::GpuRef;
25use crate::kernel::BlasMsg;
26
27use super::alloc_dispatch::{
28    AllocDispatch, AllocReq, CopyFromHostDispatch, CopyFromHostReq, CopyToHostDispatch,
29    CopyToHostReq,
30};
31use super::alloc_msg::{DeviceLoad, HostBuf};
32use super::context_actor::{ContextActor, ContextMsg};
33use super::state::DeviceState;
34
35bitflags! {
36    /// Per-device opt-in flags for which library actors to spawn.
37    /// Compile-time `feature = "..."` gates still apply — a flag for
38    /// a library that wasn't compiled in is silently ignored.
39    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
40    pub struct EnabledLibraries: u32 {
41        const BLAS     = 1 << 0;
42        const CUDNN    = 1 << 1;
43        const CUFFT    = 1 << 2;
44        const CURAND   = 1 << 3;
45        const CUSOLVER = 1 << 4;
46        const CUBLASLT = 1 << 5;
47        const NVRTC    = 1 << 6;
48        // Phase 0.8 — additional library + extension actor opt-ins.
49        const CUTENSOR   = 1 << 7;
50        const CUSPARSE   = 1 << 8;
51        const NCCL       = 1 << 9;
52        const CUTLASS    = 1 << 10;
53        const TENSORRT   = 1 << 11;
54        const FLASHATTN  = 1 << 12;
55        const CUB_THRUST = 1 << 13;
56        const TELEMETRY  = 1 << 14;
57
58        const ALL = Self::BLAS.bits()
59            | Self::CUDNN.bits()
60            | Self::CUFFT.bits()
61            | Self::CURAND.bits()
62            | Self::CUSOLVER.bits()
63            | Self::CUBLASLT.bits()
64            | Self::NVRTC.bits()
65            | Self::CUTENSOR.bits()
66            | Self::CUSPARSE.bits()
67            | Self::NCCL.bits()
68            | Self::CUTLASS.bits()
69            | Self::TENSORRT.bits()
70            | Self::FLASHATTN.bits()
71            | Self::CUB_THRUST.bits()
72            | Self::TELEMETRY.bits();
73    }
74}
75
76impl Default for EnabledLibraries {
77    /// Sensible default: BLAS only (matches F1 semantics). Enable
78    /// other libraries explicitly via [`DeviceConfig::with_libraries`].
79    fn default() -> Self {
80        Self::BLAS
81    }
82}
83
84/// Public configuration for a `DeviceActor`.
85#[derive(Debug, Clone)]
86pub struct DeviceConfig {
87    pub device_id: u32,
88    /// When true, the `ContextActor` skips real cudarc calls and just
89    /// drives the supervision plumbing. Used by `examples/echo_no_gpu`
90    /// and unit tests on hosts without a GPU.
91    pub mock_mode: bool,
92    /// Internal queue cap for work received before the context is ready
93    /// (or while it is being rebuilt). Bounds the §5.4 backpressure
94    /// surface.
95    pub pending_queue_capacity: usize,
96    /// Which library actors `ContextActor` should spawn under this
97    /// device. Defaults to `BLAS` only (F1 behaviour). Compile-time
98    /// `feature = "..."` gates still apply.
99    pub enabled_libraries: EnabledLibraries,
100}
101
102impl DeviceConfig {
103    pub fn new(device_id: u32) -> Self {
104        Self {
105            device_id,
106            mock_mode: false,
107            pending_queue_capacity: 1024,
108            enabled_libraries: EnabledLibraries::default(),
109        }
110    }
111
112    pub fn mock(device_id: u32) -> Self {
113        Self {
114            device_id,
115            mock_mode: true,
116            pending_queue_capacity: 1024,
117            enabled_libraries: EnabledLibraries::default(),
118        }
119    }
120
121    /// Builder: select which libraries' kernel actors to spawn.
122    pub fn with_libraries(mut self, libs: EnabledLibraries) -> Self {
123        self.enabled_libraries = libs;
124        self
125    }
126}
127
128/// Public messages sent to a `DeviceActor`.
129///
130/// **Phase 0.4** — the formerly-21 dtype-enumerated `Allocate*` /
131/// `CopyToHost*` / `CopyFromHost*` variants collapse into 3 boxed
132/// dispatchers:
133///
134/// - [`DeviceMsg::Alloc`] — typed allocation
135/// - [`DeviceMsg::CopyToHost`] — D2H async copy
136/// - [`DeviceMsg::CopyFromHost`] — H2D async copy
137///
138/// Each carries a `Box<dyn …Dispatch>` whose concrete payload is an
139/// `AllocReq<T>` / `CopyToHostReq<T>` / `CopyFromHostReq<T>` for some
140/// `T: CudaDtype`. `GpuRef<T>` keeps its static dtype on both ends —
141/// the box is purely a uniform mailbox surface.
142///
143/// The legacy `Allocate*` / `CopyToHost*` / `CopyFromHost*` variants
144/// remain as `#[deprecated]` aliases. Existing call sites compile and
145/// run unchanged; the handler arm constructs the equivalent
146/// `Box<dyn …Dispatch>` and forwards through the new path.
147pub enum DeviceMsg {
148    /// Phase 0.4 generic alloc. Construct via
149    /// [`DeviceMsg::alloc::<T>`](Self::alloc) or
150    /// `Box::new(AllocReq::<T> { … })` directly.
151    Alloc(Box<dyn AllocDispatch>),
152    /// Phase 0.4 generic D2H copy.
153    CopyToHost(Box<dyn CopyToHostDispatch>),
154    /// Phase 0.4 generic H2D copy.
155    CopyFromHost(Box<dyn CopyFromHostDispatch>),
156
157    /// **Deprecated alias** for [`DeviceMsg::AllocateF32`]. F1
158    /// callers wrote `Allocate { len, reply }` — kept for back-compat.
159    #[deprecated(note = "use DeviceMsg::alloc::<f32>(len, reply)")]
160    Allocate {
161        len: usize,
162        reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
163    },
164    #[deprecated(note = "use DeviceMsg::alloc::<f32>(len, reply)")]
165    AllocateF32 {
166        len: usize,
167        reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
168    },
169    #[deprecated(note = "use DeviceMsg::alloc::<f64>(len, reply)")]
170    AllocateF64 {
171        len: usize,
172        reply: oneshot::Sender<Result<GpuRef<f64>, GpuError>>,
173    },
174    #[deprecated(note = "use DeviceMsg::alloc::<i8>(len, reply)")]
175    AllocateI8 {
176        len: usize,
177        reply: oneshot::Sender<Result<GpuRef<i8>, GpuError>>,
178    },
179    #[deprecated(note = "use DeviceMsg::alloc::<i32>(len, reply)")]
180    AllocateI32 {
181        len: usize,
182        reply: oneshot::Sender<Result<GpuRef<i32>, GpuError>>,
183    },
184    #[deprecated(note = "use DeviceMsg::alloc::<i64>(len, reply)")]
185    AllocateI64 {
186        len: usize,
187        reply: oneshot::Sender<Result<GpuRef<i64>, GpuError>>,
188    },
189    #[deprecated(note = "use DeviceMsg::alloc::<u8>(len, reply)")]
190    AllocateU8 {
191        len: usize,
192        reply: oneshot::Sender<Result<GpuRef<u8>, GpuError>>,
193    },
194    #[deprecated(note = "use DeviceMsg::alloc::<u32>(len, reply)")]
195    AllocateU32 {
196        len: usize,
197        reply: oneshot::Sender<Result<GpuRef<u32>, GpuError>>,
198    },
199    #[deprecated(note = "use DeviceMsg::alloc::<u64>(len, reply)")]
200    AllocateU64 {
201        len: usize,
202        reply: oneshot::Sender<Result<GpuRef<u64>, GpuError>>,
203    },
204    #[cfg(feature = "f16")]
205    #[deprecated(note = "use DeviceMsg::alloc::<half::f16>(len, reply)")]
206    AllocateF16 {
207        len: usize,
208        reply: oneshot::Sender<Result<GpuRef<half::f16>, GpuError>>,
209    },
210    #[cfg(feature = "f16")]
211    #[deprecated(note = "use DeviceMsg::alloc::<half::bf16>(len, reply)")]
212    AllocateBf16 {
213        len: usize,
214        reply: oneshot::Sender<Result<GpuRef<half::bf16>, GpuError>>,
215    },
216
217    /// D2H async copy — buffer round-trips back via the reply so a
218    /// pinned buffer can return to its pool.
219    #[deprecated(note = "use DeviceMsg::copy_to_host::<f32>(src, dst, reply)")]
220    CopyToHostF32 {
221        src: GpuRef<f32>,
222        dst: HostBuf<f32>,
223        reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
224    },
225    #[deprecated(note = "use DeviceMsg::copy_from_host::<f32>(src, dst, reply)")]
226    CopyFromHostF32 {
227        src: HostBuf<f32>,
228        dst: GpuRef<f32>,
229        reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
230    },
231    #[deprecated(note = "use DeviceMsg::copy_to_host::<f64>(src, dst, reply)")]
232    CopyToHostF64 {
233        src: GpuRef<f64>,
234        dst: HostBuf<f64>,
235        reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
236    },
237    #[deprecated(note = "use DeviceMsg::copy_from_host::<f64>(src, dst, reply)")]
238    CopyFromHostF64 {
239        src: HostBuf<f64>,
240        dst: GpuRef<f64>,
241        reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
242    },
243    #[deprecated(note = "use DeviceMsg::copy_to_host::<i32>(src, dst, reply)")]
244    CopyToHostI32 {
245        src: GpuRef<i32>,
246        dst: HostBuf<i32>,
247        reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
248    },
249    #[deprecated(note = "use DeviceMsg::copy_from_host::<i32>(src, dst, reply)")]
250    CopyFromHostI32 {
251        src: HostBuf<i32>,
252        dst: GpuRef<i32>,
253        reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
254    },
255    #[deprecated(note = "use DeviceMsg::copy_to_host::<u32>(src, dst, reply)")]
256    CopyToHostU32 {
257        src: GpuRef<u32>,
258        dst: HostBuf<u32>,
259        reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
260    },
261    #[deprecated(note = "use DeviceMsg::copy_from_host::<u32>(src, dst, reply)")]
262    CopyFromHostU32 {
263        src: HostBuf<u32>,
264        dst: GpuRef<u32>,
265        reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
266    },
267    #[deprecated(note = "use DeviceMsg::copy_to_host::<u8>(src, dst, reply)")]
268    CopyToHostU8 {
269        src: GpuRef<u8>,
270        dst: HostBuf<u8>,
271        reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
272    },
273    #[deprecated(note = "use DeviceMsg::copy_from_host::<u8>(src, dst, reply)")]
274    CopyFromHostU8 {
275        src: HostBuf<u8>,
276        dst: GpuRef<u8>,
277        reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
278    },
279
280    /// Fire an SGEMM through the context's BlasActor.
281    Sgemm(Box<SgemmRequest>),
282
283    /// F4: Snapshot the underlying `Arc<CudaContext>` so a top-level
284    /// observer (P2pTopology, NcclWorldActor) can build cross-device
285    /// machinery. Replies `None` if the context isn't ready.
286    SnapshotContext {
287        reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaContext>>>,
288    },
289
290    /// Phase 4.5++ — Snapshot the device's primary `Arc<CudaStream>`
291    /// (the stream owned by `ContextActor`). Returned to downstream
292    /// raw-pointer FFI users (TensorRT `enqueueV3`, custom kernel
293    /// launchers) that need to share a single CUDA execution timeline
294    /// with the rest of the device's library actors.
295    ///
296    /// Replies `None` if the context isn't ready (e.g. mock mode, or
297    /// before `ContextReady`). On real hardware the returned stream
298    /// is the same one that BLAS / cuDNN / cuFFT child actors were
299    /// minted off.
300    SnapshotStream {
301        reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaStream>>>,
302    },
303
304    /// F7: Snapshot the current `KernelChildren` so application code
305    /// can talk to library actors directly (e.g. `RngActor`,
306    /// `CudnnActor`). Replies `None` until `ContextActor::Init`
307    /// completes.
308    SnapshotChildren {
309        reply: oneshot::Sender<Option<KernelChildren>>,
310    },
311
312    /// F9: Subscribe to the device's `DeviceState::generation_watch`.
313    /// The receiver fires every time the underlying `CudaContext`
314    /// rebuilds. Used by `NcclWorldActor` and `P2pTopology` to
315    /// react to context loss.
316    WatchGeneration {
317        reply: oneshot::Sender<tokio::sync::watch::Receiver<u64>>,
318    },
319
320    /// F5: Per-device load snapshot for placement scheduling.
321    Stats { reply: oneshot::Sender<DeviceLoad> },
322
323    /// Internal: `ContextActor` has finished initialising and the
324    /// kernel actors are live.
325    ContextReady { children: KernelChildren },
326    /// Internal: `ContextActor` notifies that the context was torn
327    /// down (e.g. on poisoning); pending work should be re-stashed
328    /// until a new `ContextReady` arrives.
329    ContextLost,
330}
331
332/// Set of kernel-actor refs spawned by a `ContextActor`. Each is
333/// `Some` only when both the cargo feature is compiled in and the
334/// `DeviceConfig::enabled_libraries` flag is set.
335///
336/// **Phase 0.8 extension.** In addition to the typed fields below
337/// (which keep existing call sites compiling), `KernelChildren`
338/// carries an open `extras` map keyed by [`TypeId`]. Future actor
339/// crates (`atomr-accel-cutlass`, `-tensorrt`, `-flashattn`,
340/// `-telemetry`, `-cub`) stash their `ActorRef` here so the device
341/// supervisor can hand them out via [`KernelChildren::extra`]
342/// without the core having to know their concrete message type.
343#[derive(Clone)]
344pub struct KernelChildren {
345    pub blas: ActorRef<BlasMsg>,
346    #[cfg(feature = "cudnn")]
347    pub cudnn: Option<ActorRef<crate::kernel::CudnnMsg>>,
348    #[cfg(feature = "cufft")]
349    pub fft: Option<ActorRef<crate::kernel::FftMsg>>,
350    #[cfg(feature = "curand")]
351    pub rng: Option<ActorRef<crate::kernel::RngMsg>>,
352    #[cfg(feature = "cusolver")]
353    pub solver: Option<ActorRef<crate::kernel::SolverMsg>>,
354    #[cfg(feature = "nvrtc")]
355    pub nvrtc: Option<ActorRef<crate::kernel::NvrtcMsg>>,
356    /// TypeId-keyed registry for child actors not represented by a
357    /// typed field above. The `Arc<RwLock<…>>` keeps `KernelChildren`
358    /// `Clone` while letting later library crates register / look up
359    /// their own refs.
360    extras: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
361}
362
363impl KernelChildren {
364    /// Construct a `KernelChildren` with the given `BlasActor` ref
365    /// and no library children or extras. Mirrors the `..Default::default()`
366    /// pattern used by callers but keeps `blas` mandatory.
367    pub fn new(blas: ActorRef<BlasMsg>) -> Self {
368        Self {
369            blas,
370            #[cfg(feature = "cudnn")]
371            cudnn: None,
372            #[cfg(feature = "cufft")]
373            fft: None,
374            #[cfg(feature = "curand")]
375            rng: None,
376            #[cfg(feature = "cusolver")]
377            solver: None,
378            #[cfg(feature = "nvrtc")]
379            nvrtc: None,
380            extras: Arc::new(RwLock::new(HashMap::new())),
381        }
382    }
383
384    /// Register an extra child actor (or any `Send + Sync` handle) by
385    /// type. Future actor crates (`atomr-accel-cutlass`, `-tensorrt`,
386    /// `-flashattn`, `-telemetry`, `-cub`) stash their `ActorRef`
387    /// here so the device supervisor can route stop/restart messages.
388    ///
389    /// If a value of the same type is already registered, it is
390    /// overwritten — typical use is one-shot registration during
391    /// `ContextActor::run_init`.
392    pub fn register_extra<T: Any + Send + Sync>(&self, value: T) {
393        let mut g = self.extras.write();
394        g.insert(TypeId::of::<T>(), Arc::new(value));
395    }
396
397    /// Look up a previously registered extra by type. Returns a clone
398    /// of the stored `T` if and only if a value of that exact type
399    /// was registered.
400    pub fn extra<T: Any + Send + Sync + Clone>(&self) -> Option<T> {
401        let g = self.extras.read();
402        g.get(&TypeId::of::<T>())
403            .and_then(|v| v.clone().downcast::<T>().ok())
404            .map(|arc| (*arc).clone())
405    }
406
407    /// Number of registered extras (for stats / observability).
408    pub fn extras_len(&self) -> usize {
409        self.extras.read().len()
410    }
411}
412
413impl DeviceMsg {
414    /// Phase 0.4: typed-allocation constructor. Boxes an
415    /// [`AllocReq<T>`] into the generic [`DeviceMsg::Alloc`] variant.
416    pub fn alloc<T: CudaDtype>(
417        len: usize,
418        reply: oneshot::Sender<Result<GpuRef<T>, GpuError>>,
419    ) -> Self {
420        DeviceMsg::Alloc(Box::new(AllocReq::<T> { len, reply }))
421    }
422
423    /// Phase 0.4: typed D2H copy constructor.
424    pub fn copy_to_host<T: CudaDtype>(
425        src: GpuRef<T>,
426        dst: HostBuf<T>,
427        reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
428    ) -> Self {
429        DeviceMsg::CopyToHost(Box::new(CopyToHostReq::<T> { src, dst, reply }))
430    }
431
432    /// Phase 0.4: typed H2D copy constructor.
433    pub fn copy_from_host<T: CudaDtype>(
434        src: HostBuf<T>,
435        dst: GpuRef<T>,
436        reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
437    ) -> Self {
438        DeviceMsg::CopyFromHost(Box::new(CopyFromHostReq::<T> { src, dst, reply }))
439    }
440}
441
442/// Body of a `DeviceMsg::Sgemm` request. Boxed because it's larger than
443/// the surrounding enum's other variants and we want the enum cheap to
444/// clone/send. `reply` is `oneshot`, so each request must be unique.
445pub struct SgemmRequest {
446    pub a: GpuRef<f32>,
447    pub b: GpuRef<f32>,
448    pub c: GpuRef<f32>,
449    pub m: i32,
450    pub n: i32,
451    pub k: i32,
452    pub alpha: f32,
453    pub beta: f32,
454    pub reply: oneshot::Sender<Result<(), GpuError>>,
455}
456
457/// Pending work item — anything DeviceActor stashes while the context
458/// is not ready. Mirrors the user-facing variants of `DeviceMsg` minus
459/// internal messages.
460pub enum WorkRequest {
461    /// Forward to ContextActor as a typed allocate. We collapse all
462    /// per-dtype variants into a single `Pending` that re-issues the
463    /// original message verbatim — `Box<dyn FnOnce>` keeps each
464    /// pending op type-safe without needing N enum variants.
465    Boxed(Box<dyn FnOnce(&ActorRef<ContextMsg>, &ActorRef<BlasMsg>) + Send>),
466    Sgemm(Box<SgemmRequest>),
467    /// Reply slot for callers who can't be re-driven (e.g. a
468    /// SnapshotContext while context isn't ready).
469    SnapshotContext {
470        reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaContext>>>,
471    },
472}
473
474pub struct DeviceActor {
475    config: DeviceConfig,
476    state: Arc<DeviceState>,
477    context_ref: Option<ActorRef<ContextMsg>>,
478    children: Option<KernelChildren>,
479    pending: VecDeque<WorkRequest>,
480}
481
482impl DeviceActor {
483    pub fn new(config: DeviceConfig) -> Self {
484        let state = Arc::new(DeviceState::new(config.device_id));
485        Self {
486            config,
487            state,
488            context_ref: None,
489            children: None,
490            pending: VecDeque::new(),
491        }
492    }
493
494    /// Construct a `Props<DeviceActor>` with the given configuration.
495    pub fn props(config: DeviceConfig) -> Props<Self> {
496        let cfg = config.clone();
497        Props::create(move || DeviceActor::new(cfg.clone()))
498    }
499
500    /// Shared device state — exposed for tests and for `KernelActor`s
501    /// that need to mint `GpuRef`s.
502    pub fn state(&self) -> &Arc<DeviceState> {
503        &self.state
504    }
505
506    fn enqueue_pending(&mut self, work: WorkRequest) {
507        if self.pending.len() >= self.config.pending_queue_capacity {
508            warn!(
509                device_id = self.config.device_id,
510                cap = self.config.pending_queue_capacity,
511                "dropping work — pending queue full"
512            );
513            // Drop on the floor with a typed error. The Boxed variant
514            // owns its reply channel internally so we just drop it
515            // and the caller observes oneshot::Receiver::Err.
516            match work {
517                WorkRequest::Sgemm(req) => {
518                    let _ = req.reply.send(Err(GpuError::Unrecoverable(
519                        "device pending queue full".into(),
520                    )));
521                }
522                WorkRequest::SnapshotContext { reply } => {
523                    let _ = reply.send(None);
524                }
525                WorkRequest::Boxed(_) => { /* reply drops with closure */ }
526            }
527            return;
528        }
529        self.pending.push_back(work);
530    }
531
532    fn drain_pending(&mut self) {
533        let Some(children) = self.children.clone() else {
534            return;
535        };
536        let Some(ctx) = self.context_ref.clone() else {
537            return;
538        };
539        while let Some(work) = self.pending.pop_front() {
540            match work {
541                WorkRequest::Boxed(f) => f(&ctx, &children.blas),
542                WorkRequest::Sgemm(req) => {
543                    children.blas.tell(BlasMsg::Sgemm(req));
544                }
545                WorkRequest::SnapshotContext { reply } => {
546                    // No way to ask a non-mailbox method to fetch the
547                    // current context here; the user can re-issue.
548                    let _ = reply.send(self.state.current_context());
549                }
550            }
551        }
552    }
553}
554
555#[async_trait]
556impl Actor for DeviceActor {
557    type Msg = DeviceMsg;
558
559    async fn pre_start(&mut self, ctx: &mut Context<Self>) {
560        debug!(device_id = self.config.device_id, "DeviceActor pre_start");
561        let parent_ref = ctx.self_ref().clone();
562        let props = ContextActor::props(self.state.clone(), self.config.clone(), parent_ref);
563        match ctx.spawn::<ContextActor>(props, "ctx") {
564            Ok(r) => {
565                self.context_ref = Some(r);
566            }
567            Err(e) => {
568                // Spawn failure here is structural; surface via panic so
569                // a user-installed root supervisor sees it.
570                panic!("Unrecoverable: failed to spawn ContextActor: {e}");
571            }
572        }
573    }
574
575    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: DeviceMsg) {
576        // Phase 0.4: the alloc/copy fan-out collapses into 3 generic
577        // arms. Legacy `Allocate*` / `CopyToHost*` / `CopyFromHost*`
578        // variants are translated into the new boxed dispatchers
579        // before forwarding, so the rest of the pipeline (stash /
580        // drain / context handler) sees a single shape.
581        #[allow(deprecated)]
582        let msg = match msg {
583            // -- legacy alloc → AllocReq<T> -------------------
584            DeviceMsg::Allocate { len, reply } | DeviceMsg::AllocateF32 { len, reply } => {
585                DeviceMsg::alloc::<f32>(len, reply)
586            }
587            DeviceMsg::AllocateF64 { len, reply } => DeviceMsg::alloc::<f64>(len, reply),
588            DeviceMsg::AllocateI8 { len, reply } => DeviceMsg::alloc::<i8>(len, reply),
589            DeviceMsg::AllocateI32 { len, reply } => DeviceMsg::alloc::<i32>(len, reply),
590            DeviceMsg::AllocateI64 { len, reply } => DeviceMsg::alloc::<i64>(len, reply),
591            DeviceMsg::AllocateU8 { len, reply } => DeviceMsg::alloc::<u8>(len, reply),
592            DeviceMsg::AllocateU32 { len, reply } => DeviceMsg::alloc::<u32>(len, reply),
593            DeviceMsg::AllocateU64 { len, reply } => DeviceMsg::alloc::<u64>(len, reply),
594            #[cfg(feature = "f16")]
595            DeviceMsg::AllocateF16 { len, reply } => DeviceMsg::alloc::<half::f16>(len, reply),
596            #[cfg(feature = "f16")]
597            DeviceMsg::AllocateBf16 { len, reply } => DeviceMsg::alloc::<half::bf16>(len, reply),
598            // -- legacy copy_to_host → CopyToHostReq<T> -------
599            DeviceMsg::CopyToHostF32 { src, dst, reply } => {
600                DeviceMsg::copy_to_host::<f32>(src, dst, reply)
601            }
602            DeviceMsg::CopyToHostF64 { src, dst, reply } => {
603                DeviceMsg::copy_to_host::<f64>(src, dst, reply)
604            }
605            DeviceMsg::CopyToHostI32 { src, dst, reply } => {
606                DeviceMsg::copy_to_host::<i32>(src, dst, reply)
607            }
608            DeviceMsg::CopyToHostU32 { src, dst, reply } => {
609                DeviceMsg::copy_to_host::<u32>(src, dst, reply)
610            }
611            DeviceMsg::CopyToHostU8 { src, dst, reply } => {
612                DeviceMsg::copy_to_host::<u8>(src, dst, reply)
613            }
614            // -- legacy copy_from_host → CopyFromHostReq<T> ---
615            DeviceMsg::CopyFromHostF32 { src, dst, reply } => {
616                DeviceMsg::copy_from_host::<f32>(src, dst, reply)
617            }
618            DeviceMsg::CopyFromHostF64 { src, dst, reply } => {
619                DeviceMsg::copy_from_host::<f64>(src, dst, reply)
620            }
621            DeviceMsg::CopyFromHostI32 { src, dst, reply } => {
622                DeviceMsg::copy_from_host::<i32>(src, dst, reply)
623            }
624            DeviceMsg::CopyFromHostU32 { src, dst, reply } => {
625                DeviceMsg::copy_from_host::<u32>(src, dst, reply)
626            }
627            DeviceMsg::CopyFromHostU8 { src, dst, reply } => {
628                DeviceMsg::copy_from_host::<u8>(src, dst, reply)
629            }
630            // already-collapsed / non-alloc variants pass through
631            other => other,
632        };
633
634        let ready = self.context_ref.is_some() && self.children.is_some();
635
636        match msg {
637            // Phase 0.4: 3 arms for the generic forms.
638            DeviceMsg::Alloc(boxed) => {
639                if ready {
640                    self.context_ref
641                        .as_ref()
642                        .unwrap()
643                        .tell(ContextMsg::Alloc(boxed));
644                } else {
645                    self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
646                        c.tell(ContextMsg::Alloc(boxed))
647                    })));
648                }
649            }
650            DeviceMsg::CopyToHost(boxed) => {
651                if ready {
652                    self.context_ref
653                        .as_ref()
654                        .unwrap()
655                        .tell(ContextMsg::CopyToHost(boxed));
656                } else {
657                    self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
658                        c.tell(ContextMsg::CopyToHost(boxed))
659                    })));
660                }
661            }
662            DeviceMsg::CopyFromHost(boxed) => {
663                if ready {
664                    self.context_ref
665                        .as_ref()
666                        .unwrap()
667                        .tell(ContextMsg::CopyFromHost(boxed));
668                } else {
669                    self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
670                        c.tell(ContextMsg::CopyFromHost(boxed))
671                    })));
672                }
673            }
674
675            // Legacy variants are unreachable here: the upstream
676            // translation stage (above) rewrote every one into the
677            // generic form.
678            #[allow(deprecated)]
679            DeviceMsg::Allocate { .. }
680            | DeviceMsg::AllocateF32 { .. }
681            | DeviceMsg::AllocateF64 { .. }
682            | DeviceMsg::AllocateI8 { .. }
683            | DeviceMsg::AllocateI32 { .. }
684            | DeviceMsg::AllocateI64 { .. }
685            | DeviceMsg::AllocateU8 { .. }
686            | DeviceMsg::AllocateU32 { .. }
687            | DeviceMsg::AllocateU64 { .. }
688            | DeviceMsg::CopyToHostF32 { .. }
689            | DeviceMsg::CopyFromHostF32 { .. }
690            | DeviceMsg::CopyToHostF64 { .. }
691            | DeviceMsg::CopyFromHostF64 { .. }
692            | DeviceMsg::CopyToHostI32 { .. }
693            | DeviceMsg::CopyFromHostI32 { .. }
694            | DeviceMsg::CopyToHostU32 { .. }
695            | DeviceMsg::CopyFromHostU32 { .. }
696            | DeviceMsg::CopyToHostU8 { .. }
697            | DeviceMsg::CopyFromHostU8 { .. } => unreachable!(
698                "Phase 0.4 translation collapses all legacy alloc/copy variants \
699                 into DeviceMsg::Alloc / CopyToHost / CopyFromHost"
700            ),
701            #[cfg(feature = "f16")]
702            #[allow(deprecated)]
703            DeviceMsg::AllocateF16 { .. } | DeviceMsg::AllocateBf16 { .. } => {
704                unreachable!(
705                    "Phase 0.4 translation collapses all legacy alloc/copy variants \
706                     into DeviceMsg::Alloc"
707                )
708            }
709
710            DeviceMsg::Sgemm(req) => match &self.children {
711                Some(c) => c.blas.tell(BlasMsg::Sgemm(req)),
712                None => self.enqueue_pending(WorkRequest::Sgemm(req)),
713            },
714
715            DeviceMsg::SnapshotContext { reply } => {
716                let _ = reply.send(self.state.current_context());
717            }
718            DeviceMsg::SnapshotStream { reply } => {
719                // Forward to ContextActor — it owns the primary
720                // `Arc<CudaStream>`. If the context isn't ready (mock
721                // mode pre-ready or while a rebuild is in flight) the
722                // ContextActor handler replies `None` to the same
723                // oneshot. We don't stash this as a `WorkRequest`
724                // because a `None` reply is correct in that case —
725                // callers re-issue if they need the stream.
726                if let Some(ctx) = self.context_ref.as_ref() {
727                    ctx.tell(ContextMsg::SnapshotStream { reply });
728                } else {
729                    let _ = reply.send(None);
730                }
731            }
732            DeviceMsg::SnapshotChildren { reply } => {
733                let _ = reply.send(self.children.clone());
734            }
735            DeviceMsg::WatchGeneration { reply } => {
736                let _ = reply.send(self.state.generation_watch());
737            }
738            DeviceMsg::Stats { reply } => {
739                let _ = reply.send(self.snapshot_load());
740            }
741
742            DeviceMsg::ContextReady { children } => {
743                debug!(device_id = self.config.device_id, "context ready");
744                self.children = Some(children);
745                self.drain_pending();
746            }
747            DeviceMsg::ContextLost => {
748                debug!(device_id = self.config.device_id, "context lost");
749                self.children = None;
750            }
751        }
752    }
753
754    async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
755        debug!(device_id = self.config.device_id, "DeviceActor post_stop");
756        self.state.begin_shutdown();
757        // Drain pending replies with stale errors so callers don't hang.
758        while let Some(work) = self.pending.pop_front() {
759            match work {
760                WorkRequest::Boxed(_) => { /* reply drops with closure */ }
761                WorkRequest::Sgemm(req) => {
762                    let _ = req
763                        .reply
764                        .send(Err(GpuError::GpuRefStale("device shutting down")));
765                }
766                WorkRequest::SnapshotContext { reply } => {
767                    let _ = reply.send(None);
768                }
769            }
770        }
771    }
772}
773
774impl DeviceActor {
775    fn snapshot_load(&self) -> DeviceLoad {
776        DeviceLoad {
777            free_bytes: 0,
778            total_bytes: 0,
779            active_streams: 0,
780            queue_depth: self.pending.len() as u32,
781            compute_cap: (0, 0),
782        }
783    }
784}
785
786#[cfg(test)]
787#[allow(deprecated)] // exercised on purpose: legacy variants must keep routing.
788mod tests {
789    use super::*;
790    use crate::dtype::DType;
791    use atomr_config::Config;
792    use atomr_core::actor::ActorSystem;
793    use std::time::Duration;
794
795    /// Phase 0.8 — bit values are part of the on-the-wire surface
796    /// (config files / persisted device specs). Lock them down so a
797    /// future re-ordering of the bitflag declaration is caught.
798    #[test]
799    fn enabled_libraries_bit_values_are_stable() {
800        assert_eq!(EnabledLibraries::BLAS.bits(), 1 << 0);
801        assert_eq!(EnabledLibraries::CUDNN.bits(), 1 << 1);
802        assert_eq!(EnabledLibraries::CUFFT.bits(), 1 << 2);
803        assert_eq!(EnabledLibraries::CURAND.bits(), 1 << 3);
804        assert_eq!(EnabledLibraries::CUSOLVER.bits(), 1 << 4);
805        assert_eq!(EnabledLibraries::CUBLASLT.bits(), 1 << 5);
806        assert_eq!(EnabledLibraries::NVRTC.bits(), 1 << 6);
807        // Phase 0.8 additions.
808        assert_eq!(EnabledLibraries::CUTENSOR.bits(), 1 << 7);
809        assert_eq!(EnabledLibraries::CUSPARSE.bits(), 1 << 8);
810        assert_eq!(EnabledLibraries::NCCL.bits(), 1 << 9);
811        assert_eq!(EnabledLibraries::CUTLASS.bits(), 1 << 10);
812        assert_eq!(EnabledLibraries::TENSORRT.bits(), 1 << 11);
813        assert_eq!(EnabledLibraries::FLASHATTN.bits(), 1 << 12);
814        assert_eq!(EnabledLibraries::CUB_THRUST.bits(), 1 << 13);
815        assert_eq!(EnabledLibraries::TELEMETRY.bits(), 1 << 14);
816    }
817
818    #[test]
819    fn enabled_libraries_round_trip_via_bits() {
820        let original = EnabledLibraries::BLAS
821            | EnabledLibraries::CUTENSOR
822            | EnabledLibraries::FLASHATTN
823            | EnabledLibraries::TELEMETRY;
824        let bits = original.bits();
825        let restored =
826            EnabledLibraries::from_bits(bits).expect("known bits round-trip through from_bits");
827        assert_eq!(original, restored);
828        assert!(restored.contains(EnabledLibraries::FLASHATTN));
829        assert!(!restored.contains(EnabledLibraries::CUDNN));
830    }
831
832    #[test]
833    fn enabled_libraries_all_contains_every_phase_0_8_bit() {
834        let all = EnabledLibraries::ALL;
835        for bit in [
836            EnabledLibraries::BLAS,
837            EnabledLibraries::CUDNN,
838            EnabledLibraries::CUFFT,
839            EnabledLibraries::CURAND,
840            EnabledLibraries::CUSOLVER,
841            EnabledLibraries::CUBLASLT,
842            EnabledLibraries::NVRTC,
843            EnabledLibraries::CUTENSOR,
844            EnabledLibraries::CUSPARSE,
845            EnabledLibraries::NCCL,
846            EnabledLibraries::CUTLASS,
847            EnabledLibraries::TENSORRT,
848            EnabledLibraries::FLASHATTN,
849            EnabledLibraries::CUB_THRUST,
850            EnabledLibraries::TELEMETRY,
851        ] {
852            assert!(all.contains(bit), "ALL missing {bit:?}");
853        }
854    }
855
856    /// Phase 0.8 — `KernelChildren::register_extra` / `extra` round-trip.
857    /// Uses a dummy non-actor type because spawning a real actor here
858    /// would pull in the full ActorSystem and is unnecessary for the
859    /// API contract under test.
860    #[test]
861    fn kernel_children_extras_register_and_retrieve_by_type() {
862        // Build a KernelChildren manually using a dummy BlasActor ref.
863        // We can't easily construct an ActorRef<BlasMsg> outside an
864        // ActorSystem, so this test only touches the extras map by
865        // building it directly.
866        let extras: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
867            Arc::new(RwLock::new(HashMap::new()));
868
869        // Stand-in helper that mirrors KernelChildren::register_extra /
870        // extra / extras_len semantics but operates on the bare extras
871        // map. Keeping the test free of an ActorSystem dep.
872        fn register<T: Any + Send + Sync>(
873            map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
874            v: T,
875        ) {
876            map.write().insert(TypeId::of::<T>(), Arc::new(v));
877        }
878        fn lookup<T: Any + Send + Sync + Clone>(
879            map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
880        ) -> Option<T> {
881            map.read()
882                .get(&TypeId::of::<T>())
883                .and_then(|v| v.clone().downcast::<T>().ok())
884                .map(|arc| (*arc).clone())
885        }
886
887        #[derive(Clone, PartialEq, Eq, Debug)]
888        struct CutlassRef(u32);
889        #[derive(Clone, PartialEq, Eq, Debug)]
890        struct TensorRtRef(&'static str);
891
892        register(&extras, CutlassRef(7));
893        register(&extras, TensorRtRef("trt"));
894
895        assert_eq!(lookup::<CutlassRef>(&extras), Some(CutlassRef(7)));
896        assert_eq!(lookup::<TensorRtRef>(&extras), Some(TensorRtRef("trt")));
897        // Unregistered type returns None.
898        #[derive(Clone)]
899        struct Unknown;
900        assert!(lookup::<Unknown>(&extras).is_none());
901        assert_eq!(extras.read().len(), 2);
902
903        // Re-registering the same type overwrites.
904        register(&extras, CutlassRef(99));
905        assert_eq!(lookup::<CutlassRef>(&extras), Some(CutlassRef(99)));
906        assert_eq!(extras.read().len(), 2);
907    }
908
909    /// End-to-end exercise of the actual `KernelChildren` API by going
910    /// through a mock `DeviceActor` so we have a real `BlasActor` ref
911    /// to seed the struct.
912    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
913    async fn kernel_children_extras_via_snapshot() {
914        let sys = ActorSystem::create("kc_extras", Config::empty())
915            .await
916            .unwrap();
917        let dev = sys
918            .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev0")
919            .unwrap();
920
921        // Wait for ContextReady by repeatedly probing SnapshotChildren.
922        let mut snap: Option<KernelChildren> = None;
923        for _ in 0..50 {
924            let (tx, rx) = oneshot::channel();
925            dev.tell(DeviceMsg::SnapshotChildren { reply: tx });
926            if let Ok(Some(c)) = rx.await {
927                snap = Some(c);
928                break;
929            }
930            tokio::time::sleep(Duration::from_millis(20)).await;
931        }
932        let children = snap.expect("KernelChildren snapshot should arrive in mock mode");
933        assert_eq!(children.extras_len(), 0);
934
935        #[derive(Clone, Debug, PartialEq, Eq)]
936        struct FakeCutlassRef(u64);
937        children.register_extra(FakeCutlassRef(42));
938        assert_eq!(children.extras_len(), 1);
939        assert_eq!(children.extra::<FakeCutlassRef>(), Some(FakeCutlassRef(42)));
940        // Clones share the same extras map (Arc<RwLock<…>> inside).
941        let cloned = children.clone();
942        assert_eq!(cloned.extras_len(), 1);
943        assert_eq!(cloned.extra::<FakeCutlassRef>(), Some(FakeCutlassRef(42)));
944
945        sys.terminate().await;
946    }
947
948    /// Smoke test — DeviceActor in mock mode should accept Allocate
949    /// requests and reply (with an error from mock BlasActor or with a
950    /// fabricated success). This exercises the whole spawn / ContextReady
951    /// / drain_pending plumbing without touching cudarc at runtime.
952    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
953    async fn pending_work_drains_on_context_ready() {
954        let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
955        let dev = sys
956            .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev0")
957            .unwrap();
958
959        // Send Allocate before ContextReady can possibly have arrived. In
960        // mock mode the ContextActor responds with success quickly; we
961        // give it a generous timeout.
962        let (tx, rx) = oneshot::channel();
963        dev.tell(DeviceMsg::Allocate { len: 16, reply: tx });
964        let res = tokio::time::timeout(Duration::from_secs(2), rx)
965            .await
966            .expect("alloc reply should arrive within timeout")
967            .expect("oneshot dropped");
968        // In mock mode the allocation returns the synthetic error
969        // documented in ContextActor::handle. We just verify a reply
970        // arrived — the plumbing is what's under test here.
971        assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
972
973        sys.terminate().await;
974    }
975
976    /// Phase 0.4: the typed `DeviceMsg::alloc::<T>` constructor should
977    /// build an `AllocReq<T>`-shaped boxed dispatcher and round-trip
978    /// through the actor pipeline, replying with the same kind of
979    /// error the legacy variant produces in mock mode.
980    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
981    async fn alloc_dispatch_via_typed_constructor() {
982        let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
983        let dev = sys
984            .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev1")
985            .unwrap();
986
987        let (tx, rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
988        dev.tell(DeviceMsg::alloc::<f32>(64, tx));
989        let res = tokio::time::timeout(Duration::from_secs(2), rx)
990            .await
991            .expect("alloc reply within timeout")
992            .expect("oneshot dropped");
993        assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
994
995        sys.terminate().await;
996    }
997
998    /// Phase 0.4: every `*Dispatch` trait carries a runtime dtype tag
999    /// reflecting the concrete `T: CudaDtype`. We never go through the
1000    /// actor system here — boxing is enough to verify dispatch.
1001    #[test]
1002    fn alloc_dispatch_dtype_kind_correct() {
1003        // f32
1004        let (tx, _rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
1005        let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<f32> { len: 4, reply: tx });
1006        assert_eq!(boxed.dtype(), DType::F32);
1007        assert_eq!(boxed.len(), 4);
1008
1009        // i32
1010        let (tx, _rx) = oneshot::channel::<Result<GpuRef<i32>, GpuError>>();
1011        let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<i32> { len: 7, reply: tx });
1012        assert_eq!(boxed.dtype(), DType::I32);
1013
1014        // u8
1015        let (tx, _rx) = oneshot::channel::<Result<GpuRef<u8>, GpuError>>();
1016        let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<u8> { len: 1, reply: tx });
1017        assert_eq!(boxed.dtype(), DType::U8);
1018    }
1019
1020    /// Phase 0.4: legacy `DeviceMsg::AllocateF32` constructor is
1021    /// `#[deprecated]` but still compiles and routes correctly into
1022    /// the new generic pipeline.
1023    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1024    async fn deprecated_allocate_f32_still_works() {
1025        let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
1026        let dev = sys
1027            .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev2")
1028            .unwrap();
1029
1030        let (tx, rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
1031        // NOTE: explicitly using the deprecated variant. The
1032        // `#[allow(deprecated)]` on the mod silences the warning.
1033        dev.tell(DeviceMsg::AllocateF32 { len: 8, reply: tx });
1034        let res = tokio::time::timeout(Duration::from_secs(2), rx)
1035            .await
1036            .expect("alloc reply within timeout")
1037            .expect("oneshot dropped");
1038        assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
1039
1040        sys.terminate().await;
1041    }
1042
1043    /// Phase 0.4: the `CopyToHostDispatch` trait carries a runtime
1044    /// dtype tag. We exercise this with a stub dispatcher (no real
1045    /// `GpuRef<T>` involved — that would require a live CudaContext)
1046    /// and confirm the boxed dtype reports `T::KIND`.
1047    #[test]
1048    fn copy_to_host_typed() {
1049        struct Stub<T: CudaDtype>(std::marker::PhantomData<T>);
1050        impl<T: CudaDtype> CopyToHostDispatch for Stub<T> {
1051            fn dtype(&self) -> DType {
1052                T::KIND
1053            }
1054            fn run(
1055                self: Box<Self>,
1056                _stream: Arc<cudarc::driver::CudaStream>,
1057                _completion: Arc<dyn crate::completion::CompletionStrategy>,
1058            ) {
1059                // never invoked in unit tests
1060            }
1061        }
1062
1063        let boxed: Box<dyn CopyToHostDispatch> = Box::new(Stub::<f32>(std::marker::PhantomData));
1064        assert_eq!(boxed.dtype(), DType::F32);
1065        let boxed: Box<dyn CopyToHostDispatch> = Box::new(Stub::<i32>(std::marker::PhantomData));
1066        assert_eq!(boxed.dtype(), DType::I32);
1067
1068        // Smoke: the typed constructor builds the matching variant.
1069        // We can wrap the stub in DeviceMsg::CopyToHost manually and
1070        // assert the variant tag.
1071        let msg = DeviceMsg::CopyToHost(Box::new(Stub::<u32>(std::marker::PhantomData)));
1072        match msg {
1073            DeviceMsg::CopyToHost(b) => assert_eq!(b.dtype(), DType::U32),
1074            _ => panic!("expected CopyToHost variant"),
1075        }
1076    }
1077
1078    /// Phase 0.4: H2D mirror of `copy_to_host_typed`.
1079    #[test]
1080    fn copy_from_host_typed() {
1081        struct Stub<T: CudaDtype>(std::marker::PhantomData<T>);
1082        impl<T: CudaDtype> CopyFromHostDispatch for Stub<T> {
1083            fn dtype(&self) -> DType {
1084                T::KIND
1085            }
1086            fn run(
1087                self: Box<Self>,
1088                _stream: Arc<cudarc::driver::CudaStream>,
1089                _completion: Arc<dyn crate::completion::CompletionStrategy>,
1090            ) {
1091            }
1092        }
1093
1094        let boxed: Box<dyn CopyFromHostDispatch> = Box::new(Stub::<u8>(std::marker::PhantomData));
1095        assert_eq!(boxed.dtype(), DType::U8);
1096
1097        let msg = DeviceMsg::CopyFromHost(Box::new(Stub::<f64>(std::marker::PhantomData)));
1098        match msg {
1099            DeviceMsg::CopyFromHost(b) => assert_eq!(b.dtype(), DType::F64),
1100            _ => panic!("expected CopyFromHost variant"),
1101        }
1102    }
1103}