atomr_accel_cuda/device/
alloc_dispatch.rs1use std::sync::Arc;
25
26use cudarc::driver::CudaStream;
27use tokio::sync::oneshot;
28
29use crate::completion::CompletionStrategy;
30use crate::dtype::{CudaDtype, DType};
31use crate::error::GpuError;
32use crate::gpu_ref::GpuRef;
33
34use super::alloc_msg::HostBuf;
35use super::state::DeviceState;
36
37pub trait AllocDispatch: Send + 'static {
43 fn dtype(&self) -> DType;
45
46 fn len(&self) -> usize;
48
49 fn run(
57 self: Box<Self>,
58 stream: Option<&Arc<CudaStream>>,
59 state: &Arc<DeviceState>,
60 mock_mode: bool,
61 );
62}
63
64pub struct AllocReq<T: CudaDtype> {
67 pub len: usize,
68 pub reply: oneshot::Sender<Result<GpuRef<T>, GpuError>>,
69}
70
71impl<T: CudaDtype> AllocDispatch for AllocReq<T> {
72 fn dtype(&self) -> DType {
73 T::KIND
74 }
75
76 fn len(&self) -> usize {
77 self.len
78 }
79
80 fn run(
81 self: Box<Self>,
82 stream: Option<&Arc<CudaStream>>,
83 state: &Arc<DeviceState>,
84 mock_mode: bool,
85 ) {
86 let AllocReq { len, reply } = *self;
87 if mock_mode {
88 let _ = reply.send(Err(GpuError::Unrecoverable(
89 "alloc not supported in mock mode".into(),
90 )));
91 return;
92 }
93 let Some(stream) = stream else {
94 let _ = reply.send(Err(GpuError::GpuRefStale("context not ready")));
95 return;
96 };
97 match stream.alloc_zeros::<T>(len) {
98 Ok(slice) => {
99 let _ = reply.send(Ok(GpuRef::<T>::new(Arc::new(slice), state)));
100 }
101 Err(e) => {
102 let _ = reply.send(Err(GpuError::OutOfMemory(format!("alloc {len}: {e}"))));
103 }
104 }
105 }
106}
107
108pub trait CopyToHostDispatch: Send + 'static {
111 fn dtype(&self) -> DType;
112 fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>);
113}
114
115pub struct CopyToHostReq<T: CudaDtype> {
116 pub src: GpuRef<T>,
117 pub dst: HostBuf<T>,
118 pub reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
119}
120
121impl<T: CudaDtype> CopyToHostDispatch for CopyToHostReq<T> {
122 fn dtype(&self) -> DType {
123 T::KIND
124 }
125
126 fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>) {
127 let CopyToHostReq { src, dst, reply } = *self;
128 super::context_actor::run_copy_to_host(src, dst, stream, completion, reply);
129 }
130}
131
132pub trait CopyFromHostDispatch: Send + 'static {
134 fn dtype(&self) -> DType;
135 fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>);
136}
137
138pub struct CopyFromHostReq<T: CudaDtype> {
139 pub src: HostBuf<T>,
140 pub dst: GpuRef<T>,
141 pub reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
142}
143
144impl<T: CudaDtype> CopyFromHostDispatch for CopyFromHostReq<T> {
145 fn dtype(&self) -> DType {
146 T::KIND
147 }
148
149 fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>) {
150 let CopyFromHostReq { src, dst, reply } = *self;
151 super::context_actor::run_copy_from_host(src, dst, stream, completion, reply);
152 }
153}