1#![allow(dead_code)]
8
9use std::marker::PhantomData;
10
11use tokio::sync::oneshot;
12
13use crate::dtype::CudnnSupported;
14use crate::error::GpuError;
15use crate::gpu_ref::GpuRef;
16use crate::kernel::cudnn::conv::dtype_tag;
17use crate::kernel::cudnn::graph::{
18 DtypeTag, OpSpec, OperationGraphSpec, PointwiseMode, TensorLayout, TensorSpec,
19};
20use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum ActivationKind {
25 Relu,
26 Sigmoid,
27 Tanh,
28 Gelu,
29 GeluApprox,
30 Swish,
31 Elu,
32 Softplus,
33 Identity,
34}
35
36impl ActivationKind {
37 pub fn pointwise_mode(self) -> PointwiseMode {
39 match self {
40 ActivationKind::Relu => PointwiseMode::Relu,
41 ActivationKind::Sigmoid => PointwiseMode::Sigmoid,
42 ActivationKind::Tanh => PointwiseMode::Tanh,
43 ActivationKind::Gelu => PointwiseMode::Gelu,
44 ActivationKind::GeluApprox => PointwiseMode::GeluApprox,
45 ActivationKind::Swish => PointwiseMode::Swish,
46 ActivationKind::Elu => PointwiseMode::Elu,
47 ActivationKind::Softplus => PointwiseMode::Softplus,
48 ActivationKind::Identity => PointwiseMode::Identity,
49 }
50 }
51
52 #[cfg(feature = "cudnn")]
56 pub fn cudnn_legacy_mode(self) -> cudarc::cudnn::sys::cudnnActivationMode_t {
57 use cudarc::cudnn::sys::cudnnActivationMode_t::*;
58 match self {
59 ActivationKind::Relu | ActivationKind::Identity => CUDNN_ACTIVATION_RELU,
60 ActivationKind::Sigmoid => CUDNN_ACTIVATION_SIGMOID,
61 ActivationKind::Tanh => CUDNN_ACTIVATION_TANH,
62 ActivationKind::Elu => CUDNN_ACTIVATION_ELU,
63 ActivationKind::Swish => CUDNN_ACTIVATION_SWISH,
64 ActivationKind::Gelu | ActivationKind::GeluApprox => CUDNN_ACTIVATION_RELU,
65 ActivationKind::Softplus => CUDNN_ACTIVATION_RELU,
66 }
67 }
68}
69
70pub struct ActivationFwdRequest<T: CudnnSupported> {
72 pub kind: ActivationKind,
73 pub x: GpuRef<T>,
74 pub y: GpuRef<T>,
75 pub dims: Vec<i64>,
76 pub layout: TensorLayout,
77 pub alpha: T::Scalar,
78 pub beta: T::Scalar,
79 pub reply: oneshot::Sender<Result<(), GpuError>>,
80 pub _ty: PhantomData<T>,
81}
82
83impl<T: CudnnSupported> ActivationFwdRequest<T> {
84 pub fn graph_spec(&self) -> OperationGraphSpec {
85 let dt = dtype_tag::<T>();
86 let mut g = OperationGraphSpec::new("activation_fwd");
87 let x_uid = g.add_tensor(TensorSpec::new(1, dt, self.dims.clone(), self.layout));
88 let y_uid = g.add_tensor(TensorSpec::new(2, dt, self.dims.clone(), self.layout));
89 g.add_op(OpSpec::Pointwise {
90 mode: self.kind.pointwise_mode(),
91 x: x_uid,
92 b: None,
93 y: y_uid,
94 compute_dtype: dt,
95 alpha1: 1.0,
96 alpha2: 0.0,
97 });
98 g
99 }
100}
101
102impl<T: CudnnSupported> CudnnDispatch for ActivationFwdRequest<T> {
103 fn dtype_name(&self) -> &'static str {
104 T::NAME
105 }
106 fn op_kind(&self) -> &'static str {
107 "activation_fwd"
108 }
109 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
110 let _ = self.reply.send(Err(GpuError::LibraryError {
111 lib: "cudnn",
112 msg: "ActivationFwdRequest dispatch requires the v9 frontend graph builder; \
113 skeleton entry point only"
114 .to_string(),
115 }));
116 }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
121pub enum SoftmaxMode {
122 Instance,
123 Channel,
124}
125
126pub struct SoftmaxFwdRequest<T: CudnnSupported> {
128 pub mode: SoftmaxMode,
129 pub x: GpuRef<T>,
130 pub y: GpuRef<T>,
131 pub dims: Vec<i64>,
132 pub layout: TensorLayout,
133 pub alpha: T::Scalar,
134 pub beta: T::Scalar,
135 pub reply: oneshot::Sender<Result<(), GpuError>>,
136 pub _ty: PhantomData<T>,
137}
138
139impl<T: CudnnSupported> CudnnDispatch for SoftmaxFwdRequest<T> {
140 fn dtype_name(&self) -> &'static str {
141 T::NAME
142 }
143 fn op_kind(&self) -> &'static str {
144 "softmax_fwd"
145 }
146 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
147 let _ = self.reply.send(Err(GpuError::LibraryError {
148 lib: "cudnn",
149 msg: "SoftmaxFwdRequest dispatch requires the v9 frontend graph builder; \
150 skeleton entry point only"
151 .to_string(),
152 }));
153 }
154}
155
156pub struct DropoutFwdRequest<T: CudnnSupported> {
159 pub x: GpuRef<T>,
160 pub y: GpuRef<T>,
161 pub mask: GpuRef<u8>,
162 pub dims: Vec<i64>,
163 pub layout: TensorLayout,
164 pub probability: f32,
165 pub seed: u64,
166 pub reply: oneshot::Sender<Result<(), GpuError>>,
167 pub _ty: PhantomData<T>,
168}
169
170impl<T: CudnnSupported> CudnnDispatch for DropoutFwdRequest<T> {
171 fn dtype_name(&self) -> &'static str {
172 T::NAME
173 }
174 fn op_kind(&self) -> &'static str {
175 "dropout_fwd"
176 }
177 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
178 let _ = self.reply.send(Err(GpuError::LibraryError {
179 lib: "cudnn",
180 msg: "DropoutFwdRequest dispatch requires the v9 frontend graph builder; \
181 skeleton entry point only"
182 .to_string(),
183 }));
184 }
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
189pub struct LrnParams {
190 pub n: u32,
191 pub alpha_milli: i64,
192 pub beta_milli: i64,
193 pub k_milli: i64,
194}
195
196impl LrnParams {
197 pub fn new(n: u32, alpha: f64, beta: f64, k: f64) -> Self {
198 Self {
199 n,
200 alpha_milli: (alpha * 1_000_000.0) as i64,
201 beta_milli: (beta * 1_000_000.0) as i64,
202 k_milli: (k * 1_000_000.0) as i64,
203 }
204 }
205}
206
207pub struct LrnFwdRequest<T: CudnnSupported> {
209 pub x: GpuRef<T>,
210 pub y: GpuRef<T>,
211 pub dims: Vec<i64>,
212 pub layout: TensorLayout,
213 pub params: LrnParams,
214 pub alpha: T::Scalar,
215 pub beta: T::Scalar,
216 pub reply: oneshot::Sender<Result<(), GpuError>>,
217 pub _ty: PhantomData<T>,
218}
219
220impl<T: CudnnSupported> CudnnDispatch for LrnFwdRequest<T> {
221 fn dtype_name(&self) -> &'static str {
222 T::NAME
223 }
224 fn op_kind(&self) -> &'static str {
225 "lrn_fwd"
226 }
227 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
228 let _ = self.reply.send(Err(GpuError::LibraryError {
229 lib: "cudnn",
230 msg: "LrnFwdRequest dispatch requires the v9 frontend graph builder; \
231 skeleton entry point only"
232 .to_string(),
233 }));
234 }
235}
236
237pub fn build_activation_fwd_graph(
239 dtype: DtypeTag,
240 dims: &[i64],
241 layout: TensorLayout,
242 kind: ActivationKind,
243) -> OperationGraphSpec {
244 let mut g = OperationGraphSpec::new("activation_fwd");
245 let x_uid = g.add_tensor(TensorSpec::new(1, dtype, dims.to_vec(), layout));
246 let y_uid = g.add_tensor(TensorSpec::new(2, dtype, dims.to_vec(), layout));
247 g.add_op(OpSpec::Pointwise {
248 mode: kind.pointwise_mode(),
249 x: x_uid,
250 b: None,
251 y: y_uid,
252 compute_dtype: dtype,
253 alpha1: 1.0,
254 alpha2: 0.0,
255 });
256 g
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn activation_kinds_have_pointwise_mode() {
265 assert_eq!(ActivationKind::Relu.pointwise_mode(), PointwiseMode::Relu);
266 assert_eq!(ActivationKind::Gelu.pointwise_mode(), PointwiseMode::Gelu);
267 assert_eq!(ActivationKind::Swish.pointwise_mode(), PointwiseMode::Swish);
268 assert_eq!(
269 ActivationKind::Softplus.pointwise_mode(),
270 PointwiseMode::Softplus
271 );
272 assert_eq!(ActivationKind::Elu.pointwise_mode(), PointwiseMode::Elu);
273 assert_eq!(
274 ActivationKind::Identity.pointwise_mode(),
275 PointwiseMode::Identity
276 );
277 }
278
279 #[test]
280 fn activation_fwd_graph_builds() {
281 let g = build_activation_fwd_graph(
282 DtypeTag::F32,
283 &[1, 3, 8, 8],
284 TensorLayout::NchwPacked,
285 ActivationKind::Gelu,
286 );
287 assert_eq!(g.tensors.len(), 2);
288 assert_eq!(g.ops.len(), 1);
289 }
290
291 #[test]
292 fn lrn_params_quantization() {
293 let p = LrnParams::new(5, 0.0001, 0.75, 1.0);
294 assert_eq!(p.n, 5);
295 assert_eq!(p.alpha_milli, 100);
296 assert_eq!(p.beta_milli, 750_000);
297 }
298}