Skip to main content

atomr_accel_cuda/graph/record/
cudnn.rs

1//! `GraphOpRecord` impls for [`crate::kernel::CudnnActor`] requests.
2//!
3//! Each op holds the same payload as the matching `CudnnMsg` variant
4//! minus the reply channel, plus the cuDNN handle that the actor
5//! ordinarily owns. The graph script caller passes the handle in
6//! once, then submits a series of [`ConvForwardOp`]s without
7//! reaching back into the actor.
8//!
9//! For Phase 3 we keep these structs descriptor-shape-only: their
10//! `record` method validates the inputs and returns `Unrecoverable`
11//! on hosts where cuDNN isn't loadable. A future sub-phase will wire
12//! the actual cuDNN-record path (cuDNN supports stream-capture as of
13//! v8.5).
14
15#![cfg(feature = "cudnn")]
16
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::graph::{GraphOpRecord, GraphRecordCtx};
20use crate::kernel::{ActivationKind, ConvParams};
21
22/// Capture-mode op for `CudnnMsg::ConvForward`.
23pub struct ConvForwardOp {
24    pub x: GpuRef<f32>,
25    pub x_dims: [i32; 4],
26    pub w: GpuRef<f32>,
27    pub w_dims: [i32; 4],
28    pub y: GpuRef<f32>,
29    pub y_dims: [i32; 4],
30    pub conv: ConvParams,
31    pub alpha: f32,
32    pub beta: f32,
33}
34
35/// Capture-mode op for `CudnnMsg::Activation`.
36pub struct ActivationOp {
37    pub kind: ActivationKind,
38    pub x: GpuRef<f32>,
39    pub y: GpuRef<f32>,
40    pub dims: [i32; 4],
41    pub alpha: f32,
42    pub beta: f32,
43}
44
45/// Capture-mode op for `CudnnMsg::Softmax`.
46pub struct SoftmaxOp {
47    pub x: GpuRef<f32>,
48    pub y: GpuRef<f32>,
49    pub dims: [i32; 4],
50    pub alpha: f32,
51    pub beta: f32,
52}
53
54impl GraphOpRecord for ConvForwardOp {
55    fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
56        // Validate first.
57        validate_dims(&self.x_dims, "conv: x_dims")?;
58        validate_dims(&self.w_dims, "conv: w_dims")?;
59        validate_dims(&self.y_dims, "conv: y_dims")?;
60        let _ = self.x.access()?;
61        let _ = self.w.access()?;
62        let _ = self.y.access()?;
63        // Phase 3: cuDNN's stream-capture path is not yet wired
64        // through the existing `CudnnActor` (which uses
65        // `envelope::run_kernel` host-fn completion that's not
66        // capture-safe). Until the actor exposes a capture-safe entry
67        // point, we surface a clear Unrecoverable here. The
68        // descriptor and validation are still enforced so callers
69        // catch shape mismatches early.
70        let _ = ctx;
71        Err(GpuError::Unrecoverable(
72            "graph::record::cudnn::ConvForward: cuDNN capture-mode \
73             entry not yet wired (Phase 3 surface only)"
74                .into(),
75        ))
76    }
77}
78
79impl GraphOpRecord for ActivationOp {
80    fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
81        validate_dims(&self.dims, "activation: dims")?;
82        let _ = self.x.access()?;
83        let _ = self.y.access()?;
84        let _ = ctx;
85        Err(GpuError::Unrecoverable(
86            "graph::record::cudnn::Activation: cuDNN capture-mode \
87             entry not yet wired"
88                .into(),
89        ))
90    }
91}
92
93impl GraphOpRecord for SoftmaxOp {
94    fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
95        validate_dims(&self.dims, "softmax: dims")?;
96        let _ = self.x.access()?;
97        let _ = self.y.access()?;
98        let _ = ctx;
99        Err(GpuError::Unrecoverable(
100            "graph::record::cudnn::Softmax: cuDNN capture-mode \
101             entry not yet wired"
102                .into(),
103        ))
104    }
105}
106
107fn validate_dims(d: &[i32; 4], who: &str) -> Result<(), GpuError> {
108    if d.iter().any(|&x| x <= 0) {
109        Err(GpuError::Unrecoverable(format!(
110            "{who}: non-positive dim in {d:?}"
111        )))
112    } else {
113        Ok(())
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::device::DeviceState;
121    use crate::graph::MockGraphRecordCtx;
122    use cudarc::driver::sys as driver_sys;
123    use std::sync::Arc;
124
125    /// Build a minimal `GpuRef<f32>` that's *invalid* (no underlying
126    /// CudaSlice) but typed correctly. We won't actually deref it —
127    /// the record path catches it via `access()` returning
128    /// `GpuRefStale`.
129    fn dead_gpu_ref() -> GpuRef<f32> {
130        // We can't build a real CudaSlice without a CudaContext; the
131        // test relies on the dim-validation path failing fast before
132        // the GpuRef is touched. Use a no-element placeholder by
133        // constructing through a bogus DeviceState — but the public
134        // GpuRef::new requires a real Arc<CudaSlice>. Workaround: skip
135        // the GpuRef-bearing assertions and exercise dim validation
136        // separately.
137        let _ = DeviceState::new(0);
138        unimplemented!("not used — dim-validation tests cover the path")
139    }
140
141    #[test]
142    fn conv_op_records() {
143        // Validation-failure path: zero dim. We exercise it through
144        // the record() method but skip the GpuRef accesses by using a
145        // mock context and asserting the typed error category.
146        let null_graph: driver_sys::CUgraph = std::ptr::null_mut();
147        let mock = MockGraphRecordCtx::new(null_graph);
148        let ctx = mock.as_ctx();
149
150        // Use validate_dims directly to avoid needing live GpuRefs.
151        assert!(validate_dims(&[1, 1, 1, 1], "ok").is_ok());
152        assert!(validate_dims(&[0, 1, 1, 1], "bad").is_err());
153
154        // Smoke-test the trait wiring with carefully chosen dims that
155        // pass validation; then access() of a synthetic GpuRef would
156        // fail, but we need to construct one. For Phase 3 we keep the
157        // assertion to "validation surface compiles" and rely on
158        // dim-validation tests above for behaviour.
159        let _ = dead_gpu_ref;
160        let _ = Arc::new(()) as Arc<()>;
161        let _ = ctx;
162    }
163
164    #[test]
165    fn activation_op_records() {
166        assert!(validate_dims(&[1, 2, 3, 4], "ok").is_ok());
167        assert!(validate_dims(&[1, 2, 3, -1], "bad").is_err());
168    }
169
170    #[test]
171    fn softmax_op_records() {
172        assert!(validate_dims(&[2, 4, 1, 1], "ok").is_ok());
173    }
174}