Skip to main content

atomr_accel_cuda/hopper/
cp_async.rs

1//! `cp.async` pipeline macro shim.
2//!
3//! `cp.async` (sm_80+) and the Hopper-introduced `cp.async.bulk`
4//! variants live entirely on the device side. This module provides
5//! Rust constants for the macro names defined in `atomr_hopper.cuh`
6//! plus host-side helpers that compute the right `mbarrier` arrival
7//! count for a given pipeline stage count.
8
9/// Number of mbarrier arrival slots needed to fence a pipelined
10/// producer/consumer with the given `stages`. Matches the formula
11/// `2 * stages` used by stage-balanced double-buffer kernels.
12pub const fn mbarrier_arrival_count(stages: u32) -> u32 {
13    stages * 2
14}
15
16/// Pipeline stage policy for `cp.async`-driven shared-memory
17/// double/triple/quad buffering.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum PipelineStages {
20    Double,
21    Triple,
22    Quad,
23}
24
25impl PipelineStages {
26    pub fn count(self) -> u32 {
27        match self {
28            PipelineStages::Double => 2,
29            PipelineStages::Triple => 3,
30            PipelineStages::Quad => 4,
31        }
32    }
33}
34
35/// Macro names exposed by `atomr_hopper.cuh` for callers to reference
36/// in their NVRTC sources.
37pub mod macro_names {
38    /// Issue a `cp.async.cg.shared.global` (16B-aligned). 4-arg macro:
39    /// `(dst_smem_addr, src_global_addr, bytes, predicate)`.
40    pub const CP_ASYNC_CG_16: &str = "ATOMR_CP_ASYNC_CG_16";
41    /// Issue a `cp.async.ca.shared.global` (4B-aligned, cache-all).
42    pub const CP_ASYNC_CA_4: &str = "ATOMR_CP_ASYNC_CA_4";
43    /// Commit-group barrier (`cp.async.commit_group`).
44    pub const CP_ASYNC_COMMIT_GROUP: &str = "ATOMR_CP_ASYNC_COMMIT_GROUP";
45    /// Wait-group barrier (`cp.async.wait_group <N>`).
46    pub const CP_ASYNC_WAIT_GROUP: &str = "ATOMR_CP_ASYNC_WAIT_GROUP";
47    /// Bulk async copy (`cp.async.bulk.shared::cluster.global`) for
48    /// TMA-driven loads.
49    pub const CP_ASYNC_BULK: &str = "ATOMR_CP_ASYNC_BULK";
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[test]
57    fn arrival_counts_balanced() {
58        assert_eq!(mbarrier_arrival_count(2), 4);
59        assert_eq!(mbarrier_arrival_count(3), 6);
60        assert_eq!(mbarrier_arrival_count(4), 8);
61        assert_eq!(PipelineStages::Double.count(), 2);
62        assert_eq!(PipelineStages::Triple.count(), 3);
63        assert_eq!(PipelineStages::Quad.count(), 4);
64    }
65}