Skip to main content

atomr_accel_cuda/kernel/cudnn/
pool.rs

1//! Pooling requests (max / avg, fwd + bwd) for the cuDNN actor.
2//!
3//! Routes through `CUDNN_BACKEND_OPERATION_RESAMPLE_FWD/BWD_DESCRIPTOR`
4//! in the v9 frontend graph.
5
6#![allow(dead_code)]
7
8use std::marker::PhantomData;
9
10use tokio::sync::oneshot;
11
12use crate::dtype::CudnnSupported;
13use crate::error::GpuError;
14use crate::gpu_ref::GpuRef;
15use crate::kernel::cudnn::conv::dtype_tag;
16use crate::kernel::cudnn::graph::{
17    DtypeTag, OpSpec, OperationGraphSpec, PoolKind, TensorLayout, TensorSpec,
18};
19use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
20
21/// Pooling op kind — `kind == PoolKind::*Bwd` selects the backward
22/// pass.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum PoolMode {
25    Max,
26    Avg,
27    AvgExcludePadding,
28}
29
30impl PoolMode {
31    pub fn fwd(self) -> PoolKind {
32        match self {
33            PoolMode::Max => PoolKind::MaxFwd,
34            PoolMode::Avg | PoolMode::AvgExcludePadding => PoolKind::AvgFwd,
35        }
36    }
37    pub fn bwd(self) -> PoolKind {
38        match self {
39            PoolMode::Max => PoolKind::MaxBwd,
40            PoolMode::Avg | PoolMode::AvgExcludePadding => PoolKind::AvgBwd,
41        }
42    }
43}
44
45/// Pooling parameter struct, dim-generic.
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47pub struct PoolParams {
48    pub mode: PoolMode,
49    /// Window per spatial dim.
50    pub window: Vec<i64>,
51    pub pre_padding: Vec<i64>,
52    pub post_padding: Vec<i64>,
53    pub stride: Vec<i64>,
54}
55
56impl PoolParams {
57    /// 2D pooling helper.
58    pub fn pool_2d(mode: PoolMode, kernel: i64, stride: i64, padding: i64) -> Self {
59        Self {
60            mode,
61            window: vec![kernel, kernel],
62            pre_padding: vec![padding, padding],
63            post_padding: vec![padding, padding],
64            stride: vec![stride, stride],
65        }
66    }
67}
68
69/// Forward pooling.
70pub struct PoolFwdRequest<T: CudnnSupported> {
71    pub x: GpuRef<T>,
72    pub y: GpuRef<T>,
73    pub x_dims: Vec<i64>,
74    pub y_dims: Vec<i64>,
75    pub layout: TensorLayout,
76    pub params: PoolParams,
77    pub alpha: T::Scalar,
78    pub beta: T::Scalar,
79    pub reply: oneshot::Sender<Result<(), GpuError>>,
80    pub _ty: PhantomData<T>,
81}
82
83impl<T: CudnnSupported> PoolFwdRequest<T> {
84    pub fn graph_spec(&self) -> OperationGraphSpec {
85        build_pool_fwd_graph(
86            dtype_tag::<T>(),
87            &self.x_dims,
88            &self.y_dims,
89            self.layout,
90            &self.params,
91        )
92    }
93}
94
95impl<T: CudnnSupported> CudnnDispatch for PoolFwdRequest<T> {
96    fn dtype_name(&self) -> &'static str {
97        T::NAME
98    }
99    fn op_kind(&self) -> &'static str {
100        "pool_fwd"
101    }
102    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
103        let _ = self.reply.send(Err(GpuError::LibraryError {
104            lib: "cudnn",
105            msg: "PoolFwdRequest dispatch requires the v9 frontend graph builder; \
106                  skeleton entry point only"
107                .to_string(),
108        }));
109    }
110}
111
112/// Backward pooling.
113pub struct PoolBwdRequest<T: CudnnSupported> {
114    pub x: GpuRef<T>,
115    pub y: GpuRef<T>,
116    pub dy: GpuRef<T>,
117    pub dx: GpuRef<T>,
118    pub x_dims: Vec<i64>,
119    pub y_dims: Vec<i64>,
120    pub layout: TensorLayout,
121    pub params: PoolParams,
122    pub alpha: T::Scalar,
123    pub beta: T::Scalar,
124    pub reply: oneshot::Sender<Result<(), GpuError>>,
125    pub _ty: PhantomData<T>,
126}
127
128impl<T: CudnnSupported> PoolBwdRequest<T> {
129    pub fn graph_spec(&self) -> OperationGraphSpec {
130        build_pool_bwd_graph(
131            dtype_tag::<T>(),
132            &self.x_dims,
133            &self.y_dims,
134            self.layout,
135            &self.params,
136        )
137    }
138}
139
140impl<T: CudnnSupported> CudnnDispatch for PoolBwdRequest<T> {
141    fn dtype_name(&self) -> &'static str {
142        T::NAME
143    }
144    fn op_kind(&self) -> &'static str {
145        "pool_bwd"
146    }
147    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
148        let _ = self.reply.send(Err(GpuError::LibraryError {
149            lib: "cudnn",
150            msg: "PoolBwdRequest dispatch requires the v9 frontend graph builder; \
151                  skeleton entry point only"
152                .to_string(),
153        }));
154    }
155}
156
157pub fn build_pool_fwd_graph(
158    dtype: DtypeTag,
159    x_dims: &[i64],
160    y_dims: &[i64],
161    layout: TensorLayout,
162    p: &PoolParams,
163) -> OperationGraphSpec {
164    let mut g = OperationGraphSpec::new("pool_fwd");
165    let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
166    let y_uid = g.add_tensor(TensorSpec::new(2, dtype, y_dims.to_vec(), layout));
167    g.add_op(OpSpec::PoolFwd {
168        kind: p.mode.fwd(),
169        x: x_uid,
170        y: y_uid,
171        window: p.window.clone(),
172        pre_padding: p.pre_padding.clone(),
173        post_padding: p.post_padding.clone(),
174        stride: p.stride.clone(),
175        compute_dtype: dtype,
176    });
177    g
178}
179
180pub fn build_pool_bwd_graph(
181    dtype: DtypeTag,
182    x_dims: &[i64],
183    y_dims: &[i64],
184    layout: TensorLayout,
185    p: &PoolParams,
186) -> OperationGraphSpec {
187    let mut g = OperationGraphSpec::new("pool_bwd");
188    let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
189    let y_uid = g.add_tensor(TensorSpec::new(2, dtype, y_dims.to_vec(), layout));
190    let dy_uid = g.add_tensor(TensorSpec::new(3, dtype, y_dims.to_vec(), layout));
191    let dx_uid = g.add_tensor(TensorSpec::new(4, dtype, x_dims.to_vec(), layout));
192    g.add_op(OpSpec::PoolBwd {
193        kind: p.mode.bwd(),
194        dy: dy_uid,
195        x: x_uid,
196        y: y_uid,
197        dx: dx_uid,
198        window: p.window.clone(),
199        pre_padding: p.pre_padding.clone(),
200        post_padding: p.post_padding.clone(),
201        stride: p.stride.clone(),
202        compute_dtype: dtype,
203    });
204    g
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn pool_fwd_bwd_round_trip() {
213        let p = PoolParams::pool_2d(PoolMode::Max, 2, 2, 0);
214        let g_fwd = build_pool_fwd_graph(
215            DtypeTag::F32,
216            &[1, 16, 8, 8],
217            &[1, 16, 4, 4],
218            TensorLayout::NchwPacked,
219            &p,
220        );
221        match &g_fwd.ops[0] {
222            OpSpec::PoolFwd { kind, .. } => assert_eq!(*kind, PoolKind::MaxFwd),
223            _ => panic!("wrong op"),
224        }
225        let g_bwd = build_pool_bwd_graph(
226            DtypeTag::F32,
227            &[1, 16, 8, 8],
228            &[1, 16, 4, 4],
229            TensorLayout::NchwPacked,
230            &p,
231        );
232        assert_eq!(g_bwd.tensors.len(), 4);
233        match &g_bwd.ops[0] {
234            OpSpec::PoolBwd { kind, .. } => assert_eq!(*kind, PoolKind::MaxBwd),
235            _ => panic!("wrong op"),
236        }
237
238        // Avg pool round-trip.
239        let avg = PoolParams::pool_2d(PoolMode::Avg, 2, 2, 0);
240        let g = build_pool_fwd_graph(
241            DtypeTag::F32,
242            &[1, 16, 8, 8],
243            &[1, 16, 4, 4],
244            TensorLayout::NchwPacked,
245            &avg,
246        );
247        match &g.ops[0] {
248            OpSpec::PoolFwd { kind, .. } => assert_eq!(*kind, PoolKind::AvgFwd),
249            _ => panic!("wrong op"),
250        }
251    }
252}