atomr_accel_cuda/kernel/collective/
allreduce.rs1use std::marker::PhantomData;
5use std::sync::Arc;
6
7use cudarc::nccl::ReduceOp;
8use tokio::sync::oneshot;
9
10use super::{NcclReduceSupported, LIB};
11use crate::error::GpuError;
12use crate::gpu_ref::GpuRef;
13use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
14
15pub struct AllReduceRequest<T: NcclReduceSupported> {
20 pub tensor: GpuRef<T>,
21 pub op: ReduceOp,
22 pub reply: oneshot::Sender<Result<(), GpuError>>,
23}
24
25impl<T: NcclReduceSupported> CollectiveDispatch for AllReduceRequest<T> {
26 fn dtype_kind(&self) -> DispatchDType {
27 T::dispatch_dtype()
28 }
29
30 fn device_id(&self) -> Option<u32> {
31 self.tensor.device_id()
32 }
33
34 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
35 let AllReduceRequest { tensor, op, reply } = *self;
36 let slice = match tensor.access() {
37 Ok(s) => s.clone(),
38 Err(e) => {
39 let _ = reply.send(Err(e));
40 return;
41 }
42 };
43 let mut owned = match Arc::try_unwrap(slice) {
45 Ok(s) => s,
46 Err(_) => {
47 let _ = reply.send(Err(GpuError::Unrecoverable(
48 "AllReduce tensor has multiple live references".into(),
49 )));
50 return;
51 }
52 };
53 let res =
54 ctx.comm
55 .all_reduce_in_place(&mut owned, &op)
56 .map_err(|e| GpuError::LibraryError {
57 lib: LIB,
58 msg: format!("all_reduce: {e:?}"),
59 });
60 let _ = reply.send(res.map(|_| ()));
61 drop(owned);
62 }
63}
64
65#[allow(dead_code)]
66fn _phantom_use<T: NcclReduceSupported>() -> PhantomData<T> {
67 PhantomData
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
78 fn request_round_trip_for_every_dtype() {
79 fn assert_boxable<T: NcclReduceSupported>() {
84 let _ = _phantom_use::<T>();
87 let _ = <T as NcclReduceSupported>::dispatch_dtype();
89 }
90 assert_boxable::<f32>();
91 assert_boxable::<f64>();
92 assert_boxable::<i8>();
93 assert_boxable::<u8>();
94 assert_boxable::<i32>();
95 assert_boxable::<u32>();
96 assert_boxable::<i64>();
97 assert_boxable::<u64>();
98 #[cfg(feature = "f16")]
99 {
100 assert_boxable::<half::f16>();
101 assert_boxable::<half::bf16>();
102 }
103 }
104}