Skip to main content

atomr_accel_cuda/kernel/collective/
allreduce.rs

1//! Typed AllReduce request. Generic over `T: NcclReduceSupported`
2//! (any of f32/f64/i8/u8/i32/u32/i64/u64; f16/bf16 with `f16`).
3
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7use cudarc::nccl::ReduceOp;
8use tokio::sync::oneshot;
9
10use super::{NcclReduceSupported, LIB};
11use crate::error::GpuError;
12use crate::gpu_ref::GpuRef;
13use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
14
15/// In-place all-reduce on a single tensor on this actor's device.
16///
17/// The world actor coordinates `group_start`/`group_end` across
18/// ranks via separate `BeginGroup` / `EndGroup` messages.
19pub struct AllReduceRequest<T: NcclReduceSupported> {
20    pub tensor: GpuRef<T>,
21    pub op: ReduceOp,
22    pub reply: oneshot::Sender<Result<(), GpuError>>,
23}
24
25impl<T: NcclReduceSupported> CollectiveDispatch for AllReduceRequest<T> {
26    fn dtype_kind(&self) -> DispatchDType {
27        T::dispatch_dtype()
28    }
29
30    fn device_id(&self) -> Option<u32> {
31        self.tensor.device_id()
32    }
33
34    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
35        let AllReduceRequest { tensor, op, reply } = *self;
36        let slice = match tensor.access() {
37            Ok(s) => s.clone(),
38            Err(e) => {
39                let _ = reply.send(Err(e));
40                return;
41            }
42        };
43        // In-place all-reduce requires a unique-owner Arc unwrap.
44        let mut owned = match Arc::try_unwrap(slice) {
45            Ok(s) => s,
46            Err(_) => {
47                let _ = reply.send(Err(GpuError::Unrecoverable(
48                    "AllReduce tensor has multiple live references".into(),
49                )));
50                return;
51            }
52        };
53        let res =
54            ctx.comm
55                .all_reduce_in_place(&mut owned, &op)
56                .map_err(|e| GpuError::LibraryError {
57                    lib: LIB,
58                    msg: format!("all_reduce: {e:?}"),
59                });
60        let _ = reply.send(res.map(|_| ()));
61        drop(owned);
62    }
63}
64
65#[allow(dead_code)]
66fn _phantom_use<T: NcclReduceSupported>() -> PhantomData<T> {
67    PhantomData
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    /// Build (and immediately drop) an `AllReduceRequest<T>` for every
75    /// `NcclReduceSupported` dtype. Validates that the trait bounds
76    /// line up across the dtype matrix.
77    #[test]
78    fn request_round_trip_for_every_dtype() {
79        // We can't instantiate a real `GpuRef<T>` without a CudaSlice.
80        // Instead, prove the type-level surface compiles: the boxed
81        // dispatch can be coerced to `Box<dyn CollectiveDispatch>` for
82        // each `T`.
83        fn assert_boxable<T: NcclReduceSupported>() {
84            // PhantomData synthesises the trait bound check at
85            // monomorphisation time.
86            let _ = _phantom_use::<T>();
87            // Verify dtype tag lookup is present.
88            let _ = <T as NcclReduceSupported>::dispatch_dtype();
89        }
90        assert_boxable::<f32>();
91        assert_boxable::<f64>();
92        assert_boxable::<i8>();
93        assert_boxable::<u8>();
94        assert_boxable::<i32>();
95        assert_boxable::<u32>();
96        assert_boxable::<i64>();
97        assert_boxable::<u64>();
98        #[cfg(feature = "f16")]
99        {
100            assert_boxable::<half::f16>();
101            assert_boxable::<half::bf16>();
102        }
103    }
104}