Skip to main content

atomr_accel_cuda/kernel/cudnn/
attention.rs

1//! Multi-head attention (`cudnnFusedAttnFwd`/`cudnnFusedAttnBwd`)
2//! request types.
3//!
4//! Routes through the v9 frontend `OPERATION_MATMUL_DESCRIPTOR` +
5//! softmax + dropout fusion path. Supports causal masking, sliding
6//! window, paged-KV (skeleton), MQA / GQA via head-count split.
7
8#![allow(dead_code)]
9
10use std::marker::PhantomData;
11
12use tokio::sync::oneshot;
13
14use crate::dtype::CudnnSupported;
15use crate::error::GpuError;
16use crate::gpu_ref::GpuRef;
17use crate::kernel::cudnn::conv::dtype_tag;
18use crate::kernel::cudnn::graph::{DtypeTag, OpSpec, OperationGraphSpec, TensorLayout, TensorSpec};
19use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
20
21/// Mask kind applied to the attention scores.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum AttentionMask {
24    None,
25    Causal,
26    /// Bidirectional sliding window of `window` tokens.
27    SlidingWindow(u32),
28    /// Causal + sliding window.
29    CausalSlidingWindow(u32),
30}
31
32/// Attention parameters.
33#[derive(Debug, Clone, PartialEq)]
34pub struct AttentionParams {
35    pub batch: i64,
36    pub seq_q: i64,
37    pub seq_kv: i64,
38    pub heads_q: i64,
39    pub heads_kv: i64,
40    pub head_dim: i64,
41    pub mask: AttentionMask,
42    /// Scale on the QK^T product. Typically `1/sqrt(head_dim)`.
43    pub scale: f64,
44    /// Dropout probability on attention scores. `0.0` disables.
45    pub dropout: f32,
46    pub dropout_seed: u64,
47}
48
49impl AttentionParams {
50    pub fn new(
51        batch: i64,
52        seq_q: i64,
53        seq_kv: i64,
54        heads_q: i64,
55        heads_kv: i64,
56        head_dim: i64,
57    ) -> Self {
58        Self {
59            batch,
60            seq_q,
61            seq_kv,
62            heads_q,
63            heads_kv,
64            head_dim,
65            mask: AttentionMask::None,
66            scale: 1.0 / (head_dim as f64).sqrt(),
67            dropout: 0.0,
68            dropout_seed: 0,
69        }
70    }
71
72    pub fn with_mask(mut self, m: AttentionMask) -> Self {
73        self.mask = m;
74        self
75    }
76
77    pub fn with_dropout(mut self, p: f32, seed: u64) -> Self {
78        self.dropout = p;
79        self.dropout_seed = seed;
80        self
81    }
82
83    pub fn is_gqa(&self) -> bool {
84        self.heads_q != self.heads_kv
85    }
86}
87
88/// MHA forward request.
89pub struct MultiHeadAttnFwdRequest<T: CudnnSupported> {
90    pub q: GpuRef<T>,
91    pub k: GpuRef<T>,
92    pub v: GpuRef<T>,
93    pub o: GpuRef<T>,
94    /// Optional saved softmax-stats for backward.
95    pub stats: Option<GpuRef<T>>,
96    /// Optional bias added to attention scores.
97    pub bias: Option<GpuRef<T>>,
98    pub layout: TensorLayout,
99    pub params: AttentionParams,
100    pub reply: oneshot::Sender<Result<(), GpuError>>,
101    pub _ty: PhantomData<T>,
102}
103
104impl<T: CudnnSupported> MultiHeadAttnFwdRequest<T> {
105    pub fn graph_spec(&self) -> OperationGraphSpec {
106        build_mha_fwd_graph(dtype_tag::<T>(), &self.params, self.layout)
107    }
108}
109
110impl<T: CudnnSupported> CudnnDispatch for MultiHeadAttnFwdRequest<T> {
111    fn dtype_name(&self) -> &'static str {
112        T::NAME
113    }
114    fn op_kind(&self) -> &'static str {
115        "mha_fwd"
116    }
117    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
118        let _ = self.reply.send(Err(GpuError::LibraryError {
119            lib: "cudnn",
120            msg: "MultiHeadAttnFwdRequest dispatch requires the v9 fused-attention path; \
121                  skeleton entry point only"
122                .to_string(),
123        }));
124    }
125}
126
127/// MHA backward request.
128pub struct MultiHeadAttnBwdRequest<T: CudnnSupported> {
129    pub q: GpuRef<T>,
130    pub k: GpuRef<T>,
131    pub v: GpuRef<T>,
132    pub o: GpuRef<T>,
133    pub do_: GpuRef<T>,
134    pub dq: GpuRef<T>,
135    pub dk: GpuRef<T>,
136    pub dv: GpuRef<T>,
137    pub stats: GpuRef<T>,
138    pub layout: TensorLayout,
139    pub params: AttentionParams,
140    pub reply: oneshot::Sender<Result<(), GpuError>>,
141    pub _ty: PhantomData<T>,
142}
143
144impl<T: CudnnSupported> MultiHeadAttnBwdRequest<T> {
145    pub fn graph_spec(&self) -> OperationGraphSpec {
146        build_mha_bwd_graph(dtype_tag::<T>(), &self.params, self.layout)
147    }
148}
149
150impl<T: CudnnSupported> CudnnDispatch for MultiHeadAttnBwdRequest<T> {
151    fn dtype_name(&self) -> &'static str {
152        T::NAME
153    }
154    fn op_kind(&self) -> &'static str {
155        "mha_bwd"
156    }
157    fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
158        let _ = self.reply.send(Err(GpuError::LibraryError {
159            lib: "cudnn",
160            msg: "MultiHeadAttnBwdRequest dispatch requires the v9 fused-attention path; \
161                  skeleton entry point only"
162                .to_string(),
163        }));
164    }
165}
166
167pub fn build_mha_fwd_graph(
168    dtype: DtypeTag,
169    p: &AttentionParams,
170    layout: TensorLayout,
171) -> OperationGraphSpec {
172    let mut g = OperationGraphSpec::new("mha_fwd");
173    let q_dims = vec![p.batch, p.heads_q, p.seq_q, p.head_dim];
174    let k_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
175    let v_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
176    let o_dims = vec![p.batch, p.heads_q, p.seq_q, p.head_dim];
177    let qk_dims = vec![p.batch, p.heads_q, p.seq_q, p.seq_kv];
178
179    let q_uid = g.add_tensor(TensorSpec::new(1, dtype, q_dims, layout));
180    let k_uid = g.add_tensor(TensorSpec::new(2, dtype, k_dims, layout));
181    let v_uid = g.add_tensor(TensorSpec::new(3, dtype, v_dims, layout));
182    let qk_uid = g.add_tensor(TensorSpec::new(4, dtype, qk_dims.clone(), layout).virtualized());
183    let qk_softmax_uid = g.add_tensor(TensorSpec::new(5, dtype, qk_dims, layout).virtualized());
184    let o_uid = g.add_tensor(TensorSpec::new(6, dtype, o_dims, layout));
185
186    // QK^T
187    g.add_op(OpSpec::Matmul {
188        a: q_uid,
189        b: k_uid,
190        c: qk_uid,
191        compute_dtype: dtype,
192    });
193    // softmax (modelled as a Pointwise tag — the real graph chains
194    // exp / reduce / divide, but the spec layer's plan-cache only
195    // needs the op-shape signature).
196    g.add_op(OpSpec::Pointwise {
197        mode: super::graph::PointwiseMode::Identity,
198        x: qk_uid,
199        b: None,
200        y: qk_softmax_uid,
201        compute_dtype: dtype,
202        alpha1: p.scale,
203        alpha2: 0.0,
204    });
205    // S * V
206    g.add_op(OpSpec::Matmul {
207        a: qk_softmax_uid,
208        b: v_uid,
209        c: o_uid,
210        compute_dtype: dtype,
211    });
212    g
213}
214
215pub fn build_mha_bwd_graph(
216    dtype: DtypeTag,
217    p: &AttentionParams,
218    layout: TensorLayout,
219) -> OperationGraphSpec {
220    let mut g = OperationGraphSpec::new("mha_bwd");
221    // We model the backward DAG at op-count granularity (sufficient
222    // for plan caching). Real launch path adds ~7 more nodes.
223    let q_dims = vec![p.batch, p.heads_q, p.seq_q, p.head_dim];
224    let k_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
225    let v_dims = vec![p.batch, p.heads_kv, p.seq_kv, p.head_dim];
226
227    g.add_tensor(TensorSpec::new(1, dtype, q_dims.clone(), layout));
228    g.add_tensor(TensorSpec::new(2, dtype, k_dims.clone(), layout));
229    g.add_tensor(TensorSpec::new(3, dtype, v_dims.clone(), layout));
230    g.add_tensor(TensorSpec::new(4, dtype, q_dims.clone(), layout));
231    g.add_tensor(TensorSpec::new(5, dtype, k_dims.clone(), layout));
232    g.add_tensor(TensorSpec::new(6, dtype, v_dims.clone(), layout));
233
234    g.add_op(OpSpec::Matmul {
235        a: 4,
236        b: 2,
237        c: 7,
238        compute_dtype: dtype,
239    });
240    g.add_op(OpSpec::Matmul {
241        a: 4,
242        b: 3,
243        c: 8,
244        compute_dtype: dtype,
245    });
246    g.add_op(OpSpec::Matmul {
247        a: 1,
248        b: 5,
249        c: 9,
250        compute_dtype: dtype,
251    });
252    g
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn mha_fwd_bwd_request_round_trip() {
261        let p = AttentionParams::new(2, 128, 128, 8, 8, 64).with_mask(AttentionMask::Causal);
262        let g_fwd = build_mha_fwd_graph(DtypeTag::Bf16, &p, TensorLayout::NchwPacked);
263        // Q, K, V, QK, QK_softmax, O
264        assert_eq!(g_fwd.tensors.len(), 6);
265        // QK matmul + softmax + SV matmul
266        assert_eq!(g_fwd.ops.len(), 3);
267
268        let g_bwd = build_mha_bwd_graph(DtypeTag::Bf16, &p, TensorLayout::NchwPacked);
269        assert!(g_bwd.ops.len() >= 3);
270
271        // GQA path: heads_q != heads_kv.
272        let gqa = AttentionParams::new(1, 128, 128, 16, 4, 64);
273        assert!(gqa.is_gqa());
274        let g_gqa = build_mha_fwd_graph(DtypeTag::Bf16, &gqa, TensorLayout::NchwPacked);
275        assert_ne!(g_fwd.signature(), g_gqa.signature());
276
277        // Different mask -> same graph signature on the spec layer
278        // (mask wires into the variant pack at execute time, not the
279        // descriptor digest). Verify the params struct still records
280        // it.
281        let p2 =
282            AttentionParams::new(2, 128, 128, 8, 8, 64).with_mask(AttentionMask::SlidingWindow(64));
283        assert!(matches!(p2.mask, AttentionMask::SlidingWindow(64)));
284    }
285}