Skip to main content

atomr_accel_cuda/kernel/cudnn/
norm.rs

1//! Normalisation requests for the cuDNN actor: BatchNorm (training +
2//! inference + persistent), LayerNorm, InstanceNorm, GroupNorm,
3//! RMSNorm.
4
5#![allow(dead_code)]
6
7use std::marker::PhantomData;
8
9use tokio::sync::oneshot;
10
11use crate::dtype::CudnnSupported;
12use crate::error::GpuError;
13use crate::gpu_ref::GpuRef;
14use crate::kernel::cudnn::conv::dtype_tag;
15use crate::kernel::cudnn::graph::{
16    DtypeTag, NormMode, NormPhase, OpSpec, OperationGraphSpec, TensorLayout, TensorSpec,
17};
18use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
19
20/// BatchNorm — training-mode running-mean/var update + per-channel
21/// scale + bias.
22pub struct BatchNormRequest<T: CudnnSupported> {
23    pub phase: NormPhase,
24    pub x: GpuRef<T>,
25    pub y: GpuRef<T>,
26    pub scale: GpuRef<T>,
27    pub bias: GpuRef<T>,
28    pub running_mean: Option<GpuRef<T>>,
29    pub running_var: Option<GpuRef<T>>,
30    pub saved_mean: Option<GpuRef<T>>,
31    pub saved_var: Option<GpuRef<T>>,
32    pub dims: Vec<i64>,
33    pub layout: TensorLayout,
34    pub epsilon: f64,
35    pub exp_avg_factor: f64,
36    pub reply: oneshot::Sender<Result<(), GpuError>>,
37    pub _ty: PhantomData<T>,
38}
39
40impl<T: CudnnSupported> BatchNormRequest<T> {
41    pub fn graph_spec(&self) -> OperationGraphSpec {
42        build_norm_fwd_graph(
43            NormMode::BatchNorm,
44            self.phase,
45            dtype_tag::<T>(),
46            &self.dims,
47            self.layout,
48            self.epsilon,
49            self.exp_avg_factor,
50        )
51    }
52}
53
54impl<T: CudnnSupported> CudnnDispatch for BatchNormRequest<T> {
55    fn dtype_name(&self) -> &'static str {
56        T::NAME
57    }
58    fn op_kind(&self) -> &'static str {
59        "batchnorm"
60    }
61    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
62        let _ = self.reply.send(Err(GpuError::LibraryError {
63            lib: "cudnn",
64            msg: "BatchNormRequest dispatch requires the v9 frontend graph builder; \
65                  skeleton entry point only"
66                .to_string(),
67        }));
68    }
69}
70
71/// LayerNorm — normalises across the trailing axes, scale + bias
72/// applied per-feature.
73pub struct LayerNormRequest<T: CudnnSupported> {
74    pub x: GpuRef<T>,
75    pub y: GpuRef<T>,
76    pub scale: GpuRef<T>,
77    pub bias: GpuRef<T>,
78    pub saved_mean: Option<GpuRef<T>>,
79    pub saved_var: Option<GpuRef<T>>,
80    pub dims: Vec<i64>,
81    pub norm_axes: Vec<i64>,
82    pub layout: TensorLayout,
83    pub epsilon: f64,
84    pub reply: oneshot::Sender<Result<(), GpuError>>,
85    pub _ty: PhantomData<T>,
86}
87
88impl<T: CudnnSupported> LayerNormRequest<T> {
89    pub fn graph_spec(&self) -> OperationGraphSpec {
90        build_norm_fwd_graph(
91            NormMode::LayerNorm,
92            NormPhase::Training,
93            dtype_tag::<T>(),
94            &self.dims,
95            self.layout,
96            self.epsilon,
97            0.0,
98        )
99    }
100}
101
102impl<T: CudnnSupported> CudnnDispatch for LayerNormRequest<T> {
103    fn dtype_name(&self) -> &'static str {
104        T::NAME
105    }
106    fn op_kind(&self) -> &'static str {
107        "layernorm"
108    }
109    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
110        let _ = self.reply.send(Err(GpuError::LibraryError {
111            lib: "cudnn",
112            msg: "LayerNormRequest dispatch requires the v9 frontend graph builder; \
113                  skeleton entry point only"
114                .to_string(),
115        }));
116    }
117}
118
119/// InstanceNorm — normalises per-sample, per-channel.
120pub struct InstanceNormRequest<T: CudnnSupported> {
121    pub x: GpuRef<T>,
122    pub y: GpuRef<T>,
123    pub scale: GpuRef<T>,
124    pub bias: GpuRef<T>,
125    pub saved_mean: Option<GpuRef<T>>,
126    pub saved_var: Option<GpuRef<T>>,
127    pub dims: Vec<i64>,
128    pub layout: TensorLayout,
129    pub epsilon: f64,
130    pub reply: oneshot::Sender<Result<(), GpuError>>,
131    pub _ty: PhantomData<T>,
132}
133
134impl<T: CudnnSupported> InstanceNormRequest<T> {
135    pub fn graph_spec(&self) -> OperationGraphSpec {
136        build_norm_fwd_graph(
137            NormMode::InstanceNorm,
138            NormPhase::Training,
139            dtype_tag::<T>(),
140            &self.dims,
141            self.layout,
142            self.epsilon,
143            0.0,
144        )
145    }
146}
147
148impl<T: CudnnSupported> CudnnDispatch for InstanceNormRequest<T> {
149    fn dtype_name(&self) -> &'static str {
150        T::NAME
151    }
152    fn op_kind(&self) -> &'static str {
153        "instancenorm"
154    }
155    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
156        let _ = self.reply.send(Err(GpuError::LibraryError {
157            lib: "cudnn",
158            msg: "InstanceNormRequest dispatch requires the v9 frontend graph builder; \
159                  skeleton entry point only"
160                .to_string(),
161        }));
162    }
163}
164
165/// GroupNorm — generalisation of BatchNorm/LayerNorm/InstanceNorm
166/// with `groups` channel partitions.
167pub struct GroupNormRequest<T: CudnnSupported> {
168    pub x: GpuRef<T>,
169    pub y: GpuRef<T>,
170    pub scale: GpuRef<T>,
171    pub bias: GpuRef<T>,
172    pub saved_mean: Option<GpuRef<T>>,
173    pub saved_var: Option<GpuRef<T>>,
174    pub dims: Vec<i64>,
175    pub groups: u32,
176    pub layout: TensorLayout,
177    pub epsilon: f64,
178    pub reply: oneshot::Sender<Result<(), GpuError>>,
179    pub _ty: PhantomData<T>,
180}
181
182impl<T: CudnnSupported> CudnnDispatch for GroupNormRequest<T> {
183    fn dtype_name(&self) -> &'static str {
184        T::NAME
185    }
186    fn op_kind(&self) -> &'static str {
187        "groupnorm"
188    }
189    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
190        let _ = self.reply.send(Err(GpuError::LibraryError {
191            lib: "cudnn",
192            msg: "GroupNormRequest dispatch requires the v9 frontend graph builder; \
193                  skeleton entry point only"
194                .to_string(),
195        }));
196    }
197}
198
199/// Norm backward request, applies to BatchNorm / LayerNorm /
200/// InstanceNorm / GroupNorm uniformly via `mode`.
201pub struct NormBwdRequest<T: CudnnSupported> {
202    pub mode: NormMode,
203    pub x: GpuRef<T>,
204    pub dy: GpuRef<T>,
205    pub scale: GpuRef<T>,
206    pub mean: GpuRef<T>,
207    pub var: GpuRef<T>,
208    pub dx: GpuRef<T>,
209    pub dscale: GpuRef<T>,
210    pub dbias: GpuRef<T>,
211    pub dims: Vec<i64>,
212    pub layout: TensorLayout,
213    pub epsilon: f64,
214    pub reply: oneshot::Sender<Result<(), GpuError>>,
215    pub _ty: PhantomData<T>,
216}
217
218impl<T: CudnnSupported> CudnnDispatch for NormBwdRequest<T> {
219    fn dtype_name(&self) -> &'static str {
220        T::NAME
221    }
222    fn op_kind(&self) -> &'static str {
223        match self.mode {
224            NormMode::BatchNorm => "batchnorm_bwd",
225            NormMode::LayerNorm => "layernorm_bwd",
226            NormMode::InstanceNorm => "instancenorm_bwd",
227            NormMode::GroupNorm => "groupnorm_bwd",
228            NormMode::RmsNorm => "rmsnorm_bwd",
229        }
230    }
231    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
232        let _ = self.reply.send(Err(GpuError::LibraryError {
233            lib: "cudnn",
234            msg: "NormBwdRequest dispatch requires the v9 frontend graph builder; \
235                  skeleton entry point only"
236                .to_string(),
237        }));
238    }
239}
240
241/// Build the spec-side norm-fwd op graph.
242pub fn build_norm_fwd_graph(
243    mode: NormMode,
244    phase: NormPhase,
245    dtype: DtypeTag,
246    dims: &[i64],
247    layout: TensorLayout,
248    epsilon: f64,
249    exp_avg_factor: f64,
250) -> OperationGraphSpec {
251    let mut g = OperationGraphSpec::new("norm_fwd");
252    let x_uid = g.add_tensor(TensorSpec::new(1, dtype, dims.to_vec(), layout));
253    let scale_uid = g.add_tensor(TensorSpec::new(2, dtype, vec![1, dims[1], 1, 1], layout));
254    let bias_uid = g.add_tensor(TensorSpec::new(3, dtype, vec![1, dims[1], 1, 1], layout));
255    let y_uid = g.add_tensor(TensorSpec::new(4, dtype, dims.to_vec(), layout));
256    g.add_op(OpSpec::NormFwd {
257        mode,
258        phase,
259        x: x_uid,
260        scale: scale_uid,
261        bias: bias_uid,
262        mean: None,
263        var: None,
264        y: y_uid,
265        compute_dtype: dtype,
266        epsilon,
267        exp_avg_factor,
268    });
269    g
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn batchnorm_layernorm_instancenorm_round_trip() {
278        let bn = build_norm_fwd_graph(
279            NormMode::BatchNorm,
280            NormPhase::Training,
281            DtypeTag::F32,
282            &[2, 3, 4, 4],
283            TensorLayout::NchwPacked,
284            1e-5,
285            0.1,
286        );
287        assert_eq!(bn.ops.len(), 1);
288        match &bn.ops[0] {
289            OpSpec::NormFwd { mode, phase, .. } => {
290                assert_eq!(*mode, NormMode::BatchNorm);
291                assert_eq!(*phase, NormPhase::Training);
292            }
293            _ => panic!("wrong op"),
294        }
295
296        let ln = build_norm_fwd_graph(
297            NormMode::LayerNorm,
298            NormPhase::Training,
299            DtypeTag::F32,
300            &[2, 3, 4, 4],
301            TensorLayout::NchwPacked,
302            1e-5,
303            0.0,
304        );
305        assert_ne!(bn.signature(), ln.signature());
306
307        let in_ = build_norm_fwd_graph(
308            NormMode::InstanceNorm,
309            NormPhase::Training,
310            DtypeTag::F32,
311            &[2, 3, 4, 4],
312            TensorLayout::NchwPacked,
313            1e-5,
314            0.0,
315        );
316        assert_ne!(ln.signature(), in_.signature());
317
318        // Persistent batchnorm has its own phase signature.
319        let bn_persist = build_norm_fwd_graph(
320            NormMode::BatchNorm,
321            NormPhase::PersistentTraining,
322            DtypeTag::F32,
323            &[2, 3, 4, 4],
324            TensorLayout::NchwPacked,
325            1e-5,
326            0.1,
327        );
328        assert_ne!(bn.signature(), bn_persist.signature());
329    }
330}