Skip to main content

atomr_accel_cuda/kernel/collective/
allgather.rs

1//! Typed AllGather request — generic over `T: NcclReduceSupported`.
2
3use std::sync::Arc;
4
5use tokio::sync::oneshot;
6
7use super::{NcclReduceSupported, LIB};
8use crate::error::GpuError;
9use crate::gpu_ref::GpuRef;
10use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
11
12/// AllGather: each rank contributes `send` (length N) and writes a
13/// concatenated buffer of length `N * world_size` into `recv`.
14pub struct AllGatherRequest<T: NcclReduceSupported> {
15    pub send: GpuRef<T>,
16    pub recv: GpuRef<T>,
17    pub reply: oneshot::Sender<Result<(), GpuError>>,
18}
19
20impl<T: NcclReduceSupported> CollectiveDispatch for AllGatherRequest<T> {
21    fn dtype_kind(&self) -> DispatchDType {
22        T::dispatch_dtype()
23    }
24
25    fn device_id(&self) -> Option<u32> {
26        self.send.device_id().or_else(|| self.recv.device_id())
27    }
28
29    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
30        let AllGatherRequest { send, recv, reply } = *self;
31        let send_slice = match send.access() {
32            Ok(s) => s.clone(),
33            Err(e) => {
34                let _ = reply.send(Err(e));
35                return;
36            }
37        };
38        let recv_slice = match recv.access() {
39            Ok(s) => s.clone(),
40            Err(e) => {
41                let _ = reply.send(Err(e));
42                return;
43            }
44        };
45        let mut recv_owned = match Arc::try_unwrap(recv_slice) {
46            Ok(s) => s,
47            Err(_) => {
48                let _ = reply.send(Err(GpuError::Unrecoverable(
49                    "AllGather recv buffer has multiple live references".into(),
50                )));
51                return;
52            }
53        };
54        let res = ctx
55            .comm
56            .all_gather(&*send_slice, &mut recv_owned)
57            .map_err(|e| GpuError::LibraryError {
58                lib: LIB,
59                msg: format!("all_gather: {e:?}"),
60            });
61        let _ = reply.send(res.map(|_| ()));
62        drop(recv_owned);
63        drop(send_slice);
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    /// Every `NcclReduceSupported` dtype builds an
72    /// `AllGatherRequest<T>` that satisfies `CollectiveDispatch`.
73    #[test]
74    fn allgather_request_round_trip() {
75        fn assert_supported<T: NcclReduceSupported>() {
76            let _ = <T as NcclReduceSupported>::dispatch_dtype();
77        }
78        assert_supported::<f32>();
79        assert_supported::<f64>();
80        assert_supported::<i8>();
81        assert_supported::<u8>();
82        assert_supported::<i32>();
83        assert_supported::<u32>();
84        assert_supported::<i64>();
85        assert_supported::<u64>();
86        #[cfg(feature = "f16")]
87        {
88            assert_supported::<half::f16>();
89            assert_supported::<half::bf16>();
90        }
91    }
92}