atomr_accel_cuda/kernel/tensor/
mod.rs1use std::sync::Arc;
17
18use async_trait::async_trait;
19use atomr_core::actor::{Actor, Context, Props};
20use cudarc::cutensor::result as ct_result;
21use cudarc::cutensor::sys as ct_sys;
22use parking_lot::Mutex;
23use tokio::sync::oneshot;
24
25use crate::completion::CompletionStrategy;
26use crate::device::DeviceState;
27use crate::error::GpuError;
28use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
29use crate::stream::StreamAllocator;
30
31#[cfg(feature = "cutensor-autotune")]
32pub mod autotune;
33pub mod compute_desc;
34pub mod contract;
35pub mod elementwise;
36pub mod permute;
37pub mod plan_cache;
38pub mod reduce;
39
40pub use compute_desc::ComputeDesc;
41pub use contract::{ContractRequest, OperandSpec};
42pub use elementwise::{ElementwiseBinaryRequest, ElementwiseTrinaryRequest};
43pub use permute::PermutationRequest;
44pub use plan_cache::{PlanCache, PlanKey, DEFAULT_PLAN_CACHE_SIZE};
45pub use reduce::ReductionRequest;
46
47pub type TensorSpec = OperandSpec<f32>;
51
52pub struct SendHandle(pub ct_sys::cutensorHandle_t);
57unsafe impl Send for SendHandle {}
58unsafe impl Sync for SendHandle {}
59
60pub enum TensorMsg {
64 Op(Box<dyn TensorDispatch>),
66 #[deprecated(note = "use TensorMsg::Op(Box::new(ContractRequest::<f32>::new(...)))")]
70 Contract {
71 a: TensorSpec,
72 b: TensorSpec,
73 c: TensorSpec,
74 alpha: f32,
75 beta: f32,
76 reply: oneshot::Sender<Result<(), GpuError>>,
77 },
78}
79
80pub struct TensorActor {
81 inner: TensorInner,
82}
83
84#[allow(clippy::large_enum_variant)]
85enum TensorInner {
86 Real {
87 ctx: TensorDispatchCtx,
88 #[allow(dead_code)]
89 state: Arc<DeviceState>,
90 },
91 Mock,
92}
93
94impl Drop for TensorInner {
95 fn drop(&mut self) {
96 if let TensorInner::Real { ctx, .. } = self {
97 let h = ctx.handle.lock();
98 unsafe {
99 let _ = ct_result::destroy_handle(h.0);
100 }
101 }
102 }
103}
104
105impl TensorActor {
106 pub fn props(
107 stream: Arc<cudarc::driver::CudaStream>,
108 _allocator: Arc<dyn StreamAllocator>,
109 completion: Arc<dyn CompletionStrategy>,
110 state: Arc<DeviceState>,
111 ) -> Props<Self> {
112 Props::create(move || {
113 let h = match ct_result::create_handle() {
114 Ok(h) => h,
115 Err(e) => panic!("ContextPoisoned: cutensorCreate failed: {e}"),
116 };
117 let ctx = TensorDispatchCtx {
118 handle: Arc::new(Mutex::new(SendHandle(h))),
119 stream: stream.clone(),
120 completion: completion.clone(),
121 plan_cache: Arc::new(PlanCache::with_default_capacity()),
122 workspace: Arc::new(WorkspacePool::new(stream.clone())),
123 };
124 TensorActor {
125 inner: TensorInner::Real {
126 ctx,
127 state: state.clone(),
128 },
129 }
130 })
131 }
132
133 pub fn mock_props() -> Props<Self> {
134 Props::create(|| TensorActor {
135 inner: TensorInner::Mock,
136 })
137 }
138}
139
140#[async_trait]
141impl Actor for TensorActor {
142 type Msg = TensorMsg;
143
144 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TensorMsg) {
145 match &self.inner {
146 TensorInner::Mock => mock_reply(msg),
147 TensorInner::Real { ctx, .. } => match msg {
148 TensorMsg::Op(req) => req.dispatch(ctx),
149 #[allow(deprecated)]
150 TensorMsg::Contract {
151 a,
152 b,
153 c,
154 alpha,
155 beta,
156 reply,
157 } => {
158 let req = ContractRequest::<f32>::new(a, b, c, alpha, beta, reply);
159 Box::new(req).dispatch(ctx);
160 }
161 },
162 }
163 }
164}
165
166fn mock_reply(msg: TensorMsg) {
167 match msg {
168 TensorMsg::Op(req) => req.fail_mock(),
169 #[allow(deprecated)]
170 TensorMsg::Contract { reply, .. } => {
171 let _ = reply.send(Err(GpuError::Unrecoverable(
172 "TensorActor in mock mode".into(),
173 )));
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
189 fn deprecated_contract_alias_still_constructs() {
190 let (tx_op, rx_op) = oneshot::channel();
192 let mock = MockReq { reply: Some(tx_op) };
193 let msg_op = TensorMsg::Op(Box::new(mock));
194 mock_reply(msg_op);
195 let res = rx_op
196 .blocking_recv()
197 .expect("Op mock_reply must send a result");
198 assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
199
200 legacy_contract_mock_path();
208 }
209
210 #[allow(deprecated)]
217 #[allow(dead_code)]
218 fn legacy_contract_mock_path() {
219 let _build: fn(
223 TensorSpec,
224 TensorSpec,
225 TensorSpec,
226 f32,
227 f32,
228 oneshot::Sender<Result<(), GpuError>>,
229 ) -> TensorMsg = |a, b, c, alpha, beta, reply| TensorMsg::Contract {
230 a,
231 b,
232 c,
233 alpha,
234 beta,
235 reply,
236 };
237 }
238
239 struct MockReq {
240 reply: Option<oneshot::Sender<Result<(), GpuError>>>,
241 }
242 impl TensorDispatch for MockReq {
243 fn op_tag(&self) -> &'static str {
244 "mock"
245 }
246 fn dtype_tag(&self) -> &'static str {
247 "mock"
248 }
249 fn dispatch(self: Box<Self>, _ctx: &TensorDispatchCtx) {}
250 fn fail_mock(mut self: Box<Self>) {
251 if let Some(tx) = self.reply.take() {
252 let _ = tx.send(Err(GpuError::Unrecoverable(
253 "TensorActor in mock mode".into(),
254 )));
255 }
256 }
257 }
258}