Skip to main content

atomr_accel_cuda/kernel/cudnn/
rnn.rs

1//! RNN / LSTM / GRU forward + backward training requests.
2//!
3//! Routes through cuDNN's RNN v8 API (`cudnnRNNForward` /
4//! `cudnnRNNBackwardData_v8` / `cudnnRNNBackwardWeights_v8`).
5
6#![allow(dead_code)]
7
8use std::marker::PhantomData;
9
10use tokio::sync::oneshot;
11
12use crate::dtype::CudnnSupported;
13use crate::error::GpuError;
14use crate::gpu_ref::GpuRef;
15use crate::kernel::cudnn::conv::dtype_tag;
16use crate::kernel::cudnn::graph::{DtypeTag, OperationGraphSpec, TensorLayout, TensorSpec};
17use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
18
19/// RNN cell mode.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum RnnMode {
22    Rnn,
23    RnnTanh,
24    Lstm,
25    Gru,
26}
27
28impl RnnMode {
29    pub fn op_kind(self) -> &'static str {
30        match self {
31            RnnMode::Rnn => "rnn_relu",
32            RnnMode::RnnTanh => "rnn_tanh",
33            RnnMode::Lstm => "lstm",
34            RnnMode::Gru => "gru",
35        }
36    }
37
38    /// Number of gates / linear projections internal to one cell.
39    pub fn num_gates(self) -> u32 {
40        match self {
41            RnnMode::Rnn | RnnMode::RnnTanh => 1,
42            RnnMode::Lstm => 4,
43            RnnMode::Gru => 3,
44        }
45    }
46
47    /// Whether the cell carries a separate cell state (LSTM only).
48    pub fn has_cell_state(self) -> bool {
49        matches!(self, RnnMode::Lstm)
50    }
51}
52
53/// Direction.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum RnnDirection {
56    Unidirectional,
57    Bidirectional,
58}
59
60impl RnnDirection {
61    pub fn factor(self) -> u32 {
62        match self {
63            RnnDirection::Unidirectional => 1,
64            RnnDirection::Bidirectional => 2,
65        }
66    }
67}
68
69/// RNN parameters.
70#[derive(Debug, Clone, PartialEq)]
71pub struct RnnParams {
72    pub mode: RnnMode,
73    pub direction: RnnDirection,
74    pub num_layers: u32,
75    pub input_size: i64,
76    pub hidden_size: i64,
77    pub seq_length: i64,
78    pub batch_size: i64,
79    pub dropout: f32,
80}
81
82impl RnnParams {
83    pub fn new(
84        mode: RnnMode,
85        direction: RnnDirection,
86        num_layers: u32,
87        input_size: i64,
88        hidden_size: i64,
89        seq_length: i64,
90        batch_size: i64,
91    ) -> Self {
92        Self {
93            mode,
94            direction,
95            num_layers,
96            input_size,
97            hidden_size,
98            seq_length,
99            batch_size,
100            dropout: 0.0,
101        }
102    }
103
104    pub fn with_dropout(mut self, d: f32) -> Self {
105        self.dropout = d;
106        self
107    }
108
109    pub fn output_size(&self) -> i64 {
110        self.hidden_size * self.direction.factor() as i64
111    }
112}
113
114/// RNN forward request.
115pub struct RnnFwdRequest<T: CudnnSupported> {
116    pub x: GpuRef<T>,
117    pub h_in: GpuRef<T>,
118    pub c_in: Option<GpuRef<T>>,
119    pub weights: GpuRef<T>,
120    pub y: GpuRef<T>,
121    pub h_out: GpuRef<T>,
122    pub c_out: Option<GpuRef<T>>,
123    pub layout: TensorLayout,
124    pub params: RnnParams,
125    pub reply: oneshot::Sender<Result<(), GpuError>>,
126    pub _ty: PhantomData<T>,
127}
128
129impl<T: CudnnSupported> RnnFwdRequest<T> {
130    pub fn graph_spec(&self) -> OperationGraphSpec {
131        build_rnn_fwd_spec(dtype_tag::<T>(), &self.params, self.layout)
132    }
133}
134
135impl<T: CudnnSupported> CudnnDispatch for RnnFwdRequest<T> {
136    fn dtype_name(&self) -> &'static str {
137        T::NAME
138    }
139    fn op_kind(&self) -> &'static str {
140        match self.params.mode {
141            RnnMode::Rnn | RnnMode::RnnTanh => "rnn_fwd",
142            RnnMode::Lstm => "lstm_fwd",
143            RnnMode::Gru => "gru_fwd",
144        }
145    }
146    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
147        let _ = self.reply.send(Err(GpuError::LibraryError {
148            lib: "cudnn",
149            msg: "RnnFwdRequest dispatch requires the v8 RNN API path; \
150                  skeleton entry point only"
151                .to_string(),
152        }));
153    }
154}
155
156/// RNN backward (data + weights) request.
157pub struct RnnBwdRequest<T: CudnnSupported> {
158    pub x: GpuRef<T>,
159    pub y: GpuRef<T>,
160    pub dy: GpuRef<T>,
161    pub h_in: GpuRef<T>,
162    pub c_in: Option<GpuRef<T>>,
163    pub h_out: GpuRef<T>,
164    pub c_out: Option<GpuRef<T>>,
165    pub dh_out: GpuRef<T>,
166    pub dc_out: Option<GpuRef<T>>,
167    pub weights: GpuRef<T>,
168    pub dx: GpuRef<T>,
169    pub dh_in: GpuRef<T>,
170    pub dc_in: Option<GpuRef<T>>,
171    pub dweights: GpuRef<T>,
172    pub layout: TensorLayout,
173    pub params: RnnParams,
174    pub reply: oneshot::Sender<Result<(), GpuError>>,
175    pub _ty: PhantomData<T>,
176}
177
178impl<T: CudnnSupported> CudnnDispatch for RnnBwdRequest<T> {
179    fn dtype_name(&self) -> &'static str {
180        T::NAME
181    }
182    fn op_kind(&self) -> &'static str {
183        match self.params.mode {
184            RnnMode::Rnn | RnnMode::RnnTanh => "rnn_bwd",
185            RnnMode::Lstm => "lstm_bwd",
186            RnnMode::Gru => "gru_bwd",
187        }
188    }
189    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
190        let _ = self.reply.send(Err(GpuError::LibraryError {
191            lib: "cudnn",
192            msg: "RnnBwdRequest dispatch requires the v8 RNN API path; \
193                  skeleton entry point only"
194                .to_string(),
195        }));
196    }
197}
198
199/// Build a spec-side `OperationGraphSpec` for plan-cache keying.
200/// The actual RNN API does not use the v9 backend graph descriptor
201/// builder — but we keep a stable signature surface so the cache key
202/// machinery is shared.
203pub fn build_rnn_fwd_spec(
204    dtype: DtypeTag,
205    p: &RnnParams,
206    layout: TensorLayout,
207) -> OperationGraphSpec {
208    let mut g = OperationGraphSpec::new("rnn_fwd");
209    let x_dims = vec![p.seq_length, p.batch_size, p.input_size];
210    let h_dims = vec![
211        p.num_layers as i64 * p.direction.factor() as i64,
212        p.batch_size,
213        p.hidden_size,
214    ];
215    let y_dims = vec![p.seq_length, p.batch_size, p.output_size()];
216    g.add_tensor(TensorSpec::new(1, dtype, x_dims, layout));
217    g.add_tensor(TensorSpec::new(2, dtype, h_dims.clone(), layout));
218    g.add_tensor(TensorSpec::new(3, dtype, y_dims, layout));
219    g.add_tensor(TensorSpec::new(4, dtype, h_dims.clone(), layout));
220    if p.mode.has_cell_state() {
221        g.add_tensor(TensorSpec::new(5, dtype, h_dims.clone(), layout));
222        g.add_tensor(TensorSpec::new(6, dtype, h_dims, layout));
223    }
224    // Weights tensor — packed dim depends on (input_size, hidden_size,
225    // num_gates, num_layers, direction). Use a single placeholder
226    // dim; the plan-cache key digests this via the rest of the spec.
227    let weight_dim = p.mode.num_gates() as i64
228        * p.hidden_size
229        * (p.input_size + p.hidden_size + 2)
230        * p.num_layers as i64
231        * p.direction.factor() as i64;
232    g.add_tensor(TensorSpec::new(99, dtype, vec![weight_dim], layout));
233    g
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    fn round_trip(mode: RnnMode) {
241        let p = RnnParams::new(mode, RnnDirection::Bidirectional, 2, 128, 256, 32, 8);
242        let g = build_rnn_fwd_spec(DtypeTag::F32, &p, TensorLayout::NchwPacked);
243        // x, h_in, y, h_out [, c_in, c_out], weights
244        let expected = if mode.has_cell_state() { 7 } else { 5 };
245        assert_eq!(g.tensors.len(), expected);
246        assert_eq!(p.output_size(), 512);
247    }
248
249    #[test]
250    fn rnn_lstm_gru_request_round_trip() {
251        round_trip(RnnMode::Rnn);
252        round_trip(RnnMode::RnnTanh);
253        round_trip(RnnMode::Lstm);
254        round_trip(RnnMode::Gru);
255    }
256
257    #[test]
258    fn cell_state_only_for_lstm() {
259        assert!(!RnnMode::Rnn.has_cell_state());
260        assert!(!RnnMode::Gru.has_cell_state());
261        assert!(RnnMode::Lstm.has_cell_state());
262    }
263
264    #[test]
265    fn gate_counts() {
266        assert_eq!(RnnMode::Rnn.num_gates(), 1);
267        assert_eq!(RnnMode::Lstm.num_gates(), 4);
268        assert_eq!(RnnMode::Gru.num_gates(), 3);
269    }
270}