Skip to main content

atomr_accel_cuda/kernel/cudnn/
conv.rs

1//! Convolution requests for the cuDNN actor (Phase 2 frontend graph
2//! API).
3//!
4//! Three op families:
5//!
6//! * [`ConvFwdRequest<T>`] — y = conv(x, w) [+ optional bias + optional
7//!   activation, when fused via `epilogue`].
8//! * [`ConvBwdDataRequest<T>`] — dx = conv_bwd_data(w, dy).
9//! * [`ConvBwdFilterRequest<T>`] — dw = conv_bwd_filter(x, dy).
10//!
11//! Each supports 1D / 2D / 3D, NCHW + NHWC packed layouts (or fully
12//! strided), arbitrary group count, and arbitrary dilation.
13
14#![allow(dead_code)]
15
16use std::marker::PhantomData;
17
18use tokio::sync::oneshot;
19
20use crate::dtype::CudnnSupported;
21use crate::error::GpuError;
22use crate::gpu_ref::GpuRef;
23use crate::kernel::cudnn::activation::ActivationKind;
24use crate::kernel::cudnn::graph::{
25    DtypeTag, OpSpec, OperationGraphSpec, PointwiseMode, TensorLayout, TensorSpec,
26};
27use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
28
29/// Convolution descriptor parameters, dimension-generic.
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct ConvDescParams {
32    /// Number of spatial dimensions (1, 2, or 3).
33    pub spatial_dims: usize,
34    /// Per-dim leading padding.
35    pub pre_padding: Vec<i64>,
36    /// Per-dim trailing padding.
37    pub post_padding: Vec<i64>,
38    /// Per-dim filter stride.
39    pub stride: Vec<i64>,
40    /// Per-dim dilation.
41    pub dilation: Vec<i64>,
42    /// Group count (≥ 1).
43    pub groups: i64,
44}
45
46impl ConvDescParams {
47    /// Symmetric same-padding helper for 2D conv.
48    pub fn symmetric_2d(pad: i64, stride: i64, dilation: i64) -> Self {
49        Self {
50            spatial_dims: 2,
51            pre_padding: vec![pad, pad],
52            post_padding: vec![pad, pad],
53            stride: vec![stride, stride],
54            dilation: vec![dilation, dilation],
55            groups: 1,
56        }
57    }
58
59    /// Symmetric helper for 1D conv.
60    pub fn symmetric_1d(pad: i64, stride: i64, dilation: i64) -> Self {
61        Self {
62            spatial_dims: 1,
63            pre_padding: vec![pad],
64            post_padding: vec![pad],
65            stride: vec![stride],
66            dilation: vec![dilation],
67            groups: 1,
68        }
69    }
70
71    /// Symmetric helper for 3D conv.
72    pub fn symmetric_3d(pad: i64, stride: i64, dilation: i64) -> Self {
73        Self {
74            spatial_dims: 3,
75            pre_padding: vec![pad, pad, pad],
76            post_padding: vec![pad, pad, pad],
77            stride: vec![stride, stride, stride],
78            dilation: vec![dilation, dilation, dilation],
79            groups: 1,
80        }
81    }
82
83    pub fn with_groups(mut self, g: i64) -> Self {
84        self.groups = g;
85        self
86    }
87}
88
89/// Optional fused epilogue tail attached to conv-fwd. The bias is
90/// represented as an opaque marker on the spec layer (the graph
91/// builder records "there is a bias of this dtype + shape"); the
92/// concrete `GpuRef<T>` lives on [`ConvFwdRequest`] proper.
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum EpilogueKind {
95    /// No epilogue.
96    None,
97    /// Add bias broadcast across spatial dims.
98    Bias,
99    /// Bias + activation.
100    BiasActivation(ActivationKind),
101}
102
103pub(crate) fn dtype_tag<T: CudnnSupported>() -> DtypeTag {
104    match T::NAME {
105        "f32" => DtypeTag::F32,
106        "f64" => DtypeTag::F64,
107        "f16" => DtypeTag::F16,
108        "bf16" => DtypeTag::Bf16,
109        "i8" => DtypeTag::I8,
110        other => panic!("unsupported cuDNN dtype name: {other}"),
111    }
112}
113
114/// Build the spec-side conv-fwd op graph, parameterised by dtype +
115/// dims. Independent of `GpuRef` so callers (tests, plan-cache lookup)
116/// can build it without owning device buffers.
117pub fn build_conv_fwd_graph(
118    dtype: DtypeTag,
119    x_dims: &[i64],
120    w_dims: &[i64],
121    y_dims: &[i64],
122    conv: &ConvDescParams,
123    layout: TensorLayout,
124    epilogue: EpilogueKind,
125) -> OperationGraphSpec {
126    let mut g = OperationGraphSpec::new("conv_fwd");
127    let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
128    let w_uid = g.add_tensor(TensorSpec::new(2, dtype, w_dims.to_vec(), layout));
129    let y_uid = g.add_tensor(TensorSpec::new(3, dtype, y_dims.to_vec(), layout));
130    g.add_op(OpSpec::ConvFwd {
131        x: x_uid,
132        w: w_uid,
133        y: y_uid,
134        spatial_dims: conv.spatial_dims,
135        pre_padding: conv.pre_padding.clone(),
136        post_padding: conv.post_padding.clone(),
137        stride: conv.stride.clone(),
138        dilation: conv.dilation.clone(),
139        compute_dtype: dtype,
140        alpha: 1.0,
141        beta: 0.0,
142    });
143    match epilogue {
144        EpilogueKind::None => {}
145        EpilogueKind::Bias => {
146            let b_uid = g.add_tensor(TensorSpec::new(4, dtype, bias_dims(y_dims), layout));
147            let yb_uid = g.add_tensor(TensorSpec::new(5, dtype, y_dims.to_vec(), layout));
148            g.add_op(OpSpec::Pointwise {
149                mode: PointwiseMode::Add,
150                x: y_uid,
151                b: Some(b_uid),
152                y: yb_uid,
153                compute_dtype: dtype,
154                alpha1: 1.0,
155                alpha2: 1.0,
156            });
157        }
158        EpilogueKind::BiasActivation(act) => {
159            let b_uid = g.add_tensor(TensorSpec::new(4, dtype, bias_dims(y_dims), layout));
160            let yb_uid = g.add_tensor(TensorSpec::new(5, dtype, y_dims.to_vec(), layout));
161            g.add_op(OpSpec::Pointwise {
162                mode: PointwiseMode::Add,
163                x: y_uid,
164                b: Some(b_uid),
165                y: yb_uid,
166                compute_dtype: dtype,
167                alpha1: 1.0,
168                alpha2: 1.0,
169            });
170            let act_out = g.add_tensor(TensorSpec::new(6, dtype, y_dims.to_vec(), layout));
171            g.add_op(OpSpec::Pointwise {
172                mode: act.pointwise_mode(),
173                x: yb_uid,
174                b: None,
175                y: act_out,
176                compute_dtype: dtype,
177                alpha1: 1.0,
178                alpha2: 0.0,
179            });
180        }
181    }
182    g
183}
184
185/// Build the spec-side conv-bwd-data op graph.
186pub fn build_conv_bwd_data_graph(
187    dtype: DtypeTag,
188    dy_dims: &[i64],
189    w_dims: &[i64],
190    dx_dims: &[i64],
191    conv: &ConvDescParams,
192    layout: TensorLayout,
193) -> OperationGraphSpec {
194    let mut g = OperationGraphSpec::new("conv_bwd_data");
195    let dy_uid = g.add_tensor(TensorSpec::new(1, dtype, dy_dims.to_vec(), layout));
196    let w_uid = g.add_tensor(TensorSpec::new(2, dtype, w_dims.to_vec(), layout));
197    let dx_uid = g.add_tensor(TensorSpec::new(3, dtype, dx_dims.to_vec(), layout));
198    g.add_op(OpSpec::ConvBwdData {
199        dy: dy_uid,
200        w: w_uid,
201        dx: dx_uid,
202        spatial_dims: conv.spatial_dims,
203        pre_padding: conv.pre_padding.clone(),
204        post_padding: conv.post_padding.clone(),
205        stride: conv.stride.clone(),
206        dilation: conv.dilation.clone(),
207        compute_dtype: dtype,
208        alpha: 1.0,
209        beta: 0.0,
210    });
211    g
212}
213
214/// Build the spec-side conv-bwd-filter op graph.
215pub fn build_conv_bwd_filter_graph(
216    dtype: DtypeTag,
217    x_dims: &[i64],
218    dy_dims: &[i64],
219    dw_dims: &[i64],
220    conv: &ConvDescParams,
221    layout: TensorLayout,
222) -> OperationGraphSpec {
223    let mut g = OperationGraphSpec::new("conv_bwd_filter");
224    let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
225    let dy_uid = g.add_tensor(TensorSpec::new(2, dtype, dy_dims.to_vec(), layout));
226    let dw_uid = g.add_tensor(TensorSpec::new(3, dtype, dw_dims.to_vec(), layout));
227    g.add_op(OpSpec::ConvBwdFilter {
228        x: x_uid,
229        dy: dy_uid,
230        dw: dw_uid,
231        spatial_dims: conv.spatial_dims,
232        pre_padding: conv.pre_padding.clone(),
233        post_padding: conv.post_padding.clone(),
234        stride: conv.stride.clone(),
235        dilation: conv.dilation.clone(),
236        compute_dtype: dtype,
237        alpha: 1.0,
238        beta: 0.0,
239    });
240    g
241}
242
243/// Bias broadcast dim-vector matching `y_dims`. cuDNN bias tensors
244/// are `[1, C, 1, 1...]` regardless of channel-first vs channel-last
245/// layout — the layout is captured in strides, not dims.
246fn bias_dims(y_dims: &[i64]) -> Vec<i64> {
247    let mut out = vec![1i64; y_dims.len()];
248    if y_dims.len() >= 2 {
249        out[1] = y_dims[1];
250    }
251    out
252}
253
254// ----- Request types -------------------------------------------------
255
256/// Forward convolution: `y = alpha * conv(x, w) + beta * y`,
257/// optionally with a fused bias / activation tail.
258pub struct ConvFwdRequest<T: CudnnSupported> {
259    pub x: GpuRef<T>,
260    pub x_dims: Vec<i64>,
261    pub w: GpuRef<T>,
262    pub w_dims: Vec<i64>,
263    pub y: GpuRef<T>,
264    pub y_dims: Vec<i64>,
265    pub bias: Option<GpuRef<T>>,
266    pub conv: ConvDescParams,
267    pub layout: TensorLayout,
268    pub epilogue: EpilogueKind,
269    pub alpha: T::Scalar,
270    pub beta: T::Scalar,
271    pub reply: oneshot::Sender<Result<(), GpuError>>,
272    pub _ty: PhantomData<T>,
273}
274
275impl<T: CudnnSupported> ConvFwdRequest<T> {
276    pub fn graph_spec(&self) -> OperationGraphSpec {
277        build_conv_fwd_graph(
278            dtype_tag::<T>(),
279            &self.x_dims,
280            &self.w_dims,
281            &self.y_dims,
282            &self.conv,
283            self.layout,
284            self.epilogue,
285        )
286    }
287}
288
289impl<T: CudnnSupported> CudnnDispatch for ConvFwdRequest<T> {
290    fn dtype_name(&self) -> &'static str {
291        T::NAME
292    }
293    fn op_kind(&self) -> &'static str {
294        "conv_fwd"
295    }
296    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
297        let _ = self.reply.send(Err(GpuError::LibraryError {
298            lib: "cudnn",
299            msg: "ConvFwdRequest dispatch requires the v9 frontend graph builder \
300                  (cudnnBackendCreateDescriptor path); skeleton entry point only"
301                .to_string(),
302        }));
303    }
304}
305
306/// Backward-data convolution: `dx = alpha * conv_bwd_data(w, dy) + beta * dx`.
307pub struct ConvBwdDataRequest<T: CudnnSupported> {
308    pub dy: GpuRef<T>,
309    pub dy_dims: Vec<i64>,
310    pub w: GpuRef<T>,
311    pub w_dims: Vec<i64>,
312    pub dx: GpuRef<T>,
313    pub dx_dims: Vec<i64>,
314    pub conv: ConvDescParams,
315    pub layout: TensorLayout,
316    pub alpha: T::Scalar,
317    pub beta: T::Scalar,
318    pub reply: oneshot::Sender<Result<(), GpuError>>,
319    pub _ty: PhantomData<T>,
320}
321
322impl<T: CudnnSupported> ConvBwdDataRequest<T> {
323    pub fn graph_spec(&self) -> OperationGraphSpec {
324        build_conv_bwd_data_graph(
325            dtype_tag::<T>(),
326            &self.dy_dims,
327            &self.w_dims,
328            &self.dx_dims,
329            &self.conv,
330            self.layout,
331        )
332    }
333}
334
335impl<T: CudnnSupported> CudnnDispatch for ConvBwdDataRequest<T> {
336    fn dtype_name(&self) -> &'static str {
337        T::NAME
338    }
339    fn op_kind(&self) -> &'static str {
340        "conv_bwd_data"
341    }
342    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
343        let _ = self.reply.send(Err(GpuError::LibraryError {
344            lib: "cudnn",
345            msg: "ConvBwdDataRequest dispatch requires the v9 frontend graph builder; \
346                  skeleton entry point only"
347                .to_string(),
348        }));
349    }
350}
351
352/// Backward-filter convolution: `dw = alpha * conv_bwd_filter(x, dy) + beta * dw`.
353pub struct ConvBwdFilterRequest<T: CudnnSupported> {
354    pub x: GpuRef<T>,
355    pub x_dims: Vec<i64>,
356    pub dy: GpuRef<T>,
357    pub dy_dims: Vec<i64>,
358    pub dw: GpuRef<T>,
359    pub dw_dims: Vec<i64>,
360    pub conv: ConvDescParams,
361    pub layout: TensorLayout,
362    pub alpha: T::Scalar,
363    pub beta: T::Scalar,
364    pub reply: oneshot::Sender<Result<(), GpuError>>,
365    pub _ty: PhantomData<T>,
366}
367
368impl<T: CudnnSupported> ConvBwdFilterRequest<T> {
369    pub fn graph_spec(&self) -> OperationGraphSpec {
370        build_conv_bwd_filter_graph(
371            dtype_tag::<T>(),
372            &self.x_dims,
373            &self.dy_dims,
374            &self.dw_dims,
375            &self.conv,
376            self.layout,
377        )
378    }
379}
380
381impl<T: CudnnSupported> CudnnDispatch for ConvBwdFilterRequest<T> {
382    fn dtype_name(&self) -> &'static str {
383        T::NAME
384    }
385    fn op_kind(&self) -> &'static str {
386        "conv_bwd_filter"
387    }
388    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
389        let _ = self.reply.send(Err(GpuError::LibraryError {
390            lib: "cudnn",
391            msg: "ConvBwdFilterRequest dispatch requires the v9 frontend graph builder; \
392                  skeleton entry point only"
393                .to_string(),
394        }));
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::kernel::cudnn::graph::cache_key;
402
403    fn round_trip_fwd(dt: DtypeTag, dt_name: &'static str, layout: TensorLayout) {
404        let g = build_conv_fwd_graph(
405            dt,
406            &[1, 3, 8, 8],
407            &[16, 3, 3, 3],
408            &[1, 16, 6, 6],
409            &ConvDescParams::symmetric_2d(0, 1, 1),
410            layout,
411            EpilogueKind::None,
412        );
413        assert_eq!(g.tensors.len(), 3);
414        assert_eq!(g.ops.len(), 1);
415        let key = cache_key("conv_fwd", dt, &g);
416        assert_eq!(key.op_kind, "conv_fwd");
417        assert_eq!(key.dtype, dt);
418        // Re-building from the same inputs yields the same signature.
419        let g2 = build_conv_fwd_graph(
420            dt,
421            &[1, 3, 8, 8],
422            &[16, 3, 3, 3],
423            &[1, 16, 6, 6],
424            &ConvDescParams::symmetric_2d(0, 1, 1),
425            layout,
426            EpilogueKind::None,
427        );
428        assert_eq!(g.signature(), g2.signature());
429        assert_eq!(dt.name(), dt_name);
430    }
431
432    #[test]
433    fn conv_fwd_request_round_trip_f32_f64_f16_bf16() {
434        round_trip_fwd(DtypeTag::F32, "f32", TensorLayout::NchwPacked);
435        round_trip_fwd(DtypeTag::F64, "f64", TensorLayout::NchwPacked);
436        round_trip_fwd(DtypeTag::F16, "f16", TensorLayout::NchwPacked);
437        round_trip_fwd(DtypeTag::Bf16, "bf16", TensorLayout::NchwPacked);
438        // Also run NHWC for f32 to exercise the layout path.
439        round_trip_fwd(DtypeTag::F32, "f32", TensorLayout::NhwcPacked);
440    }
441
442    #[test]
443    fn conv_bwd_data_filter_request_round_trip() {
444        let g = build_conv_bwd_data_graph(
445            DtypeTag::F32,
446            &[1, 16, 6, 6],
447            &[16, 3, 3, 3],
448            &[1, 3, 8, 8],
449            &ConvDescParams::symmetric_2d(0, 1, 1),
450            TensorLayout::NchwPacked,
451        );
452        assert_eq!(g.ops.len(), 1);
453        match &g.ops[0] {
454            OpSpec::ConvBwdData { spatial_dims, .. } => assert_eq!(*spatial_dims, 2),
455            _ => panic!("wrong op"),
456        }
457
458        let g = build_conv_bwd_filter_graph(
459            DtypeTag::F32,
460            &[1, 3, 8, 8],
461            &[1, 16, 6, 6],
462            &[16, 3, 3, 3],
463            &ConvDescParams::symmetric_2d(0, 1, 1),
464            TensorLayout::NchwPacked,
465        );
466        match &g.ops[0] {
467            OpSpec::ConvBwdFilter { spatial_dims, .. } => assert_eq!(*spatial_dims, 2),
468            _ => panic!("wrong op"),
469        }
470    }
471
472    #[test]
473    fn nchw_vs_nhwc_layout_handled() {
474        let g_nchw = build_conv_fwd_graph(
475            DtypeTag::F32,
476            &[1, 3, 8, 8],
477            &[16, 3, 3, 3],
478            &[1, 16, 6, 6],
479            &ConvDescParams::symmetric_2d(0, 1, 1),
480            TensorLayout::NchwPacked,
481            EpilogueKind::None,
482        );
483        let g_nhwc = build_conv_fwd_graph(
484            DtypeTag::F32,
485            &[1, 3, 8, 8],
486            &[16, 3, 3, 3],
487            &[1, 16, 6, 6],
488            &ConvDescParams::symmetric_2d(0, 1, 1),
489            TensorLayout::NhwcPacked,
490            EpilogueKind::None,
491        );
492        assert_ne!(g_nchw.signature(), g_nhwc.signature());
493        assert_eq!(g_nchw.tensors[0].strides, vec![192, 64, 8, 1]);
494        assert_ne!(g_nhwc.tensors[0].strides, g_nchw.tensors[0].strides);
495    }
496
497    #[test]
498    fn conv_fwd_with_bias_activation_epilogue() {
499        let g = build_conv_fwd_graph(
500            DtypeTag::F32,
501            &[1, 3, 8, 8],
502            &[16, 3, 3, 3],
503            &[1, 16, 6, 6],
504            &ConvDescParams::symmetric_2d(0, 1, 1),
505            TensorLayout::NhwcPacked,
506            EpilogueKind::BiasActivation(ActivationKind::Relu),
507        );
508        // conv + bias-add + activation
509        assert_eq!(g.ops.len(), 3);
510        assert_eq!(g.tensors.len(), 6);
511    }
512
513    #[test]
514    fn conv_1d_and_3d_descriptor_params() {
515        let p1 = ConvDescParams::symmetric_1d(1, 1, 1);
516        assert_eq!(p1.spatial_dims, 1);
517        assert_eq!(p1.stride.len(), 1);
518        let p3 = ConvDescParams::symmetric_3d(1, 2, 1);
519        assert_eq!(p3.spatial_dims, 3);
520        assert_eq!(p3.stride, vec![2, 2, 2]);
521    }
522
523    #[test]
524    fn conv_grouped() {
525        let p = ConvDescParams::symmetric_2d(0, 1, 1).with_groups(8);
526        assert_eq!(p.groups, 8);
527    }
528}