atomr_accel_cuda/kernel/collective/
all_to_all.rs1use std::sync::Arc;
12
13use cudarc::nccl::{group_end, group_start};
14use tokio::sync::oneshot;
15
16use super::{NcclReduceSupported, LIB};
17use crate::error::GpuError;
18use crate::gpu_ref::GpuRef;
19use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
20
21pub struct AllToAllRequest<T: NcclReduceSupported> {
25 pub send: GpuRef<T>,
26 pub recv: GpuRef<T>,
27 pub count: usize,
28 pub reply: oneshot::Sender<Result<(), GpuError>>,
29}
30
31impl<T: NcclReduceSupported> CollectiveDispatch for AllToAllRequest<T> {
32 fn dtype_kind(&self) -> DispatchDType {
33 T::dispatch_dtype()
34 }
35
36 fn device_id(&self) -> Option<u32> {
37 self.send.device_id().or_else(|| self.recv.device_id())
38 }
39
40 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
41 let AllToAllRequest {
42 send,
43 recv,
44 count,
45 reply,
46 } = *self;
47
48 let send_slice = match send.access() {
49 Ok(s) => s.clone(),
50 Err(e) => {
51 let _ = reply.send(Err(e));
52 return;
53 }
54 };
55 let recv_slice = match recv.access() {
56 Ok(s) => s.clone(),
57 Err(e) => {
58 let _ = reply.send(Err(e));
59 return;
60 }
61 };
62 let mut recv_owned = match Arc::try_unwrap(recv_slice) {
63 Ok(s) => s,
64 Err(_) => {
65 let _ = reply.send(Err(GpuError::Unrecoverable(
66 "AllToAll recv buffer has multiple live references".into(),
67 )));
68 return;
69 }
70 };
71
72 let world_size = ctx.comm.world_size();
73 if send_slice.len() < world_size * count || recv_owned.len() < world_size * count {
74 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
75 "AllToAll: buffer length < world_size ({world_size}) * count ({count})"
76 ))));
77 return;
78 }
79
80 let res = (|| -> Result<(), GpuError> {
82 group_start().map_err(|e| GpuError::LibraryError {
83 lib: LIB,
84 msg: format!("group_start: {e:?}"),
85 })?;
86
87 for peer in 0..world_size {
88 let peer_i32 = peer as i32;
89 let send_slab = send_slice.slice(peer * count..(peer + 1) * count);
90 let mut recv_slab = recv_owned.slice_mut(peer * count..(peer + 1) * count);
91 ctx.comm
92 .send(&send_slab, peer_i32)
93 .map_err(|e| GpuError::LibraryError {
94 lib: LIB,
95 msg: format!("a2a send to {peer}: {e:?}"),
96 })?;
97 ctx.comm
98 .recv(&mut recv_slab, peer_i32)
99 .map_err(|e| GpuError::LibraryError {
100 lib: LIB,
101 msg: format!("a2a recv from {peer}: {e:?}"),
102 })?;
103 }
104
105 group_end().map_err(|e| GpuError::LibraryError {
106 lib: LIB,
107 msg: format!("group_end: {e:?}"),
108 })?;
109 Ok(())
110 })();
111 let _ = reply.send(res);
112 drop(recv_owned);
113 drop(send_slice);
114 }
115}
116
117pub struct AllToAllvRequest<T: NcclReduceSupported> {
119 pub send: GpuRef<T>,
120 pub recv: GpuRef<T>,
121 pub send_counts: Vec<usize>,
122 pub send_offsets: Vec<usize>,
123 pub recv_counts: Vec<usize>,
124 pub recv_offsets: Vec<usize>,
125 pub reply: oneshot::Sender<Result<(), GpuError>>,
126}
127
128impl<T: NcclReduceSupported> CollectiveDispatch for AllToAllvRequest<T> {
129 fn dtype_kind(&self) -> DispatchDType {
130 T::dispatch_dtype()
131 }
132
133 fn device_id(&self) -> Option<u32> {
134 self.send.device_id().or_else(|| self.recv.device_id())
135 }
136
137 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
138 let AllToAllvRequest {
139 send,
140 recv,
141 send_counts,
142 send_offsets,
143 recv_counts,
144 recv_offsets,
145 reply,
146 } = *self;
147
148 let world_size = ctx.comm.world_size();
149 if send_counts.len() != world_size
150 || send_offsets.len() != world_size
151 || recv_counts.len() != world_size
152 || recv_offsets.len() != world_size
153 {
154 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
155 "AllToAllv: counts/offsets must each have length world_size ({world_size})"
156 ))));
157 return;
158 }
159
160 let send_slice = match send.access() {
161 Ok(s) => s.clone(),
162 Err(e) => {
163 let _ = reply.send(Err(e));
164 return;
165 }
166 };
167 let recv_slice = match recv.access() {
168 Ok(s) => s.clone(),
169 Err(e) => {
170 let _ = reply.send(Err(e));
171 return;
172 }
173 };
174 let mut recv_owned = match Arc::try_unwrap(recv_slice) {
175 Ok(s) => s,
176 Err(_) => {
177 let _ = reply.send(Err(GpuError::Unrecoverable(
178 "AllToAllv recv buffer has multiple live references".into(),
179 )));
180 return;
181 }
182 };
183
184 let res = (|| -> Result<(), GpuError> {
185 group_start().map_err(|e| GpuError::LibraryError {
186 lib: LIB,
187 msg: format!("group_start: {e:?}"),
188 })?;
189
190 for peer in 0..world_size {
191 let peer_i32 = peer as i32;
192 let s_off = send_offsets[peer];
193 let s_cnt = send_counts[peer];
194 let r_off = recv_offsets[peer];
195 let r_cnt = recv_counts[peer];
196
197 if s_cnt > 0 {
198 if s_off + s_cnt > send_slice.len() {
199 return Err(GpuError::Unrecoverable(format!(
200 "AllToAllv: send shard for peer {peer} overruns buffer"
201 )));
202 }
203 let send_slab = send_slice.slice(s_off..s_off + s_cnt);
204 ctx.comm
205 .send(&send_slab, peer_i32)
206 .map_err(|e| GpuError::LibraryError {
207 lib: LIB,
208 msg: format!("a2av send to {peer}: {e:?}"),
209 })?;
210 }
211 if r_cnt > 0 {
212 if r_off + r_cnt > recv_owned.len() {
213 return Err(GpuError::Unrecoverable(format!(
214 "AllToAllv: recv shard from peer {peer} overruns buffer"
215 )));
216 }
217 let mut recv_slab = recv_owned.slice_mut(r_off..r_off + r_cnt);
218 ctx.comm.recv(&mut recv_slab, peer_i32).map_err(|e| {
219 GpuError::LibraryError {
220 lib: LIB,
221 msg: format!("a2av recv from {peer}: {e:?}"),
222 }
223 })?;
224 }
225 }
226
227 group_end().map_err(|e| GpuError::LibraryError {
228 lib: LIB,
229 msg: format!("group_end: {e:?}"),
230 })?;
231 Ok(())
232 })();
233 let _ = reply.send(res);
234 drop(recv_owned);
235 drop(send_slice);
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn all_to_all_request_round_trip() {
245 fn assert_supported<T: NcclReduceSupported>() {
246 let _ = <T as NcclReduceSupported>::dispatch_dtype();
247 }
248 assert_supported::<f32>();
249 assert_supported::<f64>();
250 assert_supported::<i8>();
251 assert_supported::<u8>();
252 assert_supported::<i32>();
253 assert_supported::<u32>();
254 assert_supported::<i64>();
255 assert_supported::<u64>();
256 #[cfg(feature = "f16")]
257 {
258 assert_supported::<half::f16>();
259 assert_supported::<half::bf16>();
260 }
261 }
262}