atomr_accel_cuda/kernel/collective/
mod.rs1use std::sync::Arc;
17
18use async_trait::async_trait;
19use atomr_core::actor::{Actor, Context, Props};
20pub use cudarc::nccl::ReduceOp;
21use cudarc::nccl::{group_end, group_start, Comm};
22use tokio::sync::oneshot;
23
24use crate::completion::CompletionStrategy;
25use crate::device::DeviceState;
26use crate::error::GpuError;
27use crate::gpu_ref::GpuRef;
28use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx};
29
30pub mod all_to_all;
31pub mod allgather;
32pub mod allreduce;
33pub mod broadcast;
34pub mod capabilities;
35pub mod custom_op;
36pub mod group;
37pub mod p2p;
38pub mod reduce;
39pub mod reduce_scatter;
40
41pub use all_to_all::{AllToAllRequest, AllToAllvRequest};
42pub use allgather::AllGatherRequest;
43pub use allreduce::AllReduceRequest;
44pub use broadcast::BroadcastRequest;
45pub use capabilities::{probe_capabilities, NcclCapabilities};
46pub use custom_op::PreMulSumOp;
47pub use group::GroupGuard;
48pub use p2p::{RecvRequest, SendRequest};
49pub use reduce::ReduceRequest;
50pub use reduce_scatter::ReduceScatterRequest;
51
52pub(crate) const LIB: &str = "nccl";
53
54pub trait NcclReduceSupported: cudarc::nccl::NcclType + Copy + Send + Sync + 'static {
61 fn dispatch_dtype() -> crate::kernel::dispatch::DispatchDType;
63}
64
65macro_rules! impl_nccl_reduce_supported {
66 ($t:ty, $kind:ident) => {
67 impl NcclReduceSupported for $t {
68 fn dispatch_dtype() -> crate::kernel::dispatch::DispatchDType {
69 crate::kernel::dispatch::DispatchDType::$kind
70 }
71 }
72 };
73}
74
75impl_nccl_reduce_supported!(f32, F32);
76impl_nccl_reduce_supported!(f64, F64);
77impl_nccl_reduce_supported!(i8, I8);
78impl_nccl_reduce_supported!(u8, U8);
79impl_nccl_reduce_supported!(i32, I32);
80impl_nccl_reduce_supported!(u32, U32);
81impl_nccl_reduce_supported!(i64, I64);
82impl_nccl_reduce_supported!(u64, U64);
83
84#[cfg(feature = "f16")]
85impl_nccl_reduce_supported!(half::f16, F16);
86#[cfg(feature = "f16")]
87impl_nccl_reduce_supported!(half::bf16, Bf16);
88
89pub enum CollectiveMsg {
95 Op(Box<dyn CollectiveDispatch>),
100
101 BeginGroup {
103 reply: oneshot::Sender<Result<(), GpuError>>,
104 },
105 EndGroup {
107 reply: oneshot::Sender<Result<(), GpuError>>,
108 },
109
110 QueryCapabilities {
114 reply: oneshot::Sender<NcclCapabilities>,
115 },
116
117 #[deprecated(
121 note = "use CollectiveMsg::Op(Box::new(AllReduceRequest::<f32> { ... })) instead"
122 )]
123 AllReduceF32 {
124 tensor: GpuRef<f32>,
125 op: ReduceOp,
126 reply: oneshot::Sender<Result<(), GpuError>>,
127 },
128
129 #[deprecated(
133 note = "use CollectiveMsg::Op(Box::new(BroadcastRequest::<f32> { ... })) instead"
134 )]
135 BroadcastF32 {
136 data: GpuRef<f32>,
137 root: usize,
138 reply: oneshot::Sender<Result<(), GpuError>>,
139 },
140}
141
142pub struct CollectiveActor {
143 inner: CollectiveInner,
144}
145
146pub(crate) struct SendComm(pub(crate) Comm);
147unsafe impl Send for SendComm {}
148unsafe impl Sync for SendComm {}
149
150#[allow(dead_code)]
151enum CollectiveInner {
152 Real {
153 comm: SendComm,
154 state: Arc<DeviceState>,
155 completion: Arc<dyn CompletionStrategy>,
156 },
157 Mock,
158}
159
160impl CollectiveActor {
161 pub fn props_for_rank(
168 comm: Comm,
169 state: Arc<DeviceState>,
170 completion: Arc<dyn CompletionStrategy>,
171 ) -> Props<Self> {
172 use parking_lot::Mutex;
173 let comm_slot = Arc::new(Mutex::new(Some(SendComm(comm))));
174 Props::create(move || {
175 let comm = comm_slot
176 .lock()
177 .take()
178 .expect("Unrecoverable: CollectiveActor restart with consumed Comm — NcclWorldActor must rebuild the world");
179 CollectiveActor {
180 inner: CollectiveInner::Real {
181 comm,
182 state: state.clone(),
183 completion: completion.clone(),
184 },
185 }
186 })
187 }
188
189 pub fn mock_props() -> Props<Self> {
190 Props::create(|| CollectiveActor {
191 inner: CollectiveInner::Mock,
192 })
193 }
194}
195
196#[async_trait]
197impl Actor for CollectiveActor {
198 type Msg = CollectiveMsg;
199
200 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: CollectiveMsg) {
201 match (&self.inner, msg) {
202 (CollectiveInner::Mock, msg) => mock_reply(msg),
203 (
204 CollectiveInner::Real {
205 comm,
206 state,
207 completion,
208 },
209 CollectiveMsg::Op(boxed),
210 ) => {
211 if let Some(dev) = boxed.device_id() {
212 if dev != state.device_id() {
213 tracing::warn!(
225 expected = state.device_id(),
226 got = dev,
227 "collective op on wrong device"
228 );
229 }
230 }
231 let ctx = CollectiveDispatchCtx {
232 comm: &comm.0,
233 state,
234 completion,
235 };
236 boxed.dispatch(&ctx);
237 }
238 (CollectiveInner::Real { comm, .. }, msg) => {
239 handle_legacy(comm, msg);
240 }
241 }
242 }
243}
244
245fn mock_reply(msg: CollectiveMsg) {
246 let err = || GpuError::Unrecoverable("CollectiveActor in mock mode".into());
247 match msg {
248 CollectiveMsg::Op(boxed) => {
249 tracing::warn!(
254 dtype = ?boxed.dtype_kind(),
255 "CollectiveActor mock: dropping boxed op without reply"
256 );
257 drop(boxed);
258 }
259 CollectiveMsg::BeginGroup { reply } => {
260 let _ = reply.send(Err(err()));
261 }
262 CollectiveMsg::EndGroup { reply } => {
263 let _ = reply.send(Err(err()));
264 }
265 CollectiveMsg::QueryCapabilities { reply } => {
266 let _ = reply.send(NcclCapabilities::zeroed());
267 }
268 #[allow(deprecated)]
269 CollectiveMsg::AllReduceF32 { reply, .. } => {
270 let _ = reply.send(Err(err()));
271 }
272 #[allow(deprecated)]
273 CollectiveMsg::BroadcastF32 { reply, .. } => {
274 let _ = reply.send(Err(err()));
275 }
276 }
277}
278
279#[allow(deprecated)]
280fn handle_legacy(comm: &SendComm, msg: CollectiveMsg) {
281 match msg {
282 CollectiveMsg::Op(_) => unreachable!("Op handled in handle()"),
283 CollectiveMsg::BeginGroup { reply } => {
284 let res = group_start()
285 .map_err(|e| GpuError::LibraryError {
286 lib: LIB,
287 msg: format!("group_start: {e:?}"),
288 })
289 .map(|_| ());
290 let _ = reply.send(res);
291 }
292 CollectiveMsg::EndGroup { reply } => {
293 let res = group_end()
294 .map_err(|e| GpuError::LibraryError {
295 lib: LIB,
296 msg: format!("group_end: {e:?}"),
297 })
298 .map(|_| ());
299 let _ = reply.send(res);
300 }
301 CollectiveMsg::QueryCapabilities { reply } => {
302 let _ = reply.send(probe_capabilities());
303 }
304 CollectiveMsg::AllReduceF32 { tensor, op, reply } => {
305 let req = AllReduceRequest::<f32> { tensor, op, reply };
308 let dummy_state = Arc::new(crate::device::DeviceState::new(0));
309 let dummy_comp: Arc<dyn CompletionStrategy> =
310 Arc::new(crate::completion::HostFnCompletion::new());
311 let ctx = CollectiveDispatchCtx {
312 comm: &comm.0,
313 state: &dummy_state,
314 completion: &dummy_comp,
315 };
316 Box::new(req).dispatch(&ctx);
317 }
318 CollectiveMsg::BroadcastF32 { data, root, reply } => {
319 let req = BroadcastRequest::<f32> { data, root, reply };
320 let dummy_state = Arc::new(crate::device::DeviceState::new(0));
321 let dummy_comp: Arc<dyn CompletionStrategy> =
322 Arc::new(crate::completion::HostFnCompletion::new());
323 let ctx = CollectiveDispatchCtx {
324 comm: &comm.0,
325 state: &dummy_state,
326 completion: &dummy_comp,
327 };
328 Box::new(req).dispatch(&ctx);
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::device::DeviceState;
337 use std::sync::Arc as StdArc;
338
339 #[test]
342 #[allow(deprecated)]
343 fn deprecated_allreduce_f32_alias_still_constructs() {
344 let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
345 let state = StdArc::new(DeviceState::new(0));
346 let _ = state;
351 let _ = tx; let _ = std::mem::size_of::<CollectiveMsg>();
355 let _ = std::any::TypeId::of::<CollectiveMsg>();
356 }
357}