atomr_accel_cuda/kernel/collective/
allgather.rs1use std::sync::Arc;
4
5use tokio::sync::oneshot;
6
7use super::{NcclReduceSupported, LIB};
8use crate::error::GpuError;
9use crate::gpu_ref::GpuRef;
10use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
11
12pub struct AllGatherRequest<T: NcclReduceSupported> {
15 pub send: GpuRef<T>,
16 pub recv: GpuRef<T>,
17 pub reply: oneshot::Sender<Result<(), GpuError>>,
18}
19
20impl<T: NcclReduceSupported> CollectiveDispatch for AllGatherRequest<T> {
21 fn dtype_kind(&self) -> DispatchDType {
22 T::dispatch_dtype()
23 }
24
25 fn device_id(&self) -> Option<u32> {
26 self.send.device_id().or_else(|| self.recv.device_id())
27 }
28
29 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
30 let AllGatherRequest { send, recv, reply } = *self;
31 let send_slice = match send.access() {
32 Ok(s) => s.clone(),
33 Err(e) => {
34 let _ = reply.send(Err(e));
35 return;
36 }
37 };
38 let recv_slice = match recv.access() {
39 Ok(s) => s.clone(),
40 Err(e) => {
41 let _ = reply.send(Err(e));
42 return;
43 }
44 };
45 let mut recv_owned = match Arc::try_unwrap(recv_slice) {
46 Ok(s) => s,
47 Err(_) => {
48 let _ = reply.send(Err(GpuError::Unrecoverable(
49 "AllGather recv buffer has multiple live references".into(),
50 )));
51 return;
52 }
53 };
54 let res = ctx
55 .comm
56 .all_gather(&*send_slice, &mut recv_owned)
57 .map_err(|e| GpuError::LibraryError {
58 lib: LIB,
59 msg: format!("all_gather: {e:?}"),
60 });
61 let _ = reply.send(res.map(|_| ()));
62 drop(recv_owned);
63 drop(send_slice);
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
74 fn allgather_request_round_trip() {
75 fn assert_supported<T: NcclReduceSupported>() {
76 let _ = <T as NcclReduceSupported>::dispatch_dtype();
77 }
78 assert_supported::<f32>();
79 assert_supported::<f64>();
80 assert_supported::<i8>();
81 assert_supported::<u8>();
82 assert_supported::<i32>();
83 assert_supported::<u32>();
84 assert_supported::<i64>();
85 assert_supported::<u64>();
86 #[cfg(feature = "f16")]
87 {
88 assert_supported::<half::f16>();
89 assert_supported::<half::bf16>();
90 }
91 }
92}