Skip to main content

atomr_accel_cuda/kernel/collective/
reduce_scatter.rs

1//! Typed ReduceScatter request — generic over `T: NcclReduceSupported`.
2
3use std::sync::Arc;
4
5use cudarc::nccl::ReduceOp;
6use tokio::sync::oneshot;
7
8use super::{NcclReduceSupported, LIB};
9use crate::error::GpuError;
10use crate::gpu_ref::GpuRef;
11use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
12
13/// ReduceScatter: each rank contributes `send` (length `N *
14/// world_size`) and receives a reduced shard of length `N` into
15/// `recv`.
16pub struct ReduceScatterRequest<T: NcclReduceSupported> {
17    pub send: GpuRef<T>,
18    pub recv: GpuRef<T>,
19    pub op: ReduceOp,
20    pub reply: oneshot::Sender<Result<(), GpuError>>,
21}
22
23impl<T: NcclReduceSupported> CollectiveDispatch for ReduceScatterRequest<T> {
24    fn dtype_kind(&self) -> DispatchDType {
25        T::dispatch_dtype()
26    }
27
28    fn device_id(&self) -> Option<u32> {
29        self.send.device_id().or_else(|| self.recv.device_id())
30    }
31
32    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
33        let ReduceScatterRequest {
34            send,
35            recv,
36            op,
37            reply,
38        } = *self;
39        let send_slice = match send.access() {
40            Ok(s) => s.clone(),
41            Err(e) => {
42                let _ = reply.send(Err(e));
43                return;
44            }
45        };
46        let recv_slice = match recv.access() {
47            Ok(s) => s.clone(),
48            Err(e) => {
49                let _ = reply.send(Err(e));
50                return;
51            }
52        };
53        let mut recv_owned = match Arc::try_unwrap(recv_slice) {
54            Ok(s) => s,
55            Err(_) => {
56                let _ = reply.send(Err(GpuError::Unrecoverable(
57                    "ReduceScatter recv buffer has multiple live references".into(),
58                )));
59                return;
60            }
61        };
62        let res = ctx
63            .comm
64            .reduce_scatter(&*send_slice, &mut recv_owned, &op)
65            .map_err(|e| GpuError::LibraryError {
66                lib: LIB,
67                msg: format!("reduce_scatter: {e:?}"),
68            });
69        let _ = reply.send(res.map(|_| ()));
70        drop(recv_owned);
71        drop(send_slice);
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn reduce_scatter_request_round_trip() {
81        fn assert_supported<T: NcclReduceSupported>() {
82            let _ = <T as NcclReduceSupported>::dispatch_dtype();
83        }
84        assert_supported::<f32>();
85        assert_supported::<f64>();
86        assert_supported::<i8>();
87        assert_supported::<u8>();
88        assert_supported::<i32>();
89        assert_supported::<u32>();
90        assert_supported::<i64>();
91        assert_supported::<u64>();
92        #[cfg(feature = "f16")]
93        {
94            assert_supported::<half::f16>();
95            assert_supported::<half::bf16>();
96        }
97    }
98}