1#![allow(dead_code)]
6
7use std::marker::PhantomData;
8
9use tokio::sync::oneshot;
10
11use crate::dtype::CudnnSupported;
12use crate::error::GpuError;
13use crate::gpu_ref::GpuRef;
14use crate::kernel::cudnn::conv::dtype_tag;
15use crate::kernel::cudnn::graph::{
16 DtypeTag, NormMode, NormPhase, OpSpec, OperationGraphSpec, TensorLayout, TensorSpec,
17};
18use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
19
20pub struct BatchNormRequest<T: CudnnSupported> {
23 pub phase: NormPhase,
24 pub x: GpuRef<T>,
25 pub y: GpuRef<T>,
26 pub scale: GpuRef<T>,
27 pub bias: GpuRef<T>,
28 pub running_mean: Option<GpuRef<T>>,
29 pub running_var: Option<GpuRef<T>>,
30 pub saved_mean: Option<GpuRef<T>>,
31 pub saved_var: Option<GpuRef<T>>,
32 pub dims: Vec<i64>,
33 pub layout: TensorLayout,
34 pub epsilon: f64,
35 pub exp_avg_factor: f64,
36 pub reply: oneshot::Sender<Result<(), GpuError>>,
37 pub _ty: PhantomData<T>,
38}
39
40impl<T: CudnnSupported> BatchNormRequest<T> {
41 pub fn graph_spec(&self) -> OperationGraphSpec {
42 build_norm_fwd_graph(
43 NormMode::BatchNorm,
44 self.phase,
45 dtype_tag::<T>(),
46 &self.dims,
47 self.layout,
48 self.epsilon,
49 self.exp_avg_factor,
50 )
51 }
52}
53
54impl<T: CudnnSupported> CudnnDispatch for BatchNormRequest<T> {
55 fn dtype_name(&self) -> &'static str {
56 T::NAME
57 }
58 fn op_kind(&self) -> &'static str {
59 "batchnorm"
60 }
61 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
62 let _ = self.reply.send(Err(GpuError::LibraryError {
63 lib: "cudnn",
64 msg: "BatchNormRequest dispatch requires the v9 frontend graph builder; \
65 skeleton entry point only"
66 .to_string(),
67 }));
68 }
69}
70
71pub struct LayerNormRequest<T: CudnnSupported> {
74 pub x: GpuRef<T>,
75 pub y: GpuRef<T>,
76 pub scale: GpuRef<T>,
77 pub bias: GpuRef<T>,
78 pub saved_mean: Option<GpuRef<T>>,
79 pub saved_var: Option<GpuRef<T>>,
80 pub dims: Vec<i64>,
81 pub norm_axes: Vec<i64>,
82 pub layout: TensorLayout,
83 pub epsilon: f64,
84 pub reply: oneshot::Sender<Result<(), GpuError>>,
85 pub _ty: PhantomData<T>,
86}
87
88impl<T: CudnnSupported> LayerNormRequest<T> {
89 pub fn graph_spec(&self) -> OperationGraphSpec {
90 build_norm_fwd_graph(
91 NormMode::LayerNorm,
92 NormPhase::Training,
93 dtype_tag::<T>(),
94 &self.dims,
95 self.layout,
96 self.epsilon,
97 0.0,
98 )
99 }
100}
101
102impl<T: CudnnSupported> CudnnDispatch for LayerNormRequest<T> {
103 fn dtype_name(&self) -> &'static str {
104 T::NAME
105 }
106 fn op_kind(&self) -> &'static str {
107 "layernorm"
108 }
109 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
110 let _ = self.reply.send(Err(GpuError::LibraryError {
111 lib: "cudnn",
112 msg: "LayerNormRequest dispatch requires the v9 frontend graph builder; \
113 skeleton entry point only"
114 .to_string(),
115 }));
116 }
117}
118
119pub struct InstanceNormRequest<T: CudnnSupported> {
121 pub x: GpuRef<T>,
122 pub y: GpuRef<T>,
123 pub scale: GpuRef<T>,
124 pub bias: GpuRef<T>,
125 pub saved_mean: Option<GpuRef<T>>,
126 pub saved_var: Option<GpuRef<T>>,
127 pub dims: Vec<i64>,
128 pub layout: TensorLayout,
129 pub epsilon: f64,
130 pub reply: oneshot::Sender<Result<(), GpuError>>,
131 pub _ty: PhantomData<T>,
132}
133
134impl<T: CudnnSupported> InstanceNormRequest<T> {
135 pub fn graph_spec(&self) -> OperationGraphSpec {
136 build_norm_fwd_graph(
137 NormMode::InstanceNorm,
138 NormPhase::Training,
139 dtype_tag::<T>(),
140 &self.dims,
141 self.layout,
142 self.epsilon,
143 0.0,
144 )
145 }
146}
147
148impl<T: CudnnSupported> CudnnDispatch for InstanceNormRequest<T> {
149 fn dtype_name(&self) -> &'static str {
150 T::NAME
151 }
152 fn op_kind(&self) -> &'static str {
153 "instancenorm"
154 }
155 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
156 let _ = self.reply.send(Err(GpuError::LibraryError {
157 lib: "cudnn",
158 msg: "InstanceNormRequest dispatch requires the v9 frontend graph builder; \
159 skeleton entry point only"
160 .to_string(),
161 }));
162 }
163}
164
165pub struct GroupNormRequest<T: CudnnSupported> {
168 pub x: GpuRef<T>,
169 pub y: GpuRef<T>,
170 pub scale: GpuRef<T>,
171 pub bias: GpuRef<T>,
172 pub saved_mean: Option<GpuRef<T>>,
173 pub saved_var: Option<GpuRef<T>>,
174 pub dims: Vec<i64>,
175 pub groups: u32,
176 pub layout: TensorLayout,
177 pub epsilon: f64,
178 pub reply: oneshot::Sender<Result<(), GpuError>>,
179 pub _ty: PhantomData<T>,
180}
181
182impl<T: CudnnSupported> CudnnDispatch for GroupNormRequest<T> {
183 fn dtype_name(&self) -> &'static str {
184 T::NAME
185 }
186 fn op_kind(&self) -> &'static str {
187 "groupnorm"
188 }
189 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
190 let _ = self.reply.send(Err(GpuError::LibraryError {
191 lib: "cudnn",
192 msg: "GroupNormRequest dispatch requires the v9 frontend graph builder; \
193 skeleton entry point only"
194 .to_string(),
195 }));
196 }
197}
198
199pub struct NormBwdRequest<T: CudnnSupported> {
202 pub mode: NormMode,
203 pub x: GpuRef<T>,
204 pub dy: GpuRef<T>,
205 pub scale: GpuRef<T>,
206 pub mean: GpuRef<T>,
207 pub var: GpuRef<T>,
208 pub dx: GpuRef<T>,
209 pub dscale: GpuRef<T>,
210 pub dbias: GpuRef<T>,
211 pub dims: Vec<i64>,
212 pub layout: TensorLayout,
213 pub epsilon: f64,
214 pub reply: oneshot::Sender<Result<(), GpuError>>,
215 pub _ty: PhantomData<T>,
216}
217
218impl<T: CudnnSupported> CudnnDispatch for NormBwdRequest<T> {
219 fn dtype_name(&self) -> &'static str {
220 T::NAME
221 }
222 fn op_kind(&self) -> &'static str {
223 match self.mode {
224 NormMode::BatchNorm => "batchnorm_bwd",
225 NormMode::LayerNorm => "layernorm_bwd",
226 NormMode::InstanceNorm => "instancenorm_bwd",
227 NormMode::GroupNorm => "groupnorm_bwd",
228 NormMode::RmsNorm => "rmsnorm_bwd",
229 }
230 }
231 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
232 let _ = self.reply.send(Err(GpuError::LibraryError {
233 lib: "cudnn",
234 msg: "NormBwdRequest dispatch requires the v9 frontend graph builder; \
235 skeleton entry point only"
236 .to_string(),
237 }));
238 }
239}
240
241pub fn build_norm_fwd_graph(
243 mode: NormMode,
244 phase: NormPhase,
245 dtype: DtypeTag,
246 dims: &[i64],
247 layout: TensorLayout,
248 epsilon: f64,
249 exp_avg_factor: f64,
250) -> OperationGraphSpec {
251 let mut g = OperationGraphSpec::new("norm_fwd");
252 let x_uid = g.add_tensor(TensorSpec::new(1, dtype, dims.to_vec(), layout));
253 let scale_uid = g.add_tensor(TensorSpec::new(2, dtype, vec![1, dims[1], 1, 1], layout));
254 let bias_uid = g.add_tensor(TensorSpec::new(3, dtype, vec![1, dims[1], 1, 1], layout));
255 let y_uid = g.add_tensor(TensorSpec::new(4, dtype, dims.to_vec(), layout));
256 g.add_op(OpSpec::NormFwd {
257 mode,
258 phase,
259 x: x_uid,
260 scale: scale_uid,
261 bias: bias_uid,
262 mean: None,
263 var: None,
264 y: y_uid,
265 compute_dtype: dtype,
266 epsilon,
267 exp_avg_factor,
268 });
269 g
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn batchnorm_layernorm_instancenorm_round_trip() {
278 let bn = build_norm_fwd_graph(
279 NormMode::BatchNorm,
280 NormPhase::Training,
281 DtypeTag::F32,
282 &[2, 3, 4, 4],
283 TensorLayout::NchwPacked,
284 1e-5,
285 0.1,
286 );
287 assert_eq!(bn.ops.len(), 1);
288 match &bn.ops[0] {
289 OpSpec::NormFwd { mode, phase, .. } => {
290 assert_eq!(*mode, NormMode::BatchNorm);
291 assert_eq!(*phase, NormPhase::Training);
292 }
293 _ => panic!("wrong op"),
294 }
295
296 let ln = build_norm_fwd_graph(
297 NormMode::LayerNorm,
298 NormPhase::Training,
299 DtypeTag::F32,
300 &[2, 3, 4, 4],
301 TensorLayout::NchwPacked,
302 1e-5,
303 0.0,
304 );
305 assert_ne!(bn.signature(), ln.signature());
306
307 let in_ = build_norm_fwd_graph(
308 NormMode::InstanceNorm,
309 NormPhase::Training,
310 DtypeTag::F32,
311 &[2, 3, 4, 4],
312 TensorLayout::NchwPacked,
313 1e-5,
314 0.0,
315 );
316 assert_ne!(ln.signature(), in_.signature());
317
318 let bn_persist = build_norm_fwd_graph(
320 NormMode::BatchNorm,
321 NormPhase::PersistentTraining,
322 DtypeTag::F32,
323 &[2, 3, 4, 4],
324 TensorLayout::NchwPacked,
325 1e-5,
326 0.1,
327 );
328 assert_ne!(bn.signature(), bn_persist.signature());
329 }
330}