1#![allow(dead_code)]
15
16use std::marker::PhantomData;
17
18use tokio::sync::oneshot;
19
20use crate::dtype::CudnnSupported;
21use crate::error::GpuError;
22use crate::gpu_ref::GpuRef;
23use crate::kernel::cudnn::activation::ActivationKind;
24use crate::kernel::cudnn::graph::{
25 DtypeTag, OpSpec, OperationGraphSpec, PointwiseMode, TensorLayout, TensorSpec,
26};
27use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct ConvDescParams {
32 pub spatial_dims: usize,
34 pub pre_padding: Vec<i64>,
36 pub post_padding: Vec<i64>,
38 pub stride: Vec<i64>,
40 pub dilation: Vec<i64>,
42 pub groups: i64,
44}
45
46impl ConvDescParams {
47 pub fn symmetric_2d(pad: i64, stride: i64, dilation: i64) -> Self {
49 Self {
50 spatial_dims: 2,
51 pre_padding: vec![pad, pad],
52 post_padding: vec![pad, pad],
53 stride: vec![stride, stride],
54 dilation: vec![dilation, dilation],
55 groups: 1,
56 }
57 }
58
59 pub fn symmetric_1d(pad: i64, stride: i64, dilation: i64) -> Self {
61 Self {
62 spatial_dims: 1,
63 pre_padding: vec![pad],
64 post_padding: vec![pad],
65 stride: vec![stride],
66 dilation: vec![dilation],
67 groups: 1,
68 }
69 }
70
71 pub fn symmetric_3d(pad: i64, stride: i64, dilation: i64) -> Self {
73 Self {
74 spatial_dims: 3,
75 pre_padding: vec![pad, pad, pad],
76 post_padding: vec![pad, pad, pad],
77 stride: vec![stride, stride, stride],
78 dilation: vec![dilation, dilation, dilation],
79 groups: 1,
80 }
81 }
82
83 pub fn with_groups(mut self, g: i64) -> Self {
84 self.groups = g;
85 self
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum EpilogueKind {
95 None,
97 Bias,
99 BiasActivation(ActivationKind),
101}
102
103pub(crate) fn dtype_tag<T: CudnnSupported>() -> DtypeTag {
104 match T::NAME {
105 "f32" => DtypeTag::F32,
106 "f64" => DtypeTag::F64,
107 "f16" => DtypeTag::F16,
108 "bf16" => DtypeTag::Bf16,
109 "i8" => DtypeTag::I8,
110 other => panic!("unsupported cuDNN dtype name: {other}"),
111 }
112}
113
114pub fn build_conv_fwd_graph(
118 dtype: DtypeTag,
119 x_dims: &[i64],
120 w_dims: &[i64],
121 y_dims: &[i64],
122 conv: &ConvDescParams,
123 layout: TensorLayout,
124 epilogue: EpilogueKind,
125) -> OperationGraphSpec {
126 let mut g = OperationGraphSpec::new("conv_fwd");
127 let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
128 let w_uid = g.add_tensor(TensorSpec::new(2, dtype, w_dims.to_vec(), layout));
129 let y_uid = g.add_tensor(TensorSpec::new(3, dtype, y_dims.to_vec(), layout));
130 g.add_op(OpSpec::ConvFwd {
131 x: x_uid,
132 w: w_uid,
133 y: y_uid,
134 spatial_dims: conv.spatial_dims,
135 pre_padding: conv.pre_padding.clone(),
136 post_padding: conv.post_padding.clone(),
137 stride: conv.stride.clone(),
138 dilation: conv.dilation.clone(),
139 compute_dtype: dtype,
140 alpha: 1.0,
141 beta: 0.0,
142 });
143 match epilogue {
144 EpilogueKind::None => {}
145 EpilogueKind::Bias => {
146 let b_uid = g.add_tensor(TensorSpec::new(4, dtype, bias_dims(y_dims), layout));
147 let yb_uid = g.add_tensor(TensorSpec::new(5, dtype, y_dims.to_vec(), layout));
148 g.add_op(OpSpec::Pointwise {
149 mode: PointwiseMode::Add,
150 x: y_uid,
151 b: Some(b_uid),
152 y: yb_uid,
153 compute_dtype: dtype,
154 alpha1: 1.0,
155 alpha2: 1.0,
156 });
157 }
158 EpilogueKind::BiasActivation(act) => {
159 let b_uid = g.add_tensor(TensorSpec::new(4, dtype, bias_dims(y_dims), layout));
160 let yb_uid = g.add_tensor(TensorSpec::new(5, dtype, y_dims.to_vec(), layout));
161 g.add_op(OpSpec::Pointwise {
162 mode: PointwiseMode::Add,
163 x: y_uid,
164 b: Some(b_uid),
165 y: yb_uid,
166 compute_dtype: dtype,
167 alpha1: 1.0,
168 alpha2: 1.0,
169 });
170 let act_out = g.add_tensor(TensorSpec::new(6, dtype, y_dims.to_vec(), layout));
171 g.add_op(OpSpec::Pointwise {
172 mode: act.pointwise_mode(),
173 x: yb_uid,
174 b: None,
175 y: act_out,
176 compute_dtype: dtype,
177 alpha1: 1.0,
178 alpha2: 0.0,
179 });
180 }
181 }
182 g
183}
184
185pub fn build_conv_bwd_data_graph(
187 dtype: DtypeTag,
188 dy_dims: &[i64],
189 w_dims: &[i64],
190 dx_dims: &[i64],
191 conv: &ConvDescParams,
192 layout: TensorLayout,
193) -> OperationGraphSpec {
194 let mut g = OperationGraphSpec::new("conv_bwd_data");
195 let dy_uid = g.add_tensor(TensorSpec::new(1, dtype, dy_dims.to_vec(), layout));
196 let w_uid = g.add_tensor(TensorSpec::new(2, dtype, w_dims.to_vec(), layout));
197 let dx_uid = g.add_tensor(TensorSpec::new(3, dtype, dx_dims.to_vec(), layout));
198 g.add_op(OpSpec::ConvBwdData {
199 dy: dy_uid,
200 w: w_uid,
201 dx: dx_uid,
202 spatial_dims: conv.spatial_dims,
203 pre_padding: conv.pre_padding.clone(),
204 post_padding: conv.post_padding.clone(),
205 stride: conv.stride.clone(),
206 dilation: conv.dilation.clone(),
207 compute_dtype: dtype,
208 alpha: 1.0,
209 beta: 0.0,
210 });
211 g
212}
213
214pub fn build_conv_bwd_filter_graph(
216 dtype: DtypeTag,
217 x_dims: &[i64],
218 dy_dims: &[i64],
219 dw_dims: &[i64],
220 conv: &ConvDescParams,
221 layout: TensorLayout,
222) -> OperationGraphSpec {
223 let mut g = OperationGraphSpec::new("conv_bwd_filter");
224 let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
225 let dy_uid = g.add_tensor(TensorSpec::new(2, dtype, dy_dims.to_vec(), layout));
226 let dw_uid = g.add_tensor(TensorSpec::new(3, dtype, dw_dims.to_vec(), layout));
227 g.add_op(OpSpec::ConvBwdFilter {
228 x: x_uid,
229 dy: dy_uid,
230 dw: dw_uid,
231 spatial_dims: conv.spatial_dims,
232 pre_padding: conv.pre_padding.clone(),
233 post_padding: conv.post_padding.clone(),
234 stride: conv.stride.clone(),
235 dilation: conv.dilation.clone(),
236 compute_dtype: dtype,
237 alpha: 1.0,
238 beta: 0.0,
239 });
240 g
241}
242
243fn bias_dims(y_dims: &[i64]) -> Vec<i64> {
247 let mut out = vec![1i64; y_dims.len()];
248 if y_dims.len() >= 2 {
249 out[1] = y_dims[1];
250 }
251 out
252}
253
254pub struct ConvFwdRequest<T: CudnnSupported> {
259 pub x: GpuRef<T>,
260 pub x_dims: Vec<i64>,
261 pub w: GpuRef<T>,
262 pub w_dims: Vec<i64>,
263 pub y: GpuRef<T>,
264 pub y_dims: Vec<i64>,
265 pub bias: Option<GpuRef<T>>,
266 pub conv: ConvDescParams,
267 pub layout: TensorLayout,
268 pub epilogue: EpilogueKind,
269 pub alpha: T::Scalar,
270 pub beta: T::Scalar,
271 pub reply: oneshot::Sender<Result<(), GpuError>>,
272 pub _ty: PhantomData<T>,
273}
274
275impl<T: CudnnSupported> ConvFwdRequest<T> {
276 pub fn graph_spec(&self) -> OperationGraphSpec {
277 build_conv_fwd_graph(
278 dtype_tag::<T>(),
279 &self.x_dims,
280 &self.w_dims,
281 &self.y_dims,
282 &self.conv,
283 self.layout,
284 self.epilogue,
285 )
286 }
287}
288
289impl<T: CudnnSupported> CudnnDispatch for ConvFwdRequest<T> {
290 fn dtype_name(&self) -> &'static str {
291 T::NAME
292 }
293 fn op_kind(&self) -> &'static str {
294 "conv_fwd"
295 }
296 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
297 let _ = self.reply.send(Err(GpuError::LibraryError {
298 lib: "cudnn",
299 msg: "ConvFwdRequest dispatch requires the v9 frontend graph builder \
300 (cudnnBackendCreateDescriptor path); skeleton entry point only"
301 .to_string(),
302 }));
303 }
304}
305
306pub struct ConvBwdDataRequest<T: CudnnSupported> {
308 pub dy: GpuRef<T>,
309 pub dy_dims: Vec<i64>,
310 pub w: GpuRef<T>,
311 pub w_dims: Vec<i64>,
312 pub dx: GpuRef<T>,
313 pub dx_dims: Vec<i64>,
314 pub conv: ConvDescParams,
315 pub layout: TensorLayout,
316 pub alpha: T::Scalar,
317 pub beta: T::Scalar,
318 pub reply: oneshot::Sender<Result<(), GpuError>>,
319 pub _ty: PhantomData<T>,
320}
321
322impl<T: CudnnSupported> ConvBwdDataRequest<T> {
323 pub fn graph_spec(&self) -> OperationGraphSpec {
324 build_conv_bwd_data_graph(
325 dtype_tag::<T>(),
326 &self.dy_dims,
327 &self.w_dims,
328 &self.dx_dims,
329 &self.conv,
330 self.layout,
331 )
332 }
333}
334
335impl<T: CudnnSupported> CudnnDispatch for ConvBwdDataRequest<T> {
336 fn dtype_name(&self) -> &'static str {
337 T::NAME
338 }
339 fn op_kind(&self) -> &'static str {
340 "conv_bwd_data"
341 }
342 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
343 let _ = self.reply.send(Err(GpuError::LibraryError {
344 lib: "cudnn",
345 msg: "ConvBwdDataRequest dispatch requires the v9 frontend graph builder; \
346 skeleton entry point only"
347 .to_string(),
348 }));
349 }
350}
351
352pub struct ConvBwdFilterRequest<T: CudnnSupported> {
354 pub x: GpuRef<T>,
355 pub x_dims: Vec<i64>,
356 pub dy: GpuRef<T>,
357 pub dy_dims: Vec<i64>,
358 pub dw: GpuRef<T>,
359 pub dw_dims: Vec<i64>,
360 pub conv: ConvDescParams,
361 pub layout: TensorLayout,
362 pub alpha: T::Scalar,
363 pub beta: T::Scalar,
364 pub reply: oneshot::Sender<Result<(), GpuError>>,
365 pub _ty: PhantomData<T>,
366}
367
368impl<T: CudnnSupported> ConvBwdFilterRequest<T> {
369 pub fn graph_spec(&self) -> OperationGraphSpec {
370 build_conv_bwd_filter_graph(
371 dtype_tag::<T>(),
372 &self.x_dims,
373 &self.dy_dims,
374 &self.dw_dims,
375 &self.conv,
376 self.layout,
377 )
378 }
379}
380
381impl<T: CudnnSupported> CudnnDispatch for ConvBwdFilterRequest<T> {
382 fn dtype_name(&self) -> &'static str {
383 T::NAME
384 }
385 fn op_kind(&self) -> &'static str {
386 "conv_bwd_filter"
387 }
388 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
389 let _ = self.reply.send(Err(GpuError::LibraryError {
390 lib: "cudnn",
391 msg: "ConvBwdFilterRequest dispatch requires the v9 frontend graph builder; \
392 skeleton entry point only"
393 .to_string(),
394 }));
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::kernel::cudnn::graph::cache_key;
402
403 fn round_trip_fwd(dt: DtypeTag, dt_name: &'static str, layout: TensorLayout) {
404 let g = build_conv_fwd_graph(
405 dt,
406 &[1, 3, 8, 8],
407 &[16, 3, 3, 3],
408 &[1, 16, 6, 6],
409 &ConvDescParams::symmetric_2d(0, 1, 1),
410 layout,
411 EpilogueKind::None,
412 );
413 assert_eq!(g.tensors.len(), 3);
414 assert_eq!(g.ops.len(), 1);
415 let key = cache_key("conv_fwd", dt, &g);
416 assert_eq!(key.op_kind, "conv_fwd");
417 assert_eq!(key.dtype, dt);
418 let g2 = build_conv_fwd_graph(
420 dt,
421 &[1, 3, 8, 8],
422 &[16, 3, 3, 3],
423 &[1, 16, 6, 6],
424 &ConvDescParams::symmetric_2d(0, 1, 1),
425 layout,
426 EpilogueKind::None,
427 );
428 assert_eq!(g.signature(), g2.signature());
429 assert_eq!(dt.name(), dt_name);
430 }
431
432 #[test]
433 fn conv_fwd_request_round_trip_f32_f64_f16_bf16() {
434 round_trip_fwd(DtypeTag::F32, "f32", TensorLayout::NchwPacked);
435 round_trip_fwd(DtypeTag::F64, "f64", TensorLayout::NchwPacked);
436 round_trip_fwd(DtypeTag::F16, "f16", TensorLayout::NchwPacked);
437 round_trip_fwd(DtypeTag::Bf16, "bf16", TensorLayout::NchwPacked);
438 round_trip_fwd(DtypeTag::F32, "f32", TensorLayout::NhwcPacked);
440 }
441
442 #[test]
443 fn conv_bwd_data_filter_request_round_trip() {
444 let g = build_conv_bwd_data_graph(
445 DtypeTag::F32,
446 &[1, 16, 6, 6],
447 &[16, 3, 3, 3],
448 &[1, 3, 8, 8],
449 &ConvDescParams::symmetric_2d(0, 1, 1),
450 TensorLayout::NchwPacked,
451 );
452 assert_eq!(g.ops.len(), 1);
453 match &g.ops[0] {
454 OpSpec::ConvBwdData { spatial_dims, .. } => assert_eq!(*spatial_dims, 2),
455 _ => panic!("wrong op"),
456 }
457
458 let g = build_conv_bwd_filter_graph(
459 DtypeTag::F32,
460 &[1, 3, 8, 8],
461 &[1, 16, 6, 6],
462 &[16, 3, 3, 3],
463 &ConvDescParams::symmetric_2d(0, 1, 1),
464 TensorLayout::NchwPacked,
465 );
466 match &g.ops[0] {
467 OpSpec::ConvBwdFilter { spatial_dims, .. } => assert_eq!(*spatial_dims, 2),
468 _ => panic!("wrong op"),
469 }
470 }
471
472 #[test]
473 fn nchw_vs_nhwc_layout_handled() {
474 let g_nchw = build_conv_fwd_graph(
475 DtypeTag::F32,
476 &[1, 3, 8, 8],
477 &[16, 3, 3, 3],
478 &[1, 16, 6, 6],
479 &ConvDescParams::symmetric_2d(0, 1, 1),
480 TensorLayout::NchwPacked,
481 EpilogueKind::None,
482 );
483 let g_nhwc = build_conv_fwd_graph(
484 DtypeTag::F32,
485 &[1, 3, 8, 8],
486 &[16, 3, 3, 3],
487 &[1, 16, 6, 6],
488 &ConvDescParams::symmetric_2d(0, 1, 1),
489 TensorLayout::NhwcPacked,
490 EpilogueKind::None,
491 );
492 assert_ne!(g_nchw.signature(), g_nhwc.signature());
493 assert_eq!(g_nchw.tensors[0].strides, vec![192, 64, 8, 1]);
494 assert_ne!(g_nhwc.tensors[0].strides, g_nchw.tensors[0].strides);
495 }
496
497 #[test]
498 fn conv_fwd_with_bias_activation_epilogue() {
499 let g = build_conv_fwd_graph(
500 DtypeTag::F32,
501 &[1, 3, 8, 8],
502 &[16, 3, 3, 3],
503 &[1, 16, 6, 6],
504 &ConvDescParams::symmetric_2d(0, 1, 1),
505 TensorLayout::NhwcPacked,
506 EpilogueKind::BiasActivation(ActivationKind::Relu),
507 );
508 assert_eq!(g.ops.len(), 3);
510 assert_eq!(g.tensors.len(), 6);
511 }
512
513 #[test]
514 fn conv_1d_and_3d_descriptor_params() {
515 let p1 = ConvDescParams::symmetric_1d(1, 1, 1);
516 assert_eq!(p1.spatial_dims, 1);
517 assert_eq!(p1.stride.len(), 1);
518 let p3 = ConvDescParams::symmetric_3d(1, 2, 1);
519 assert_eq!(p3.spatial_dims, 3);
520 assert_eq!(p3.stride, vec![2, 2, 2]);
521 }
522
523 #[test]
524 fn conv_grouped() {
525 let p = ConvDescParams::symmetric_2d(0, 1, 1).with_groups(8);
526 assert_eq!(p.groups, 8);
527 }
528}