atomr_accel_cuda/kernel/collective/
group.rs1use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use cudarc::nccl::{group_end, group_start};
22
23use super::LIB;
24use crate::error::GpuError;
25
26#[derive(Debug, Default)]
29pub struct GroupTracker {
30 pub begins: AtomicUsize,
31 pub ends: AtomicUsize,
32}
33
34pub struct GroupGuard {
38 tracker: Option<Arc<GroupTracker>>,
39 committed: bool,
40 inert: bool,
42}
43
44impl GroupGuard {
45 pub fn begin(tracker: Option<Arc<GroupTracker>>) -> Result<Self, GpuError> {
49 match group_start() {
50 Ok(_) => {
51 if let Some(t) = &tracker {
52 t.begins.fetch_add(1, Ordering::SeqCst);
53 }
54 Ok(Self {
55 tracker,
56 committed: false,
57 inert: false,
58 })
59 }
60 Err(e) => Err(GpuError::LibraryError {
61 lib: LIB,
62 msg: format!("group_start: {e:?}"),
63 }),
64 }
65 }
66
67 pub fn begin_inert(tracker: Option<Arc<GroupTracker>>) -> Self {
71 if let Some(t) = &tracker {
72 t.begins.fetch_add(1, Ordering::SeqCst);
73 }
74 Self {
75 tracker,
76 committed: false,
77 inert: true,
78 }
79 }
80
81 pub fn commit(mut self) -> Result<(), GpuError> {
84 self.committed = true;
85 if let Some(t) = &self.tracker {
86 t.ends.fetch_add(1, Ordering::SeqCst);
87 }
88 if self.inert {
89 return Ok(());
90 }
91 group_end().map(|_| ()).map_err(|e| GpuError::LibraryError {
92 lib: LIB,
93 msg: format!("group_end: {e:?}"),
94 })
95 }
96}
97
98impl Drop for GroupGuard {
99 fn drop(&mut self) {
100 if self.committed {
101 return;
102 }
103 if let Some(t) = &self.tracker {
104 t.ends.fetch_add(1, Ordering::SeqCst);
105 }
106 if self.inert {
107 return;
108 }
109 if let Err(e) = group_end() {
110 tracing::warn!(error = ?e, "GroupGuard::drop: group_end failed");
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
122 fn group_scope_guard_emits_begin_end_pair() {
123 let tracker = Arc::new(GroupTracker::default());
124 {
125 let _g = GroupGuard::begin_inert(Some(tracker.clone()));
126 }
128 assert_eq!(tracker.begins.load(Ordering::SeqCst), 1);
129 assert_eq!(tracker.ends.load(Ordering::SeqCst), 1);
130 }
131
132 #[test]
134 fn commit_then_drop_does_not_double_count() {
135 let tracker = Arc::new(GroupTracker::default());
136 let g = GroupGuard::begin_inert(Some(tracker.clone()));
137 g.commit().unwrap();
138 assert_eq!(tracker.begins.load(Ordering::SeqCst), 1);
139 assert_eq!(tracker.ends.load(Ordering::SeqCst), 1);
140 }
141}