1#![allow(dead_code)]
9
10use std::marker::PhantomData;
11
12use tokio::sync::oneshot;
13
14use crate::dtype::CudnnSupported;
15use crate::error::GpuError;
16use crate::gpu_ref::GpuRef;
17use crate::kernel::cudnn::conv::dtype_tag;
18use crate::kernel::cudnn::graph::{DtypeTag, OpSpec, OperationGraphSpec, TensorLayout, TensorSpec};
19use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum AttentionMask {
24 None,
25 Causal,
26 SlidingWindow(u32),
28 CausalSlidingWindow(u32),
30}
31
32#[derive(Debug, Clone, PartialEq)]
34pub struct AttentionParams {
35 pub batch: i64,
36 pub seq_q: i64,
37 pub seq_kv: i64,
38 pub heads_q: i64,
39 pub heads_kv: i64,
40 pub head_dim: i64,
41 pub mask: AttentionMask,
42 pub scale: f64,
44 pub dropout: f32,
46 pub dropout_seed: u64,
47}
48
49impl AttentionParams {
50 pub fn new(
51 batch: i64,
52 seq_q: i64,
53 seq_kv: i64,
54 heads_q: i64,
55 heads_kv: i64,
56 head_dim: i64,
57 ) -> Self {
58 Self {
59 batch,
60 seq_q,
61 seq_kv,
62 heads_q,
63 heads_kv,
64 head_dim,
65 mask: AttentionMask::None,
66 scale: 1.0 / (head_dim as f64).sqrt(),
67 dropout: 0.0,
68 dropout_seed: 0,
69 }
70 }
71
72 pub fn with_mask(mut self, m: AttentionMask) -> Self {
73 self.mask = m;
74 self
75 }
76
77 pub fn with_dropout(mut self, p: f32, seed: u64) -> Self {
78 self.dropout = p;
79 self.dropout_seed = seed;
80 self
81 }
82
83 pub fn is_gqa(&self) -> bool {
84 self.heads_q != self.heads_kv
85 }
86}
87
88pub struct MultiHeadAttnFwdRequest<T: CudnnSupported> {
90 pub q: GpuRef<T>,
91 pub k: GpuRef<T>,
92 pub v: GpuRef<T>,
93 pub o: GpuRef<T>,
94 pub stats: Option<GpuRef<T>>,
96 pub bias: Option<GpuRef<T>>,
98 pub layout: TensorLayout,
99 pub params: AttentionParams,
100 pub reply: oneshot::Sender<Result<(), GpuError>>,
101 pub _ty: PhantomData<T>,
102}
103
104impl<T: CudnnSupported> MultiHeadAttnFwdRequest<T> {
105 pub fn graph_spec(&self) -> OperationGraphSpec {
106 build_mha_fwd_graph(dtype_tag::<T>(), &self.params, self.layout)
107 }
108}
109
110impl<T: CudnnSupported> CudnnDispatch for MultiHeadAttnFwdRequest<T> {
111 fn dtype_name(&self) -> &'static str {
112 T::NAME
113 }
114 fn op_kind(&self) -> &'static str {
115 "mha_fwd"
116 }
117 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
118 let _ = self.reply.send(Err(GpuError::LibraryError {
119 lib: "cudnn",
120 msg: "MultiHeadAttnFwdRequest dispatch requires the v9 fused-attention path; \
121 skeleton entry point only"
122 .to_string(),
123 }));
124 }
125}
126
127pub struct MultiHeadAttnBwdRequest<T: CudnnSupported> {
129 pub q: GpuRef<T>,
130 pub k: GpuRef<T>,
131 pub v: GpuRef<T>,
132 pub o: GpuRef<T>,
133 pub do_: GpuRef<T>,
134 pub dq: GpuRef<T>,
135 pub dk: GpuRef<T>,
136 pub dv: GpuRef<T>,
137 pub stats: GpuRef<T>,
138 pub layout: TensorLayout,
139 pub params: AttentionParams,
140 pub reply: oneshot::Sender<Result<(), GpuError>>,
141 pub _ty: PhantomData<T>,
142}
143
144impl<T: CudnnSupported> MultiHeadAttnBwdRequest<T> {
145 pub fn graph_spec(&self) -> OperationGraphSpec {
146 build_mha_bwd_graph(dtype_tag::<T>(), &self.params, self.layout)
147 }
148}
149
150impl<T: CudnnSupported> CudnnDispatch for MultiHeadAttnBwdRequest<T> {
151 fn dtype_name(&self) -> &'static str {
152 T::NAME
153 }
154 fn op_kind(&self) -> &'static str {
155 "mha_bwd"
156 }
157 fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
158 let _ = self.reply.send(Err(GpuError::LibraryError {
159 lib: "cudnn",
160 msg: "MultiHeadAttnBwdRequest dispatch requires the v9 fused-attention path; \
161 skeleton entry point only"
162 .to_string(),
163 }));
164 }
165}
166
167pub fn build_mha_fwd_graph(
168 dtype: DtypeTag,
169 p: &AttentionParams,
170 layout: TensorLayout,
171) -> OperationGraphSpec {
172 let mut g = OperationGraphSpec::new("mha_fwd");
173 let q_dims = vec![p.batch, p.heads_q, p.seq_q, p.head_dim];
174 let k_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
175 let v_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
176 let o_dims = vec![p.batch, p.heads_q, p.seq_q, p.head_dim];
177 let qk_dims = vec![p.batch, p.heads_q, p.seq_q, p.seq_kv];
178
179 let q_uid = g.add_tensor(TensorSpec::new(1, dtype, q_dims, layout));
180 let k_uid = g.add_tensor(TensorSpec::new(2, dtype, k_dims, layout));
181 let v_uid = g.add_tensor(TensorSpec::new(3, dtype, v_dims, layout));
182 let qk_uid = g.add_tensor(TensorSpec::new(4, dtype, qk_dims.clone(), layout).virtualized());
183 let qk_softmax_uid = g.add_tensor(TensorSpec::new(5, dtype, qk_dims, layout).virtualized());
184 let o_uid = g.add_tensor(TensorSpec::new(6, dtype, o_dims, layout));
185
186 g.add_op(OpSpec::Matmul {
188 a: q_uid,
189 b: k_uid,
190 c: qk_uid,
191 compute_dtype: dtype,
192 });
193 g.add_op(OpSpec::Pointwise {
197 mode: super::graph::PointwiseMode::Identity,
198 x: qk_uid,
199 b: None,
200 y: qk_softmax_uid,
201 compute_dtype: dtype,
202 alpha1: p.scale,
203 alpha2: 0.0,
204 });
205 g.add_op(OpSpec::Matmul {
207 a: qk_softmax_uid,
208 b: v_uid,
209 c: o_uid,
210 compute_dtype: dtype,
211 });
212 g
213}
214
215pub fn build_mha_bwd_graph(
216 dtype: DtypeTag,
217 p: &AttentionParams,
218 layout: TensorLayout,
219) -> OperationGraphSpec {
220 let mut g = OperationGraphSpec::new("mha_bwd");
221 let q_dims = vec![p.batch, p.heads_q, p.seq_q, p.head_dim];
224 let k_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
225 let v_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
226
227 g.add_tensor(TensorSpec::new(1, dtype, q_dims.clone(), layout));
228 g.add_tensor(TensorSpec::new(2, dtype, k_dims.clone(), layout));
229 g.add_tensor(TensorSpec::new(3, dtype, v_dims.clone(), layout));
230 g.add_tensor(TensorSpec::new(4, dtype, q_dims.clone(), layout));
231 g.add_tensor(TensorSpec::new(5, dtype, k_dims.clone(), layout));
232 g.add_tensor(TensorSpec::new(6, dtype, v_dims.clone(), layout));
233
234 g.add_op(OpSpec::Matmul {
235 a: 4,
236 b: 2,
237 c: 7,
238 compute_dtype: dtype,
239 });
240 g.add_op(OpSpec::Matmul {
241 a: 4,
242 b: 3,
243 c: 8,
244 compute_dtype: dtype,
245 });
246 g.add_op(OpSpec::Matmul {
247 a: 1,
248 b: 5,
249 c: 9,
250 compute_dtype: dtype,
251 });
252 g
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn mha_fwd_bwd_request_round_trip() {
261 let p = AttentionParams::new(2, 128, 128, 8, 8, 64).with_mask(AttentionMask::Causal);
262 let g_fwd = build_mha_fwd_graph(DtypeTag::Bf16, &p, TensorLayout::NchwPacked);
263 assert_eq!(g_fwd.tensors.len(), 6);
265 assert_eq!(g_fwd.ops.len(), 3);
267
268 let g_bwd = build_mha_bwd_graph(DtypeTag::Bf16, &p, TensorLayout::NchwPacked);
269 assert!(g_bwd.ops.len() >= 3);
270
271 let gqa = AttentionParams::new(1, 128, 128, 16, 4, 64);
273 assert!(gqa.is_gqa());
274 let g_gqa = build_mha_fwd_graph(DtypeTag::Bf16, &gqa, TensorLayout::NchwPacked);
275 assert_ne!(g_fwd.signature(), g_gqa.signature());
276
277 let p2 =
282 AttentionParams::new(2, 128, 128, 8, 8, 64).with_mask(AttentionMask::SlidingWindow(64));
283 assert!(matches!(p2.mask, AttentionMask::SlidingWindow(64)));
284 }
285}