1#![allow(dead_code)]
7
8use std::marker::PhantomData;
9
10use tokio::sync::oneshot;
11
12use crate::dtype::CudnnSupported;
13use crate::error::GpuError;
14use crate::gpu_ref::GpuRef;
15use crate::kernel::cudnn::conv::dtype_tag;
16use crate::kernel::cudnn::graph::{
17 DtypeTag, OpSpec, OperationGraphSpec, PoolKind, TensorLayout, TensorSpec,
18};
19use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum PoolMode {
25 Max,
26 Avg,
27 AvgExcludePadding,
28}
29
30impl PoolMode {
31 pub fn fwd(self) -> PoolKind {
32 match self {
33 PoolMode::Max => PoolKind::MaxFwd,
34 PoolMode::Avg | PoolMode::AvgExcludePadding => PoolKind::AvgFwd,
35 }
36 }
37 pub fn bwd(self) -> PoolKind {
38 match self {
39 PoolMode::Max => PoolKind::MaxBwd,
40 PoolMode::Avg | PoolMode::AvgExcludePadding => PoolKind::AvgBwd,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47pub struct PoolParams {
48 pub mode: PoolMode,
49 pub window: Vec<i64>,
51 pub pre_padding: Vec<i64>,
52 pub post_padding: Vec<i64>,
53 pub stride: Vec<i64>,
54}
55
56impl PoolParams {
57 pub fn pool_2d(mode: PoolMode, kernel: i64, stride: i64, padding: i64) -> Self {
59 Self {
60 mode,
61 window: vec![kernel, kernel],
62 pre_padding: vec![padding, padding],
63 post_padding: vec![padding, padding],
64 stride: vec![stride, stride],
65 }
66 }
67}
68
69pub struct PoolFwdRequest<T: CudnnSupported> {
71 pub x: GpuRef<T>,
72 pub y: GpuRef<T>,
73 pub x_dims: Vec<i64>,
74 pub y_dims: Vec<i64>,
75 pub layout: TensorLayout,
76 pub params: PoolParams,
77 pub alpha: T::Scalar,
78 pub beta: T::Scalar,
79 pub reply: oneshot::Sender<Result<(), GpuError>>,
80 pub _ty: PhantomData<T>,
81}
82
83impl<T: CudnnSupported> PoolFwdRequest<T> {
84 pub fn graph_spec(&self) -> OperationGraphSpec {
85 build_pool_fwd_graph(
86 dtype_tag::<T>(),
87 &self.x_dims,
88 &self.y_dims,
89 self.layout,
90 &self.params,
91 )
92 }
93}
94
95impl<T: CudnnSupported> CudnnDispatch for PoolFwdRequest<T> {
96 fn dtype_name(&self) -> &'static str {
97 T::NAME
98 }
99 fn op_kind(&self) -> &'static str {
100 "pool_fwd"
101 }
102 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
103 let _ = self.reply.send(Err(GpuError::LibraryError {
104 lib: "cudnn",
105 msg: "PoolFwdRequest dispatch requires the v9 frontend graph builder; \
106 skeleton entry point only"
107 .to_string(),
108 }));
109 }
110}
111
112pub struct PoolBwdRequest<T: CudnnSupported> {
114 pub x: GpuRef<T>,
115 pub y: GpuRef<T>,
116 pub dy: GpuRef<T>,
117 pub dx: GpuRef<T>,
118 pub x_dims: Vec<i64>,
119 pub y_dims: Vec<i64>,
120 pub layout: TensorLayout,
121 pub params: PoolParams,
122 pub alpha: T::Scalar,
123 pub beta: T::Scalar,
124 pub reply: oneshot::Sender<Result<(), GpuError>>,
125 pub _ty: PhantomData<T>,
126}
127
128impl<T: CudnnSupported> PoolBwdRequest<T> {
129 pub fn graph_spec(&self) -> OperationGraphSpec {
130 build_pool_bwd_graph(
131 dtype_tag::<T>(),
132 &self.x_dims,
133 &self.y_dims,
134 self.layout,
135 &self.params,
136 )
137 }
138}
139
140impl<T: CudnnSupported> CudnnDispatch for PoolBwdRequest<T> {
141 fn dtype_name(&self) -> &'static str {
142 T::NAME
143 }
144 fn op_kind(&self) -> &'static str {
145 "pool_bwd"
146 }
147 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
148 let _ = self.reply.send(Err(GpuError::LibraryError {
149 lib: "cudnn",
150 msg: "PoolBwdRequest dispatch requires the v9 frontend graph builder; \
151 skeleton entry point only"
152 .to_string(),
153 }));
154 }
155}
156
157pub fn build_pool_fwd_graph(
158 dtype: DtypeTag,
159 x_dims: &[i64],
160 y_dims: &[i64],
161 layout: TensorLayout,
162 p: &PoolParams,
163) -> OperationGraphSpec {
164 let mut g = OperationGraphSpec::new("pool_fwd");
165 let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
166 let y_uid = g.add_tensor(TensorSpec::new(2, dtype, y_dims.to_vec(), layout));
167 g.add_op(OpSpec::PoolFwd {
168 kind: p.mode.fwd(),
169 x: x_uid,
170 y: y_uid,
171 window: p.window.clone(),
172 pre_padding: p.pre_padding.clone(),
173 post_padding: p.post_padding.clone(),
174 stride: p.stride.clone(),
175 compute_dtype: dtype,
176 });
177 g
178}
179
180pub fn build_pool_bwd_graph(
181 dtype: DtypeTag,
182 x_dims: &[i64],
183 y_dims: &[i64],
184 layout: TensorLayout,
185 p: &PoolParams,
186) -> OperationGraphSpec {
187 let mut g = OperationGraphSpec::new("pool_bwd");
188 let x_uid = g.add_tensor(TensorSpec::new(1, dtype, x_dims.to_vec(), layout));
189 let y_uid = g.add_tensor(TensorSpec::new(2, dtype, y_dims.to_vec(), layout));
190 let dy_uid = g.add_tensor(TensorSpec::new(3, dtype, y_dims.to_vec(), layout));
191 let dx_uid = g.add_tensor(TensorSpec::new(4, dtype, x_dims.to_vec(), layout));
192 g.add_op(OpSpec::PoolBwd {
193 kind: p.mode.bwd(),
194 dy: dy_uid,
195 x: x_uid,
196 y: y_uid,
197 dx: dx_uid,
198 window: p.window.clone(),
199 pre_padding: p.pre_padding.clone(),
200 post_padding: p.post_padding.clone(),
201 stride: p.stride.clone(),
202 compute_dtype: dtype,
203 });
204 g
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn pool_fwd_bwd_round_trip() {
213 let p = PoolParams::pool_2d(PoolMode::Max, 2, 2, 0);
214 let g_fwd = build_pool_fwd_graph(
215 DtypeTag::F32,
216 &[1, 16, 8, 8],
217 &[1, 16, 4, 4],
218 TensorLayout::NchwPacked,
219 &p,
220 );
221 match &g_fwd.ops[0] {
222 OpSpec::PoolFwd { kind, .. } => assert_eq!(*kind, PoolKind::MaxFwd),
223 _ => panic!("wrong op"),
224 }
225 let g_bwd = build_pool_bwd_graph(
226 DtypeTag::F32,
227 &[1, 16, 8, 8],
228 &[1, 16, 4, 4],
229 TensorLayout::NchwPacked,
230 &p,
231 );
232 assert_eq!(g_bwd.tensors.len(), 4);
233 match &g_bwd.ops[0] {
234 OpSpec::PoolBwd { kind, .. } => assert_eq!(*kind, PoolKind::MaxBwd),
235 _ => panic!("wrong op"),
236 }
237
238 let avg = PoolParams::pool_2d(PoolMode::Avg, 2, 2, 0);
240 let g = build_pool_fwd_graph(
241 DtypeTag::F32,
242 &[1, 16, 8, 8],
243 &[1, 16, 4, 4],
244 TensorLayout::NchwPacked,
245 &avg,
246 );
247 match &g.ops[0] {
248 OpSpec::PoolFwd { kind, .. } => assert_eq!(*kind, PoolKind::AvgFwd),
249 _ => panic!("wrong op"),
250 }
251 }
252}