Skip to main content

atomr_accel_cuda/kernel/collective/
group.rs

1//! Typed scope guard around `ncclGroupStart` / `ncclGroupEnd`.
2//!
3//! Usage at the world level:
4//!
5//! ```ignore
6//! NcclWorld::group(|w| {
7//!     w.send(buf, peer)?;
8//!     w.recv(buf2, peer)?;
9//!     Ok(())
10//! })?;
11//! ```
12//!
13//! At the actor level (one rank), `GroupGuard` issues `group_start`
14//! on construction and `group_end` on `Drop` (or on `commit()`),
15//! while emitting `BeginGroup` / `EndGroup` markers via a tracker so
16//! tests can assert the begin/end pair fires exactly once.
17
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use cudarc::nccl::{group_end, group_start};
22
23use super::LIB;
24use crate::error::GpuError;
25
26/// Counter of begin/end events. Tests construct one, hand it to a
27/// `GroupGuard`, and assert the begin/end pair is balanced.
28#[derive(Debug, Default)]
29pub struct GroupTracker {
30    pub begins: AtomicUsize,
31    pub ends: AtomicUsize,
32}
33
34/// RAII scope guard for a group call on a single rank. Issues
35/// `ncclGroupStart` on construction; `ncclGroupEnd` on `Drop` if
36/// `commit()` was not called (errors logged but swallowed in Drop).
37pub struct GroupGuard {
38    tracker: Option<Arc<GroupTracker>>,
39    committed: bool,
40    /// Set to true on construction error so Drop is a no-op.
41    inert: bool,
42}
43
44impl GroupGuard {
45    /// Begin a group. If `tracker` is `Some`, increments
46    /// `tracker.begins` on construction and `tracker.ends` on the
47    /// matching end.
48    pub fn begin(tracker: Option<Arc<GroupTracker>>) -> Result<Self, GpuError> {
49        match group_start() {
50            Ok(_) => {
51                if let Some(t) = &tracker {
52                    t.begins.fetch_add(1, Ordering::SeqCst);
53                }
54                Ok(Self {
55                    tracker,
56                    committed: false,
57                    inert: false,
58                })
59            }
60            Err(e) => Err(GpuError::LibraryError {
61                lib: LIB,
62                msg: format!("group_start: {e:?}"),
63            }),
64        }
65    }
66
67    /// Begin a group without invoking NCCL — for tests on hosts
68    /// without a working NCCL install. Bumps the tracker but issues
69    /// no FFI call.
70    pub fn begin_inert(tracker: Option<Arc<GroupTracker>>) -> Self {
71        if let Some(t) = &tracker {
72            t.begins.fetch_add(1, Ordering::SeqCst);
73        }
74        Self {
75            tracker,
76            committed: false,
77            inert: true,
78        }
79    }
80
81    /// End the group. Returns the FFI result. Subsequent `Drop` is
82    /// a no-op.
83    pub fn commit(mut self) -> Result<(), GpuError> {
84        self.committed = true;
85        if let Some(t) = &self.tracker {
86            t.ends.fetch_add(1, Ordering::SeqCst);
87        }
88        if self.inert {
89            return Ok(());
90        }
91        group_end().map(|_| ()).map_err(|e| GpuError::LibraryError {
92            lib: LIB,
93            msg: format!("group_end: {e:?}"),
94        })
95    }
96}
97
98impl Drop for GroupGuard {
99    fn drop(&mut self) {
100        if self.committed {
101            return;
102        }
103        if let Some(t) = &self.tracker {
104            t.ends.fetch_add(1, Ordering::SeqCst);
105        }
106        if self.inert {
107            return;
108        }
109        if let Err(e) = group_end() {
110            tracing::warn!(error = ?e, "GroupGuard::drop: group_end failed");
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    /// Constructing a guard via `begin_inert` and dropping it must
120    /// produce exactly one begin and one end event.
121    #[test]
122    fn group_scope_guard_emits_begin_end_pair() {
123        let tracker = Arc::new(GroupTracker::default());
124        {
125            let _g = GroupGuard::begin_inert(Some(tracker.clone()));
126            // dropped at end of scope
127        }
128        assert_eq!(tracker.begins.load(Ordering::SeqCst), 1);
129        assert_eq!(tracker.ends.load(Ordering::SeqCst), 1);
130    }
131
132    /// Calling `commit()` must not double-fire the end counter.
133    #[test]
134    fn commit_then_drop_does_not_double_count() {
135        let tracker = Arc::new(GroupTracker::default());
136        let g = GroupGuard::begin_inert(Some(tracker.clone()));
137        g.commit().unwrap();
138        assert_eq!(tracker.begins.load(Ordering::SeqCst), 1);
139        assert_eq!(tracker.ends.load(Ordering::SeqCst), 1);
140    }
141}