Skip to main content

atomr_accel_cuda/device/
alloc_dispatch.rs

1//! Boxed-dispatch payloads for the Phase 0.4 generic alloc/copy
2//! variants of [`DeviceMsg`] and [`ContextMsg`].
3//!
4//! The previous design carried one enum variant per dtype (`AllocateF32`,
5//! `AllocateF64`, `AllocateI8`, …) which scaled poorly: each new dtype
6//! doubled the alloc surface and tripled the copy surface. This module
7//! replaces that fan-out with three boxed-trait-object dispatchers:
8//!
9//! - [`AllocDispatch`] — typed buffer allocation
10//! - [`CopyToHostDispatch`] — D2H async copy
11//! - [`CopyFromHostDispatch`] — H2D async copy
12//!
13//! Concrete request structs (`AllocReq<T>`, `CopyToHostReq<T>`,
14//! `CopyFromHostReq<T>`) implement the matching trait and ride inside a
15//! single `Box<dyn …>`. The DeviceActor's `handle` arm forwards them to
16//! the ContextActor verbatim — the typed `T: CudaDtype` parameter is
17//! preserved through the box, so `GpuRef<T>` keeps its static dtype on
18//! the receiving side.
19//!
20//! See `device::device_actor` for the legacy `#[deprecated]` enum
21//! variants kept for back-compat, and `device::context_actor` for the
22//! receiving side that calls `.run(...)` on the boxed dispatcher.
23
24use std::sync::Arc;
25
26use cudarc::driver::CudaStream;
27use tokio::sync::oneshot;
28
29use crate::completion::CompletionStrategy;
30use crate::dtype::{CudaDtype, DType};
31use crate::error::GpuError;
32use crate::gpu_ref::GpuRef;
33
34use super::alloc_msg::HostBuf;
35use super::state::DeviceState;
36
37/// Trait object the DeviceActor stashes inside a single
38/// [`crate::device::DeviceMsg::Alloc`] variant. The actual `T` is
39/// erased at the box boundary; the receiving ContextActor calls
40/// [`AllocDispatch::run`] which downcasts back into the concrete
41/// `AllocReq<T>` and performs a typed `stream.alloc_zeros::<T>(len)`.
42pub trait AllocDispatch: Send + 'static {
43    /// Concrete dtype carried by this dispatcher.
44    fn dtype(&self) -> DType;
45
46    /// Element count being allocated.
47    fn len(&self) -> usize;
48
49    /// Execute the allocation against the given context's primary
50    /// stream and reply on the embedded `oneshot` channel.
51    ///
52    /// `mock_mode == true` makes this match the legacy mock-mode
53    /// behaviour of [`super::context_actor::ContextActor::alloc`] —
54    /// it always replies with `GpuError::Unrecoverable("alloc not
55    /// supported in mock mode")` so existing tests are unaffected.
56    fn run(
57        self: Box<Self>,
58        stream: Option<&Arc<CudaStream>>,
59        state: &Arc<DeviceState>,
60        mock_mode: bool,
61    );
62}
63
64/// Concrete typed allocation request. Held inside a
65/// `Box<dyn AllocDispatch>` while in flight.
66pub struct AllocReq<T: CudaDtype> {
67    pub len: usize,
68    pub reply: oneshot::Sender<Result<GpuRef<T>, GpuError>>,
69}
70
71impl<T: CudaDtype> AllocDispatch for AllocReq<T> {
72    fn dtype(&self) -> DType {
73        T::KIND
74    }
75
76    fn len(&self) -> usize {
77        self.len
78    }
79
80    fn run(
81        self: Box<Self>,
82        stream: Option<&Arc<CudaStream>>,
83        state: &Arc<DeviceState>,
84        mock_mode: bool,
85    ) {
86        let AllocReq { len, reply } = *self;
87        if mock_mode {
88            let _ = reply.send(Err(GpuError::Unrecoverable(
89                "alloc not supported in mock mode".into(),
90            )));
91            return;
92        }
93        let Some(stream) = stream else {
94            let _ = reply.send(Err(GpuError::GpuRefStale("context not ready")));
95            return;
96        };
97        match stream.alloc_zeros::<T>(len) {
98            Ok(slice) => {
99                let _ = reply.send(Ok(GpuRef::<T>::new(Arc::new(slice), state)));
100            }
101            Err(e) => {
102                let _ = reply.send(Err(GpuError::OutOfMemory(format!("alloc {len}: {e}"))));
103            }
104        }
105    }
106}
107
108/// Trait object behind [`crate::device::DeviceMsg::CopyToHost`].
109/// Carries the typed source `GpuRef<T>` plus host destination buffer.
110pub trait CopyToHostDispatch: Send + 'static {
111    fn dtype(&self) -> DType;
112    fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>);
113}
114
115pub struct CopyToHostReq<T: CudaDtype> {
116    pub src: GpuRef<T>,
117    pub dst: HostBuf<T>,
118    pub reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
119}
120
121impl<T: CudaDtype> CopyToHostDispatch for CopyToHostReq<T> {
122    fn dtype(&self) -> DType {
123        T::KIND
124    }
125
126    fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>) {
127        let CopyToHostReq { src, dst, reply } = *self;
128        super::context_actor::run_copy_to_host(src, dst, stream, completion, reply);
129    }
130}
131
132/// Trait object behind [`crate::device::DeviceMsg::CopyFromHost`].
133pub trait CopyFromHostDispatch: Send + 'static {
134    fn dtype(&self) -> DType;
135    fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>);
136}
137
138pub struct CopyFromHostReq<T: CudaDtype> {
139    pub src: HostBuf<T>,
140    pub dst: GpuRef<T>,
141    pub reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
142}
143
144impl<T: CudaDtype> CopyFromHostDispatch for CopyFromHostReq<T> {
145    fn dtype(&self) -> DType {
146        T::KIND
147    }
148
149    fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>) {
150        let CopyFromHostReq { src, dst, reply } = *self;
151        super::context_actor::run_copy_from_host(src, dst, stream, completion, reply);
152    }
153}