atomr_accel_cuda/hopper/cp_async.rs
1//! `cp.async` pipeline macro shim.
2//!
3//! `cp.async` (sm_80+) and the Hopper-introduced `cp.async.bulk`
4//! variants live entirely on the device side. This module provides
5//! Rust constants for the macro names defined in `atomr_hopper.cuh`
6//! plus host-side helpers that compute the right `mbarrier` arrival
7//! count for a given pipeline stage count.
8
9/// Number of mbarrier arrival slots needed to fence a pipelined
10/// producer/consumer with the given `stages`. Matches the formula
11/// `2 * stages` used by stage-balanced double-buffer kernels.
12pub const fn mbarrier_arrival_count(stages: u32) -> u32 {
13 stages * 2
14}
15
16/// Pipeline stage policy for `cp.async`-driven shared-memory
17/// double/triple/quad buffering.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum PipelineStages {
20 Double,
21 Triple,
22 Quad,
23}
24
25impl PipelineStages {
26 pub fn count(self) -> u32 {
27 match self {
28 PipelineStages::Double => 2,
29 PipelineStages::Triple => 3,
30 PipelineStages::Quad => 4,
31 }
32 }
33}
34
35/// Macro names exposed by `atomr_hopper.cuh` for callers to reference
36/// in their NVRTC sources.
37pub mod macro_names {
38 /// Issue a `cp.async.cg.shared.global` (16B-aligned). 4-arg macro:
39 /// `(dst_smem_addr, src_global_addr, bytes, predicate)`.
40 pub const CP_ASYNC_CG_16: &str = "ATOMR_CP_ASYNC_CG_16";
41 /// Issue a `cp.async.ca.shared.global` (4B-aligned, cache-all).
42 pub const CP_ASYNC_CA_4: &str = "ATOMR_CP_ASYNC_CA_4";
43 /// Commit-group barrier (`cp.async.commit_group`).
44 pub const CP_ASYNC_COMMIT_GROUP: &str = "ATOMR_CP_ASYNC_COMMIT_GROUP";
45 /// Wait-group barrier (`cp.async.wait_group <N>`).
46 pub const CP_ASYNC_WAIT_GROUP: &str = "ATOMR_CP_ASYNC_WAIT_GROUP";
47 /// Bulk async copy (`cp.async.bulk.shared::cluster.global`) for
48 /// TMA-driven loads.
49 pub const CP_ASYNC_BULK: &str = "ATOMR_CP_ASYNC_BULK";
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 #[test]
57 fn arrival_counts_balanced() {
58 assert_eq!(mbarrier_arrival_count(2), 4);
59 assert_eq!(mbarrier_arrival_count(3), 6);
60 assert_eq!(mbarrier_arrival_count(4), 8);
61 assert_eq!(PipelineStages::Double.count(), 2);
62 assert_eq!(PipelineStages::Triple.count(), 3);
63 assert_eq!(PipelineStages::Quad.count(), 4);
64 }
65}