atomr_accel_cuda/graph/record/
cudnn.rs1#![cfg(feature = "cudnn")]
16
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::graph::{GraphOpRecord, GraphRecordCtx};
20use crate::kernel::{ActivationKind, ConvParams};
21
22pub struct ConvForwardOp {
24 pub x: GpuRef<f32>,
25 pub x_dims: [i32; 4],
26 pub w: GpuRef<f32>,
27 pub w_dims: [i32; 4],
28 pub y: GpuRef<f32>,
29 pub y_dims: [i32; 4],
30 pub conv: ConvParams,
31 pub alpha: f32,
32 pub beta: f32,
33}
34
35pub struct ActivationOp {
37 pub kind: ActivationKind,
38 pub x: GpuRef<f32>,
39 pub y: GpuRef<f32>,
40 pub dims: [i32; 4],
41 pub alpha: f32,
42 pub beta: f32,
43}
44
45pub struct SoftmaxOp {
47 pub x: GpuRef<f32>,
48 pub y: GpuRef<f32>,
49 pub dims: [i32; 4],
50 pub alpha: f32,
51 pub beta: f32,
52}
53
54impl GraphOpRecord for ConvForwardOp {
55 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
56 validate_dims(&self.x_dims, "conv: x_dims")?;
58 validate_dims(&self.w_dims, "conv: w_dims")?;
59 validate_dims(&self.y_dims, "conv: y_dims")?;
60 let _ = self.x.access()?;
61 let _ = self.w.access()?;
62 let _ = self.y.access()?;
63 let _ = ctx;
71 Err(GpuError::Unrecoverable(
72 "graph::record::cudnn::ConvForward: cuDNN capture-mode \
73 entry not yet wired (Phase 3 surface only)"
74 .into(),
75 ))
76 }
77}
78
79impl GraphOpRecord for ActivationOp {
80 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
81 validate_dims(&self.dims, "activation: dims")?;
82 let _ = self.x.access()?;
83 let _ = self.y.access()?;
84 let _ = ctx;
85 Err(GpuError::Unrecoverable(
86 "graph::record::cudnn::Activation: cuDNN capture-mode \
87 entry not yet wired"
88 .into(),
89 ))
90 }
91}
92
93impl GraphOpRecord for SoftmaxOp {
94 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
95 validate_dims(&self.dims, "softmax: dims")?;
96 let _ = self.x.access()?;
97 let _ = self.y.access()?;
98 let _ = ctx;
99 Err(GpuError::Unrecoverable(
100 "graph::record::cudnn::Softmax: cuDNN capture-mode \
101 entry not yet wired"
102 .into(),
103 ))
104 }
105}
106
107fn validate_dims(d: &[i32; 4], who: &str) -> Result<(), GpuError> {
108 if d.iter().any(|&x| x <= 0) {
109 Err(GpuError::Unrecoverable(format!(
110 "{who}: non-positive dim in {d:?}"
111 )))
112 } else {
113 Ok(())
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::device::DeviceState;
121 use crate::graph::MockGraphRecordCtx;
122 use cudarc::driver::sys as driver_sys;
123 use std::sync::Arc;
124
125 fn dead_gpu_ref() -> GpuRef<f32> {
130 let _ = DeviceState::new(0);
138 unimplemented!("not used — dim-validation tests cover the path")
139 }
140
141 #[test]
142 fn conv_op_records() {
143 let null_graph: driver_sys::CUgraph = std::ptr::null_mut();
147 let mock = MockGraphRecordCtx::new(null_graph);
148 let ctx = mock.as_ctx();
149
150 assert!(validate_dims(&[1, 1, 1, 1], "ok").is_ok());
152 assert!(validate_dims(&[0, 1, 1, 1], "bad").is_err());
153
154 let _ = dead_gpu_ref;
160 let _ = Arc::new(()) as Arc<()>;
161 let _ = ctx;
162 }
163
164 #[test]
165 fn activation_op_records() {
166 assert!(validate_dims(&[1, 2, 3, 4], "ok").is_ok());
167 assert!(validate_dims(&[1, 2, 3, -1], "bad").is_err());
168 }
169
170 #[test]
171 fn softmax_op_records() {
172 assert!(validate_dims(&[2, 4, 1, 1], "ok").is_ok());
173 }
174}