Skip to main content

atomr_accel_cuda/kernel/cudnn/
mod.rs

1//! `CudnnActor` — Phase 2 cuDNN slice. Wraps a [`cudarc::cudnn::Cudnn`]
2//! handle and exposes the v9 frontend graph API plus legacy
3//! `ConvForward` / `Activation` / `Softmax` shims for back-compat.
4//!
5//! # Module layout
6//!
7//! ```text
8//! kernel/cudnn/
9//! ├── mod.rs           — CudnnActor, CudnnMsg, CudnnInner, props
10//! ├── graph.rs         — v9 frontend graph spec (TensorSpec, OpSpec,
11//! │                     OperationGraphSpec) + plan cache
12//! ├── conv.rs          — ConvFwdRequest<T>, ConvBwdDataRequest<T>,
13//! │                     ConvBwdFilterRequest<T>
14//! ├── norm.rs          — BatchNormRequest<T>, LayerNormRequest<T>,
15//! │                     InstanceNormRequest<T>, GroupNormRequest<T>,
16//! │                     NormBwdRequest<T>
17//! ├── pool.rs          — PoolFwdRequest<T>, PoolBwdRequest<T>
18//! ├── attention.rs     — MultiHeadAttnFwdRequest<T>,
19//! │                     MultiHeadAttnBwdRequest<T>
20//! ├── rnn.rs           — RnnFwdRequest<T>, RnnBwdRequest<T>
21//! └── activation.rs    — ActivationFwdRequest<T>, SoftmaxFwdRequest<T>,
22//!                       DropoutFwdRequest<T>, LrnFwdRequest<T>
23//! ```
24//!
25//! # Op coverage
26//!
27//! | Family       | Forward | Backward | Notes                                         |
28//! |--------------|:-------:|:--------:|-----------------------------------------------|
29//! | Conv         | ✓       | ✓ data + filter | 1D/2D/3D, NCHW + NHWC, groups, dilation |
30//! | Pool         | ✓       | ✓        | max, avg, avg-exclude-padding                 |
31//! | BatchNorm    | ✓       | ✓        | training, inference, persistent               |
32//! | LayerNorm    | ✓       | ✓        | training, inference                           |
33//! | InstanceNorm | ✓       | ✓        |                                               |
34//! | GroupNorm    | ✓       | ✓        |                                               |
35//! | Activation   | ✓       | (fused with conv epilogue) | relu, sigmoid, tanh, gelu, gelu_approx, swish, elu, softplus, identity |
36//! | Softmax      | ✓       | (planned) | instance + channel mode                      |
37//! | Dropout      | ✓       | (planned) |                                              |
38//! | LRN          | ✓       | (planned) |                                              |
39//! | Attention    | ✓       | ✓        | causal, sliding-window, GQA/MQA, dropout     |
40//! | RNN/LSTM/GRU | ✓       | ✓        | uni + bi, multi-layer, dropout              |
41//!
42//! # Dtype matrix
43//!
44//! Every request type is generic over `T: crate::dtype::CudnnSupported`.
45//! Implementations cover `f32`, `f64`, `i8`, plus `half::f16` and
46//! `half::bf16` under the `f16` feature.
47
48#![allow(dead_code)]
49
50pub mod activation;
51pub mod attention;
52pub mod conv;
53pub mod graph;
54pub mod norm;
55pub mod pool;
56pub mod rnn;
57
58use std::sync::Arc;
59
60use async_trait::async_trait;
61use atomr_core::actor::{Actor, Context, Props};
62use cudarc::cudnn::Cudnn;
63use cudarc::driver::CudaSlice;
64use parking_lot::Mutex;
65use tokio::sync::oneshot;
66
67use crate::completion::CompletionStrategy;
68use crate::device::DeviceState;
69use crate::error::GpuError;
70use crate::gpu_ref::GpuRef;
71use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
72use crate::stream::StreamAllocator;
73
74pub use activation::{
75    ActivationFwdRequest, ActivationKind, DropoutFwdRequest, LrnFwdRequest, LrnParams,
76    SoftmaxFwdRequest, SoftmaxMode,
77};
78pub use attention::{
79    AttentionMask, AttentionParams, MultiHeadAttnBwdRequest, MultiHeadAttnFwdRequest,
80};
81pub use conv::{
82    ConvBwdDataRequest, ConvBwdFilterRequest, ConvDescParams, ConvFwdRequest, EpilogueKind,
83};
84pub use graph::{
85    cache_key, CachedPlan, DtypeTag, NormMode, NormPhase, OpSpec, OperationGraphSpec, PlanCache,
86    PlanCacheKey, PointwiseMode, PoolKind, ReduceOp, TensorLayout, TensorSpec,
87    DEFAULT_PLAN_CACHE_SIZE,
88};
89pub use norm::{
90    BatchNormRequest, GroupNormRequest, InstanceNormRequest, LayerNormRequest, NormBwdRequest,
91};
92pub use pool::{PoolBwdRequest, PoolFwdRequest, PoolMode, PoolParams};
93pub use rnn::{RnnBwdRequest, RnnDirection, RnnFwdRequest, RnnMode, RnnParams};
94
95const LIB: &str = "cudnn";
96
97// ----- Legacy back-compat parameter / request types -------------------
98
99/// Convolution parameters (cuDNN 2D conv subset).
100///
101/// **Deprecated** — kept for back-compat with the F2 ConvForward API.
102/// New code should construct [`ConvDescParams`] directly.
103#[derive(Debug, Clone, Copy)]
104pub struct ConvParams {
105    pub pad: [i32; 2],
106    pub stride: [i32; 2],
107    pub dilation: [i32; 2],
108}
109
110/// Legacy F2 ConvForward request (NCHW, f32 only).
111///
112/// **Deprecated** — use [`ConvFwdRequest<T>`] under
113/// [`CudnnMsg::Op`] for new code.
114pub struct ConvForwardRequest {
115    pub x: GpuRef<f32>,
116    pub x_dims: [i32; 4],
117    pub w: GpuRef<f32>,
118    pub w_dims: [i32; 4],
119    pub y: GpuRef<f32>,
120    pub y_dims: [i32; 4],
121    pub conv: ConvParams,
122    pub alpha: f32,
123    pub beta: f32,
124    pub reply: oneshot::Sender<Result<(), GpuError>>,
125}
126
127/// Legacy F2 Activation request.
128pub struct ActivationRequest {
129    pub kind: ActivationKind,
130    pub x: GpuRef<f32>,
131    pub y: GpuRef<f32>,
132    pub dims: [i32; 4],
133    pub alpha: f32,
134    pub beta: f32,
135    pub reply: oneshot::Sender<Result<(), GpuError>>,
136}
137
138/// Legacy F2 Softmax request.
139pub struct SoftmaxRequest {
140    pub x: GpuRef<f32>,
141    pub y: GpuRef<f32>,
142    pub dims: [i32; 4],
143    pub alpha: f32,
144    pub beta: f32,
145    pub reply: oneshot::Sender<Result<(), GpuError>>,
146}
147
148// ----- CudnnMsg + actor ----------------------------------------------
149
150/// Mailbox message for [`CudnnActor`].
151///
152/// Modern callers send `CudnnMsg::Op(Box<dyn CudnnDispatch>)` with a
153/// typed request struct (e.g. `ConvFwdRequest<f16>`). The legacy
154/// `ConvForward` / `Activation` / `Softmax` variants are retained for
155/// back-compat and are slated for removal once downstream users
156/// migrate.
157pub enum CudnnMsg {
158    /// Generic typed cuDNN op (canonical form). The boxed trait object
159    /// carries dtype + op kind for telemetry and dispatches via
160    /// [`CudnnDispatch::dispatch`].
161    Op(Box<dyn CudnnDispatch>),
162
163    /// **Deprecated** — use [`CudnnMsg::Op`] with [`ConvFwdRequest<f32>`].
164    #[deprecated(note = "use CudnnMsg::Op with ConvFwdRequest<f32>")]
165    ConvForward(Box<ConvForwardRequest>),
166
167    /// **Deprecated** — use [`CudnnMsg::Op`] with [`ActivationFwdRequest<f32>`].
168    #[deprecated(note = "use CudnnMsg::Op with ActivationFwdRequest<f32>")]
169    Activation(Box<ActivationRequest>),
170
171    /// **Deprecated** — use [`CudnnMsg::Op`] with [`SoftmaxFwdRequest<f32>`].
172    #[deprecated(note = "use CudnnMsg::Op with SoftmaxFwdRequest<f32>")]
173    Softmax(Box<SoftmaxRequest>),
174}
175
176pub struct CudnnActor {
177    inner: CudnnInner,
178}
179
180struct SendCudnn(Arc<Cudnn>);
181unsafe impl Send for SendCudnn {}
182unsafe impl Sync for SendCudnn {}
183
184enum CudnnInner {
185    Real {
186        handle: SendCudnn,
187        stream: Arc<cudarc::driver::CudaStream>,
188        completion: Arc<dyn CompletionStrategy>,
189        plan_cache: Mutex<PlanCache>,
190        workspace: Mutex<Option<CudaSlice<u8>>>,
191        #[allow(dead_code)]
192        state: Arc<DeviceState>,
193    },
194    Mock,
195}
196
197impl CudnnActor {
198    pub fn props(
199        stream: Arc<cudarc::driver::CudaStream>,
200        _allocator: Arc<dyn StreamAllocator>,
201        completion: Arc<dyn CompletionStrategy>,
202        state: Arc<DeviceState>,
203    ) -> Props<Self> {
204        Props::create(move || {
205            let handle = match Cudnn::new(stream.clone()) {
206                Ok(h) => h,
207                Err(e) => panic!("ContextPoisoned: Cudnn::new failed: {e}"),
208            };
209            CudnnActor {
210                inner: CudnnInner::Real {
211                    handle: SendCudnn(handle),
212                    stream: stream.clone(),
213                    completion: completion.clone(),
214                    plan_cache: Mutex::new(PlanCache::new(DEFAULT_PLAN_CACHE_SIZE)),
215                    workspace: Mutex::new(None),
216                    state: state.clone(),
217                },
218            }
219        })
220    }
221
222    pub fn mock_props() -> Props<Self> {
223        Props::create(|| CudnnActor {
224            inner: CudnnInner::Mock,
225        })
226    }
227}
228
229#[async_trait]
230impl Actor for CudnnActor {
231    type Msg = CudnnMsg;
232
233    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: CudnnMsg) {
234        match &self.inner {
235            CudnnInner::Mock => reply_mock(msg),
236            CudnnInner::Real {
237                handle,
238                stream,
239                completion,
240                plan_cache,
241                workspace,
242                ..
243            } => match msg {
244                CudnnMsg::Op(op) => {
245                    let ctx = CudnnDispatchCtx {
246                        handle: handle.0.clone(),
247                        stream: stream.clone(),
248                        completion: completion.clone(),
249                        plan_cache,
250                        workspace,
251                    };
252                    op.dispatch(&ctx);
253                }
254                #[allow(deprecated)]
255                CudnnMsg::ConvForward(req) => {
256                    handle_legacy_conv_fwd(*req);
257                }
258                #[allow(deprecated)]
259                CudnnMsg::Activation(req) => {
260                    handle_legacy_activation(*req);
261                }
262                #[allow(deprecated)]
263                CudnnMsg::Softmax(req) => {
264                    handle_legacy_softmax(*req);
265                }
266            },
267        }
268    }
269}
270
271fn reply_mock(msg: CudnnMsg) {
272    let err = || GpuError::Unrecoverable("CudnnActor in mock mode".into());
273    match msg {
274        CudnnMsg::Op(_) => {
275            // The op's reply channel is owned inside the box; we
276            // dispatch to a no-op variant of the dispatcher to send a
277            // clear error. The dispatch impls all check
278            // `mock`-equivalent state and reply with LibraryError, so
279            // we just drop the box here — that closes any oneshot
280            // senders inside (which the receivers see as Closed).
281        }
282        #[allow(deprecated)]
283        CudnnMsg::ConvForward(r) => {
284            let _ = r.reply.send(Err(err()));
285        }
286        #[allow(deprecated)]
287        CudnnMsg::Activation(r) => {
288            let _ = r.reply.send(Err(err()));
289        }
290        #[allow(deprecated)]
291        CudnnMsg::Softmax(r) => {
292            let _ = r.reply.send(Err(err()));
293        }
294    }
295}
296
297#[allow(deprecated)]
298fn handle_legacy_conv_fwd(req: ConvForwardRequest) {
299    // The legacy launch path lived in cudnn_actor.rs; for the v2
300    // skeleton we reply with a clear migration message. Real callers
301    // should use CudnnMsg::Op(ConvFwdRequest<f32>) which routes
302    // through the v9 frontend graph builder.
303    let _ = req.reply.send(Err(GpuError::LibraryError {
304        lib: LIB,
305        msg: "ConvForward (legacy) is deprecated; send CudnnMsg::Op(ConvFwdRequest<f32>) \
306              for v9 frontend dispatch"
307            .to_string(),
308    }));
309}
310
311#[allow(deprecated)]
312fn handle_legacy_activation(req: ActivationRequest) {
313    let _ = req.reply.send(Err(GpuError::LibraryError {
314        lib: LIB,
315        msg: "Activation (legacy) is deprecated; send CudnnMsg::Op(ActivationFwdRequest<f32>)"
316            .to_string(),
317    }));
318}
319
320#[allow(deprecated)]
321fn handle_legacy_softmax(req: SoftmaxRequest) {
322    let _ = req.reply.send(Err(GpuError::LibraryError {
323        lib: LIB,
324        msg: "Softmax (legacy) is deprecated; send CudnnMsg::Op(SoftmaxFwdRequest<f32>)"
325            .to_string(),
326    }));
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    /// The deprecated `ConvForward` variant still constructs and the
334    /// boxed legacy request still carries its inner fields.
335    #[test]
336    #[allow(deprecated)]
337    fn deprecated_conv_forward_alias_still_constructs() {
338        let (tx, _rx) = oneshot::channel();
339        // We can't construct a real GpuRef here, but constructing the
340        // request struct itself is sufficient for the alias check —
341        // the field types compile.
342        // Skip building the GpuRefs (would need a real CudaSlice);
343        // instead exercise the related ConvParams + flag round-trip.
344        let p = ConvParams {
345            pad: [0, 0],
346            stride: [1, 1],
347            dilation: [1, 1],
348        };
349        assert_eq!(p.pad, [0, 0]);
350        assert_eq!(p.stride, [1, 1]);
351        // Verify variant tag construction by way of a dummy enum
352        // pattern — we don't build the boxed variant (no GpuRef
353        // available), but we confirm the variant exists by reference.
354        fn _accepts_legacy(_: &CudnnMsg) {}
355        // Build a fresh Op variant from a tiny dispatcher to confirm
356        // CudnnMsg::Op carries Box<dyn CudnnDispatch>.
357        struct Probe(oneshot::Sender<Result<(), GpuError>>);
358        impl CudnnDispatch for Probe {
359            fn dtype_name(&self) -> &'static str {
360                "f32"
361            }
362            fn op_kind(&self) -> &'static str {
363                "probe"
364            }
365            fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
366                let _ = self.0.send(Ok(()));
367            }
368        }
369        let msg = CudnnMsg::Op(Box::new(Probe(tx)));
370        _accepts_legacy(&msg);
371    }
372
373    #[test]
374    fn cudnn_dispatch_is_object_safe() {
375        // Verifies the trait is dyn-safe (compile-only check).
376        fn _accept(_: Box<dyn CudnnDispatch>) {}
377    }
378
379    #[test]
380    fn plan_cache_default_size_matches_constant() {
381        let pc = PlanCache::default();
382        assert_eq!(pc.cap(), DEFAULT_PLAN_CACHE_SIZE);
383    }
384}