atomr_accel_cuda/kernel/cudnn/
rnn.rs1#![allow(dead_code)]
7
8use std::marker::PhantomData;
9
10use tokio::sync::oneshot;
11
12use crate::dtype::CudnnSupported;
13use crate::error::GpuError;
14use crate::gpu_ref::GpuRef;
15use crate::kernel::cudnn::conv::dtype_tag;
16use crate::kernel::cudnn::graph::{DtypeTag, OperationGraphSpec, TensorLayout, TensorSpec};
17use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum RnnMode {
22 Rnn,
23 RnnTanh,
24 Lstm,
25 Gru,
26}
27
28impl RnnMode {
29 pub fn op_kind(self) -> &'static str {
30 match self {
31 RnnMode::Rnn => "rnn_relu",
32 RnnMode::RnnTanh => "rnn_tanh",
33 RnnMode::Lstm => "lstm",
34 RnnMode::Gru => "gru",
35 }
36 }
37
38 pub fn num_gates(self) -> u32 {
40 match self {
41 RnnMode::Rnn | RnnMode::RnnTanh => 1,
42 RnnMode::Lstm => 4,
43 RnnMode::Gru => 3,
44 }
45 }
46
47 pub fn has_cell_state(self) -> bool {
49 matches!(self, RnnMode::Lstm)
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum RnnDirection {
56 Unidirectional,
57 Bidirectional,
58}
59
60impl RnnDirection {
61 pub fn factor(self) -> u32 {
62 match self {
63 RnnDirection::Unidirectional => 1,
64 RnnDirection::Bidirectional => 2,
65 }
66 }
67}
68
69#[derive(Debug, Clone, PartialEq)]
71pub struct RnnParams {
72 pub mode: RnnMode,
73 pub direction: RnnDirection,
74 pub num_layers: u32,
75 pub input_size: i64,
76 pub hidden_size: i64,
77 pub seq_length: i64,
78 pub batch_size: i64,
79 pub dropout: f32,
80}
81
82impl RnnParams {
83 pub fn new(
84 mode: RnnMode,
85 direction: RnnDirection,
86 num_layers: u32,
87 input_size: i64,
88 hidden_size: i64,
89 seq_length: i64,
90 batch_size: i64,
91 ) -> Self {
92 Self {
93 mode,
94 direction,
95 num_layers,
96 input_size,
97 hidden_size,
98 seq_length,
99 batch_size,
100 dropout: 0.0,
101 }
102 }
103
104 pub fn with_dropout(mut self, d: f32) -> Self {
105 self.dropout = d;
106 self
107 }
108
109 pub fn output_size(&self) -> i64 {
110 self.hidden_size * self.direction.factor() as i64
111 }
112}
113
114pub struct RnnFwdRequest<T: CudnnSupported> {
116 pub x: GpuRef<T>,
117 pub h_in: GpuRef<T>,
118 pub c_in: Option<GpuRef<T>>,
119 pub weights: GpuRef<T>,
120 pub y: GpuRef<T>,
121 pub h_out: GpuRef<T>,
122 pub c_out: Option<GpuRef<T>>,
123 pub layout: TensorLayout,
124 pub params: RnnParams,
125 pub reply: oneshot::Sender<Result<(), GpuError>>,
126 pub _ty: PhantomData<T>,
127}
128
129impl<T: CudnnSupported> RnnFwdRequest<T> {
130 pub fn graph_spec(&self) -> OperationGraphSpec {
131 build_rnn_fwd_spec(dtype_tag::<T>(), &self.params, self.layout)
132 }
133}
134
135impl<T: CudnnSupported> CudnnDispatch for RnnFwdRequest<T> {
136 fn dtype_name(&self) -> &'static str {
137 T::NAME
138 }
139 fn op_kind(&self) -> &'static str {
140 match self.params.mode {
141 RnnMode::Rnn | RnnMode::RnnTanh => "rnn_fwd",
142 RnnMode::Lstm => "lstm_fwd",
143 RnnMode::Gru => "gru_fwd",
144 }
145 }
146 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
147 let _ = self.reply.send(Err(GpuError::LibraryError {
148 lib: "cudnn",
149 msg: "RnnFwdRequest dispatch requires the v8 RNN API path; \
150 skeleton entry point only"
151 .to_string(),
152 }));
153 }
154}
155
156pub struct RnnBwdRequest<T: CudnnSupported> {
158 pub x: GpuRef<T>,
159 pub y: GpuRef<T>,
160 pub dy: GpuRef<T>,
161 pub h_in: GpuRef<T>,
162 pub c_in: Option<GpuRef<T>>,
163 pub h_out: GpuRef<T>,
164 pub c_out: Option<GpuRef<T>>,
165 pub dh_out: GpuRef<T>,
166 pub dc_out: Option<GpuRef<T>>,
167 pub weights: GpuRef<T>,
168 pub dx: GpuRef<T>,
169 pub dh_in: GpuRef<T>,
170 pub dc_in: Option<GpuRef<T>>,
171 pub dweights: GpuRef<T>,
172 pub layout: TensorLayout,
173 pub params: RnnParams,
174 pub reply: oneshot::Sender<Result<(), GpuError>>,
175 pub _ty: PhantomData<T>,
176}
177
178impl<T: CudnnSupported> CudnnDispatch for RnnBwdRequest<T> {
179 fn dtype_name(&self) -> &'static str {
180 T::NAME
181 }
182 fn op_kind(&self) -> &'static str {
183 match self.params.mode {
184 RnnMode::Rnn | RnnMode::RnnTanh => "rnn_bwd",
185 RnnMode::Lstm => "lstm_bwd",
186 RnnMode::Gru => "gru_bwd",
187 }
188 }
189 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
190 let _ = self.reply.send(Err(GpuError::LibraryError {
191 lib: "cudnn",
192 msg: "RnnBwdRequest dispatch requires the v8 RNN API path; \
193 skeleton entry point only"
194 .to_string(),
195 }));
196 }
197}
198
199pub fn build_rnn_fwd_spec(
204 dtype: DtypeTag,
205 p: &RnnParams,
206 layout: TensorLayout,
207) -> OperationGraphSpec {
208 let mut g = OperationGraphSpec::new("rnn_fwd");
209 let x_dims = vec![p.seq_length, p.batch_size, p.input_size];
210 let h_dims = vec![
211 p.num_layers as i64 * p.direction.factor() as i64,
212 p.batch_size,
213 p.hidden_size,
214 ];
215 let y_dims = vec![p.seq_length, p.batch_size, p.output_size()];
216 g.add_tensor(TensorSpec::new(1, dtype, x_dims, layout));
217 g.add_tensor(TensorSpec::new(2, dtype, h_dims.clone(), layout));
218 g.add_tensor(TensorSpec::new(3, dtype, y_dims, layout));
219 g.add_tensor(TensorSpec::new(4, dtype, h_dims.clone(), layout));
220 if p.mode.has_cell_state() {
221 g.add_tensor(TensorSpec::new(5, dtype, h_dims.clone(), layout));
222 g.add_tensor(TensorSpec::new(6, dtype, h_dims, layout));
223 }
224 let weight_dim = p.mode.num_gates() as i64
228 * p.hidden_size
229 * (p.input_size + p.hidden_size + 2)
230 * p.num_layers as i64
231 * p.direction.factor() as i64;
232 g.add_tensor(TensorSpec::new(99, dtype, vec![weight_dim], layout));
233 g
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 fn round_trip(mode: RnnMode) {
241 let p = RnnParams::new(mode, RnnDirection::Bidirectional, 2, 128, 256, 32, 8);
242 let g = build_rnn_fwd_spec(DtypeTag::F32, &p, TensorLayout::NchwPacked);
243 let expected = if mode.has_cell_state() { 7 } else { 5 };
245 assert_eq!(g.tensors.len(), expected);
246 assert_eq!(p.output_size(), 512);
247 }
248
249 #[test]
250 fn rnn_lstm_gru_request_round_trip() {
251 round_trip(RnnMode::Rnn);
252 round_trip(RnnMode::RnnTanh);
253 round_trip(RnnMode::Lstm);
254 round_trip(RnnMode::Gru);
255 }
256
257 #[test]
258 fn cell_state_only_for_lstm() {
259 assert!(!RnnMode::Rnn.has_cell_state());
260 assert!(!RnnMode::Gru.has_cell_state());
261 assert!(RnnMode::Lstm.has_cell_state());
262 }
263
264 #[test]
265 fn gate_counts() {
266 assert_eq!(RnnMode::Rnn.num_gates(), 1);
267 assert_eq!(RnnMode::Lstm.num_gates(), 4);
268 assert_eq!(RnnMode::Gru.num_gates(), 3);
269 }
270}