Skip to main content

atomr_accel_cuda/kernel/cudnn/
activation.rs

1//! Activation, dropout, and LRN requests for the cuDNN actor.
2//!
3//! Activation set: `Relu`, `Sigmoid`, `Tanh` (existing) plus `Gelu`,
4//! `GeluApprox`, `Swish`, `Elu`, `Softplus`, `Identity`. cuDNN routes
5//! these through pointwise descriptor ops in the v9 frontend graph.
6
7#![allow(dead_code)]
8
9use std::marker::PhantomData;
10
11use tokio::sync::oneshot;
12
13use crate::dtype::CudnnSupported;
14use crate::error::GpuError;
15use crate::gpu_ref::GpuRef;
16use crate::kernel::cudnn::conv::dtype_tag;
17use crate::kernel::cudnn::graph::{
18    DtypeTag, OpSpec, OperationGraphSpec, PointwiseMode, TensorLayout, TensorSpec,
19};
20use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
21
22/// Activation function tag.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum ActivationKind {
25    Relu,
26    Sigmoid,
27    Tanh,
28    Gelu,
29    GeluApprox,
30    Swish,
31    Elu,
32    Softplus,
33    Identity,
34}
35
36impl ActivationKind {
37    /// Map to the v9 frontend `PointwiseMode`.
38    pub fn pointwise_mode(self) -> PointwiseMode {
39        match self {
40            ActivationKind::Relu => PointwiseMode::Relu,
41            ActivationKind::Sigmoid => PointwiseMode::Sigmoid,
42            ActivationKind::Tanh => PointwiseMode::Tanh,
43            ActivationKind::Gelu => PointwiseMode::Gelu,
44            ActivationKind::GeluApprox => PointwiseMode::GeluApprox,
45            ActivationKind::Swish => PointwiseMode::Swish,
46            ActivationKind::Elu => PointwiseMode::Elu,
47            ActivationKind::Softplus => PointwiseMode::Softplus,
48            ActivationKind::Identity => PointwiseMode::Identity,
49        }
50    }
51
52    /// Map to the legacy v7 `cudnnActivationMode_t` for the back-compat
53    /// dispatch path. Approximate / parametric activations fall back
54    /// to the plain `Relu`/`Sigmoid` etc. equivalent.
55    #[cfg(feature = "cudnn")]
56    pub fn cudnn_legacy_mode(self) -> cudarc::cudnn::sys::cudnnActivationMode_t {
57        use cudarc::cudnn::sys::cudnnActivationMode_t::*;
58        match self {
59            ActivationKind::Relu | ActivationKind::Identity => CUDNN_ACTIVATION_RELU,
60            ActivationKind::Sigmoid => CUDNN_ACTIVATION_SIGMOID,
61            ActivationKind::Tanh => CUDNN_ACTIVATION_TANH,
62            ActivationKind::Elu => CUDNN_ACTIVATION_ELU,
63            ActivationKind::Swish => CUDNN_ACTIVATION_SWISH,
64            ActivationKind::Gelu | ActivationKind::GeluApprox => CUDNN_ACTIVATION_RELU,
65            ActivationKind::Softplus => CUDNN_ACTIVATION_RELU,
66        }
67    }
68}
69
70/// Activation forward request. dims are the raw tensor dims.
71pub struct ActivationFwdRequest<T: CudnnSupported> {
72    pub kind: ActivationKind,
73    pub x: GpuRef<T>,
74    pub y: GpuRef<T>,
75    pub dims: Vec<i64>,
76    pub layout: TensorLayout,
77    pub alpha: T::Scalar,
78    pub beta: T::Scalar,
79    pub reply: oneshot::Sender<Result<(), GpuError>>,
80    pub _ty: PhantomData<T>,
81}
82
83impl<T: CudnnSupported> ActivationFwdRequest<T> {
84    pub fn graph_spec(&self) -> OperationGraphSpec {
85        let dt = dtype_tag::<T>();
86        let mut g = OperationGraphSpec::new("activation_fwd");
87        let x_uid = g.add_tensor(TensorSpec::new(1, dt, self.dims.clone(), self.layout));
88        let y_uid = g.add_tensor(TensorSpec::new(2, dt, self.dims.clone(), self.layout));
89        g.add_op(OpSpec::Pointwise {
90            mode: self.kind.pointwise_mode(),
91            x: x_uid,
92            b: None,
93            y: y_uid,
94            compute_dtype: dt,
95            alpha1: 1.0,
96            alpha2: 0.0,
97        });
98        g
99    }
100}
101
102impl<T: CudnnSupported> CudnnDispatch for ActivationFwdRequest<T> {
103    fn dtype_name(&self) -> &'static str {
104        T::NAME
105    }
106    fn op_kind(&self) -> &'static str {
107        "activation_fwd"
108    }
109    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
110        let _ = self.reply.send(Err(GpuError::LibraryError {
111            lib: "cudnn",
112            msg: "ActivationFwdRequest dispatch requires the v9 frontend graph builder; \
113                  skeleton entry point only"
114                .to_string(),
115        }));
116    }
117}
118
119/// Softmax mode (instance vs channel-wise).
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
121pub enum SoftmaxMode {
122    Instance,
123    Channel,
124}
125
126/// Softmax forward request.
127pub struct SoftmaxFwdRequest<T: CudnnSupported> {
128    pub mode: SoftmaxMode,
129    pub x: GpuRef<T>,
130    pub y: GpuRef<T>,
131    pub dims: Vec<i64>,
132    pub layout: TensorLayout,
133    pub alpha: T::Scalar,
134    pub beta: T::Scalar,
135    pub reply: oneshot::Sender<Result<(), GpuError>>,
136    pub _ty: PhantomData<T>,
137}
138
139impl<T: CudnnSupported> CudnnDispatch for SoftmaxFwdRequest<T> {
140    fn dtype_name(&self) -> &'static str {
141        T::NAME
142    }
143    fn op_kind(&self) -> &'static str {
144        "softmax_fwd"
145    }
146    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
147        let _ = self.reply.send(Err(GpuError::LibraryError {
148            lib: "cudnn",
149            msg: "SoftmaxFwdRequest dispatch requires the v9 frontend graph builder; \
150                  skeleton entry point only"
151                .to_string(),
152        }));
153    }
154}
155
156/// Dropout forward request: produces `y = x * mask / (1 - p)` and
157/// records the mask state for backward.
158pub struct DropoutFwdRequest<T: CudnnSupported> {
159    pub x: GpuRef<T>,
160    pub y: GpuRef<T>,
161    pub mask: GpuRef<u8>,
162    pub dims: Vec<i64>,
163    pub layout: TensorLayout,
164    pub probability: f32,
165    pub seed: u64,
166    pub reply: oneshot::Sender<Result<(), GpuError>>,
167    pub _ty: PhantomData<T>,
168}
169
170impl<T: CudnnSupported> CudnnDispatch for DropoutFwdRequest<T> {
171    fn dtype_name(&self) -> &'static str {
172        T::NAME
173    }
174    fn op_kind(&self) -> &'static str {
175        "dropout_fwd"
176    }
177    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
178        let _ = self.reply.send(Err(GpuError::LibraryError {
179            lib: "cudnn",
180            msg: "DropoutFwdRequest dispatch requires the v9 frontend graph builder; \
181                  skeleton entry point only"
182                .to_string(),
183        }));
184    }
185}
186
187/// Local-response-normalisation parameters.
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
189pub struct LrnParams {
190    pub n: u32,
191    pub alpha_milli: i64,
192    pub beta_milli: i64,
193    pub k_milli: i64,
194}
195
196impl LrnParams {
197    pub fn new(n: u32, alpha: f64, beta: f64, k: f64) -> Self {
198        Self {
199            n,
200            alpha_milli: (alpha * 1_000_000.0) as i64,
201            beta_milli: (beta * 1_000_000.0) as i64,
202            k_milli: (k * 1_000_000.0) as i64,
203        }
204    }
205}
206
207/// LRN forward request.
208pub struct LrnFwdRequest<T: CudnnSupported> {
209    pub x: GpuRef<T>,
210    pub y: GpuRef<T>,
211    pub dims: Vec<i64>,
212    pub layout: TensorLayout,
213    pub params: LrnParams,
214    pub alpha: T::Scalar,
215    pub beta: T::Scalar,
216    pub reply: oneshot::Sender<Result<(), GpuError>>,
217    pub _ty: PhantomData<T>,
218}
219
220impl<T: CudnnSupported> CudnnDispatch for LrnFwdRequest<T> {
221    fn dtype_name(&self) -> &'static str {
222        T::NAME
223    }
224    fn op_kind(&self) -> &'static str {
225        "lrn_fwd"
226    }
227    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
228        let _ = self.reply.send(Err(GpuError::LibraryError {
229            lib: "cudnn",
230            msg: "LrnFwdRequest dispatch requires the v9 frontend graph builder; \
231                  skeleton entry point only"
232                .to_string(),
233        }));
234    }
235}
236
237/// Build the spec-side activation-fwd op graph.
238pub fn build_activation_fwd_graph(
239    dtype: DtypeTag,
240    dims: &[i64],
241    layout: TensorLayout,
242    kind: ActivationKind,
243) -> OperationGraphSpec {
244    let mut g = OperationGraphSpec::new("activation_fwd");
245    let x_uid = g.add_tensor(TensorSpec::new(1, dtype, dims.to_vec(), layout));
246    let y_uid = g.add_tensor(TensorSpec::new(2, dtype, dims.to_vec(), layout));
247    g.add_op(OpSpec::Pointwise {
248        mode: kind.pointwise_mode(),
249        x: x_uid,
250        b: None,
251        y: y_uid,
252        compute_dtype: dtype,
253        alpha1: 1.0,
254        alpha2: 0.0,
255    });
256    g
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn activation_kinds_have_pointwise_mode() {
265        assert_eq!(ActivationKind::Relu.pointwise_mode(), PointwiseMode::Relu);
266        assert_eq!(ActivationKind::Gelu.pointwise_mode(), PointwiseMode::Gelu);
267        assert_eq!(ActivationKind::Swish.pointwise_mode(), PointwiseMode::Swish);
268        assert_eq!(
269            ActivationKind::Softplus.pointwise_mode(),
270            PointwiseMode::Softplus
271        );
272        assert_eq!(ActivationKind::Elu.pointwise_mode(), PointwiseMode::Elu);
273        assert_eq!(
274            ActivationKind::Identity.pointwise_mode(),
275            PointwiseMode::Identity
276        );
277    }
278
279    #[test]
280    fn activation_fwd_graph_builds() {
281        let g = build_activation_fwd_graph(
282            DtypeTag::F32,
283            &[1, 3, 8, 8],
284            TensorLayout::NchwPacked,
285            ActivationKind::Gelu,
286        );
287        assert_eq!(g.tensors.len(), 2);
288        assert_eq!(g.ops.len(), 1);
289    }
290
291    #[test]
292    fn lrn_params_quantization() {
293        let p = LrnParams::new(5, 0.0001, 0.75, 1.0);
294        assert_eq!(p.n, 5);
295        assert_eq!(p.alpha_milli, 100);
296        assert_eq!(p.beta_milli, 750_000);
297    }
298}