atomr_accel_cuda/kernel/cudnn/
mod.rs1#![allow(dead_code)]
49
50pub mod activation;
51pub mod attention;
52pub mod conv;
53pub mod graph;
54pub mod norm;
55pub mod pool;
56pub mod rnn;
57
58use std::sync::Arc;
59
60use async_trait::async_trait;
61use atomr_core::actor::{Actor, Context, Props};
62use cudarc::cudnn::Cudnn;
63use cudarc::driver::CudaSlice;
64use parking_lot::Mutex;
65use tokio::sync::oneshot;
66
67use crate::completion::CompletionStrategy;
68use crate::device::DeviceState;
69use crate::error::GpuError;
70use crate::gpu_ref::GpuRef;
71use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
72use crate::stream::StreamAllocator;
73
74pub use activation::{
75 ActivationFwdRequest, ActivationKind, DropoutFwdRequest, LrnFwdRequest, LrnParams,
76 SoftmaxFwdRequest, SoftmaxMode,
77};
78pub use attention::{
79 AttentionMask, AttentionParams, MultiHeadAttnBwdRequest, MultiHeadAttnFwdRequest,
80};
81pub use conv::{
82 ConvBwdDataRequest, ConvBwdFilterRequest, ConvDescParams, ConvFwdRequest, EpilogueKind,
83};
84pub use graph::{
85 cache_key, CachedPlan, DtypeTag, NormMode, NormPhase, OpSpec, OperationGraphSpec, PlanCache,
86 PlanCacheKey, PointwiseMode, PoolKind, ReduceOp, TensorLayout, TensorSpec,
87 DEFAULT_PLAN_CACHE_SIZE,
88};
89pub use norm::{
90 BatchNormRequest, GroupNormRequest, InstanceNormRequest, LayerNormRequest, NormBwdRequest,
91};
92pub use pool::{PoolBwdRequest, PoolFwdRequest, PoolMode, PoolParams};
93pub use rnn::{RnnBwdRequest, RnnDirection, RnnFwdRequest, RnnMode, RnnParams};
94
95const LIB: &str = "cudnn";
96
97#[derive(Debug, Clone, Copy)]
104pub struct ConvParams {
105 pub pad: [i32; 2],
106 pub stride: [i32; 2],
107 pub dilation: [i32; 2],
108}
109
110pub struct ConvForwardRequest {
115 pub x: GpuRef<f32>,
116 pub x_dims: [i32; 4],
117 pub w: GpuRef<f32>,
118 pub w_dims: [i32; 4],
119 pub y: GpuRef<f32>,
120 pub y_dims: [i32; 4],
121 pub conv: ConvParams,
122 pub alpha: f32,
123 pub beta: f32,
124 pub reply: oneshot::Sender<Result<(), GpuError>>,
125}
126
127pub struct ActivationRequest {
129 pub kind: ActivationKind,
130 pub x: GpuRef<f32>,
131 pub y: GpuRef<f32>,
132 pub dims: [i32; 4],
133 pub alpha: f32,
134 pub beta: f32,
135 pub reply: oneshot::Sender<Result<(), GpuError>>,
136}
137
138pub struct SoftmaxRequest {
140 pub x: GpuRef<f32>,
141 pub y: GpuRef<f32>,
142 pub dims: [i32; 4],
143 pub alpha: f32,
144 pub beta: f32,
145 pub reply: oneshot::Sender<Result<(), GpuError>>,
146}
147
148pub enum CudnnMsg {
158 Op(Box<dyn CudnnDispatch>),
162
163 #[deprecated(note = "use CudnnMsg::Op with ConvFwdRequest<f32>")]
165 ConvForward(Box<ConvForwardRequest>),
166
167 #[deprecated(note = "use CudnnMsg::Op with ActivationFwdRequest<f32>")]
169 Activation(Box<ActivationRequest>),
170
171 #[deprecated(note = "use CudnnMsg::Op with SoftmaxFwdRequest<f32>")]
173 Softmax(Box<SoftmaxRequest>),
174}
175
176pub struct CudnnActor {
177 inner: CudnnInner,
178}
179
180struct SendCudnn(Arc<Cudnn>);
181unsafe impl Send for SendCudnn {}
182unsafe impl Sync for SendCudnn {}
183
184enum CudnnInner {
185 Real {
186 handle: SendCudnn,
187 stream: Arc<cudarc::driver::CudaStream>,
188 completion: Arc<dyn CompletionStrategy>,
189 plan_cache: Mutex<PlanCache>,
190 workspace: Mutex<Option<CudaSlice<u8>>>,
191 #[allow(dead_code)]
192 state: Arc<DeviceState>,
193 },
194 Mock,
195}
196
197impl CudnnActor {
198 pub fn props(
199 stream: Arc<cudarc::driver::CudaStream>,
200 _allocator: Arc<dyn StreamAllocator>,
201 completion: Arc<dyn CompletionStrategy>,
202 state: Arc<DeviceState>,
203 ) -> Props<Self> {
204 Props::create(move || {
205 let handle = match Cudnn::new(stream.clone()) {
206 Ok(h) => h,
207 Err(e) => panic!("ContextPoisoned: Cudnn::new failed: {e}"),
208 };
209 CudnnActor {
210 inner: CudnnInner::Real {
211 handle: SendCudnn(handle),
212 stream: stream.clone(),
213 completion: completion.clone(),
214 plan_cache: Mutex::new(PlanCache::new(DEFAULT_PLAN_CACHE_SIZE)),
215 workspace: Mutex::new(None),
216 state: state.clone(),
217 },
218 }
219 })
220 }
221
222 pub fn mock_props() -> Props<Self> {
223 Props::create(|| CudnnActor {
224 inner: CudnnInner::Mock,
225 })
226 }
227}
228
229#[async_trait]
230impl Actor for CudnnActor {
231 type Msg = CudnnMsg;
232
233 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: CudnnMsg) {
234 match &self.inner {
235 CudnnInner::Mock => reply_mock(msg),
236 CudnnInner::Real {
237 handle,
238 stream,
239 completion,
240 plan_cache,
241 workspace,
242 ..
243 } => match msg {
244 CudnnMsg::Op(op) => {
245 let ctx = CudnnDispatchCtx {
246 handle: handle.0.clone(),
247 stream: stream.clone(),
248 completion: completion.clone(),
249 plan_cache,
250 workspace,
251 };
252 op.dispatch(&ctx);
253 }
254 #[allow(deprecated)]
255 CudnnMsg::ConvForward(req) => {
256 handle_legacy_conv_fwd(*req);
257 }
258 #[allow(deprecated)]
259 CudnnMsg::Activation(req) => {
260 handle_legacy_activation(*req);
261 }
262 #[allow(deprecated)]
263 CudnnMsg::Softmax(req) => {
264 handle_legacy_softmax(*req);
265 }
266 },
267 }
268 }
269}
270
271fn reply_mock(msg: CudnnMsg) {
272 let err = || GpuError::Unrecoverable("CudnnActor in mock mode".into());
273 match msg {
274 CudnnMsg::Op(_) => {
275 }
282 #[allow(deprecated)]
283 CudnnMsg::ConvForward(r) => {
284 let _ = r.reply.send(Err(err()));
285 }
286 #[allow(deprecated)]
287 CudnnMsg::Activation(r) => {
288 let _ = r.reply.send(Err(err()));
289 }
290 #[allow(deprecated)]
291 CudnnMsg::Softmax(r) => {
292 let _ = r.reply.send(Err(err()));
293 }
294 }
295}
296
297#[allow(deprecated)]
298fn handle_legacy_conv_fwd(req: ConvForwardRequest) {
299 let _ = req.reply.send(Err(GpuError::LibraryError {
304 lib: LIB,
305 msg: "ConvForward (legacy) is deprecated; send CudnnMsg::Op(ConvFwdRequest<f32>) \
306 for v9 frontend dispatch"
307 .to_string(),
308 }));
309}
310
311#[allow(deprecated)]
312fn handle_legacy_activation(req: ActivationRequest) {
313 let _ = req.reply.send(Err(GpuError::LibraryError {
314 lib: LIB,
315 msg: "Activation (legacy) is deprecated; send CudnnMsg::Op(ActivationFwdRequest<f32>)"
316 .to_string(),
317 }));
318}
319
320#[allow(deprecated)]
321fn handle_legacy_softmax(req: SoftmaxRequest) {
322 let _ = req.reply.send(Err(GpuError::LibraryError {
323 lib: LIB,
324 msg: "Softmax (legacy) is deprecated; send CudnnMsg::Op(SoftmaxFwdRequest<f32>)"
325 .to_string(),
326 }));
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
336 #[allow(deprecated)]
337 fn deprecated_conv_forward_alias_still_constructs() {
338 let (tx, _rx) = oneshot::channel();
339 let p = ConvParams {
345 pad: [0, 0],
346 stride: [1, 1],
347 dilation: [1, 1],
348 };
349 assert_eq!(p.pad, [0, 0]);
350 assert_eq!(p.stride, [1, 1]);
351 fn _accepts_legacy(_: &CudnnMsg) {}
355 struct Probe(oneshot::Sender<Result<(), GpuError>>);
358 impl CudnnDispatch for Probe {
359 fn dtype_name(&self) -> &'static str {
360 "f32"
361 }
362 fn op_kind(&self) -> &'static str {
363 "probe"
364 }
365 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
366 let _ = self.0.send(Ok(()));
367 }
368 }
369 let msg = CudnnMsg::Op(Box::new(Probe(tx)));
370 _accepts_legacy(&msg);
371 }
372
373 #[test]
374 fn cudnn_dispatch_is_object_safe() {
375 fn _accept(_: Box<dyn CudnnDispatch>) {}
377 }
378
379 #[test]
380 fn plan_cache_default_size_matches_constant() {
381 let pc = PlanCache::default();
382 assert_eq!(pc.cap(), DEFAULT_PLAN_CACHE_SIZE);
383 }
384}