Skip to main content

atomr_accel_cuda/kernel/collective/
all_to_all.rs

1//! Typed AllToAll / AllToAllv requests.
2//!
3//! cudarc 0.19.4 does not safely expose `ncclAllToAll` / `ncclSend` /
4//! `ncclRecv` against a raw `comm: ncclComm_t` (the field is private).
5//! Implementations therefore decompose AllToAll into a `group_start`
6//! / paired `send` + `recv` / `group_end` sequence using cudarc's safe
7//! `Comm::send` and `Comm::recv` — which is the standard NCCL idiom
8//! for AllToAll regardless. AllToAllv adds per-peer (count, offset)
9//! pairs.
10
11use std::sync::Arc;
12
13use cudarc::nccl::{group_end, group_start};
14use tokio::sync::oneshot;
15
16use super::{NcclReduceSupported, LIB};
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
20
21/// Symmetric AllToAll: each rank sends an equal-sized shard of length
22/// `count` to every peer. `send` carries `world_size * count`
23/// elements; `recv` likewise.
24pub struct AllToAllRequest<T: NcclReduceSupported> {
25    pub send: GpuRef<T>,
26    pub recv: GpuRef<T>,
27    pub count: usize,
28    pub reply: oneshot::Sender<Result<(), GpuError>>,
29}
30
31impl<T: NcclReduceSupported> CollectiveDispatch for AllToAllRequest<T> {
32    fn dtype_kind(&self) -> DispatchDType {
33        T::dispatch_dtype()
34    }
35
36    fn device_id(&self) -> Option<u32> {
37        self.send.device_id().or_else(|| self.recv.device_id())
38    }
39
40    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
41        let AllToAllRequest {
42            send,
43            recv,
44            count,
45            reply,
46        } = *self;
47
48        let send_slice = match send.access() {
49            Ok(s) => s.clone(),
50            Err(e) => {
51                let _ = reply.send(Err(e));
52                return;
53            }
54        };
55        let recv_slice = match recv.access() {
56            Ok(s) => s.clone(),
57            Err(e) => {
58                let _ = reply.send(Err(e));
59                return;
60            }
61        };
62        let mut recv_owned = match Arc::try_unwrap(recv_slice) {
63            Ok(s) => s,
64            Err(_) => {
65                let _ = reply.send(Err(GpuError::Unrecoverable(
66                    "AllToAll recv buffer has multiple live references".into(),
67                )));
68                return;
69            }
70        };
71
72        let world_size = ctx.comm.world_size();
73        if send_slice.len() < world_size * count || recv_owned.len() < world_size * count {
74            let _ = reply.send(Err(GpuError::Unrecoverable(format!(
75                "AllToAll: buffer length < world_size ({world_size}) * count ({count})"
76            ))));
77            return;
78        }
79
80        // group_start + 2*world_size paired send/recv + group_end.
81        let res = (|| -> Result<(), GpuError> {
82            group_start().map_err(|e| GpuError::LibraryError {
83                lib: LIB,
84                msg: format!("group_start: {e:?}"),
85            })?;
86
87            for peer in 0..world_size {
88                let peer_i32 = peer as i32;
89                let send_slab = send_slice.slice(peer * count..(peer + 1) * count);
90                let mut recv_slab = recv_owned.slice_mut(peer * count..(peer + 1) * count);
91                ctx.comm
92                    .send(&send_slab, peer_i32)
93                    .map_err(|e| GpuError::LibraryError {
94                        lib: LIB,
95                        msg: format!("a2a send to {peer}: {e:?}"),
96                    })?;
97                ctx.comm
98                    .recv(&mut recv_slab, peer_i32)
99                    .map_err(|e| GpuError::LibraryError {
100                        lib: LIB,
101                        msg: format!("a2a recv from {peer}: {e:?}"),
102                    })?;
103            }
104
105            group_end().map_err(|e| GpuError::LibraryError {
106                lib: LIB,
107                msg: format!("group_end: {e:?}"),
108            })?;
109            Ok(())
110        })();
111        let _ = reply.send(res);
112        drop(recv_owned);
113        drop(send_slice);
114    }
115}
116
117/// AllToAllv: per-peer (count, offset) shards in send and recv.
118pub struct AllToAllvRequest<T: NcclReduceSupported> {
119    pub send: GpuRef<T>,
120    pub recv: GpuRef<T>,
121    pub send_counts: Vec<usize>,
122    pub send_offsets: Vec<usize>,
123    pub recv_counts: Vec<usize>,
124    pub recv_offsets: Vec<usize>,
125    pub reply: oneshot::Sender<Result<(), GpuError>>,
126}
127
128impl<T: NcclReduceSupported> CollectiveDispatch for AllToAllvRequest<T> {
129    fn dtype_kind(&self) -> DispatchDType {
130        T::dispatch_dtype()
131    }
132
133    fn device_id(&self) -> Option<u32> {
134        self.send.device_id().or_else(|| self.recv.device_id())
135    }
136
137    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
138        let AllToAllvRequest {
139            send,
140            recv,
141            send_counts,
142            send_offsets,
143            recv_counts,
144            recv_offsets,
145            reply,
146        } = *self;
147
148        let world_size = ctx.comm.world_size();
149        if send_counts.len() != world_size
150            || send_offsets.len() != world_size
151            || recv_counts.len() != world_size
152            || recv_offsets.len() != world_size
153        {
154            let _ = reply.send(Err(GpuError::Unrecoverable(format!(
155                "AllToAllv: counts/offsets must each have length world_size ({world_size})"
156            ))));
157            return;
158        }
159
160        let send_slice = match send.access() {
161            Ok(s) => s.clone(),
162            Err(e) => {
163                let _ = reply.send(Err(e));
164                return;
165            }
166        };
167        let recv_slice = match recv.access() {
168            Ok(s) => s.clone(),
169            Err(e) => {
170                let _ = reply.send(Err(e));
171                return;
172            }
173        };
174        let mut recv_owned = match Arc::try_unwrap(recv_slice) {
175            Ok(s) => s,
176            Err(_) => {
177                let _ = reply.send(Err(GpuError::Unrecoverable(
178                    "AllToAllv recv buffer has multiple live references".into(),
179                )));
180                return;
181            }
182        };
183
184        let res = (|| -> Result<(), GpuError> {
185            group_start().map_err(|e| GpuError::LibraryError {
186                lib: LIB,
187                msg: format!("group_start: {e:?}"),
188            })?;
189
190            for peer in 0..world_size {
191                let peer_i32 = peer as i32;
192                let s_off = send_offsets[peer];
193                let s_cnt = send_counts[peer];
194                let r_off = recv_offsets[peer];
195                let r_cnt = recv_counts[peer];
196
197                if s_cnt > 0 {
198                    if s_off + s_cnt > send_slice.len() {
199                        return Err(GpuError::Unrecoverable(format!(
200                            "AllToAllv: send shard for peer {peer} overruns buffer"
201                        )));
202                    }
203                    let send_slab = send_slice.slice(s_off..s_off + s_cnt);
204                    ctx.comm
205                        .send(&send_slab, peer_i32)
206                        .map_err(|e| GpuError::LibraryError {
207                            lib: LIB,
208                            msg: format!("a2av send to {peer}: {e:?}"),
209                        })?;
210                }
211                if r_cnt > 0 {
212                    if r_off + r_cnt > recv_owned.len() {
213                        return Err(GpuError::Unrecoverable(format!(
214                            "AllToAllv: recv shard from peer {peer} overruns buffer"
215                        )));
216                    }
217                    let mut recv_slab = recv_owned.slice_mut(r_off..r_off + r_cnt);
218                    ctx.comm.recv(&mut recv_slab, peer_i32).map_err(|e| {
219                        GpuError::LibraryError {
220                            lib: LIB,
221                            msg: format!("a2av recv from {peer}: {e:?}"),
222                        }
223                    })?;
224                }
225            }
226
227            group_end().map_err(|e| GpuError::LibraryError {
228                lib: LIB,
229                msg: format!("group_end: {e:?}"),
230            })?;
231            Ok(())
232        })();
233        let _ = reply.send(res);
234        drop(recv_owned);
235        drop(send_slice);
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn all_to_all_request_round_trip() {
245        fn assert_supported<T: NcclReduceSupported>() {
246            let _ = <T as NcclReduceSupported>::dispatch_dtype();
247        }
248        assert_supported::<f32>();
249        assert_supported::<f64>();
250        assert_supported::<i8>();
251        assert_supported::<u8>();
252        assert_supported::<i32>();
253        assert_supported::<u32>();
254        assert_supported::<i64>();
255        assert_supported::<u64>();
256        #[cfg(feature = "f16")]
257        {
258            assert_supported::<half::f16>();
259            assert_supported::<half::bf16>();
260        }
261    }
262}