atomr_accel_cuda/kernel/collective/
reduce_scatter.rs1use std::sync::Arc;
4
5use cudarc::nccl::ReduceOp;
6use tokio::sync::oneshot;
7
8use super::{NcclReduceSupported, LIB};
9use crate::error::GpuError;
10use crate::gpu_ref::GpuRef;
11use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
12
13pub struct ReduceScatterRequest<T: NcclReduceSupported> {
17 pub send: GpuRef<T>,
18 pub recv: GpuRef<T>,
19 pub op: ReduceOp,
20 pub reply: oneshot::Sender<Result<(), GpuError>>,
21}
22
23impl<T: NcclReduceSupported> CollectiveDispatch for ReduceScatterRequest<T> {
24 fn dtype_kind(&self) -> DispatchDType {
25 T::dispatch_dtype()
26 }
27
28 fn device_id(&self) -> Option<u32> {
29 self.send.device_id().or_else(|| self.recv.device_id())
30 }
31
32 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
33 let ReduceScatterRequest {
34 send,
35 recv,
36 op,
37 reply,
38 } = *self;
39 let send_slice = match send.access() {
40 Ok(s) => s.clone(),
41 Err(e) => {
42 let _ = reply.send(Err(e));
43 return;
44 }
45 };
46 let recv_slice = match recv.access() {
47 Ok(s) => s.clone(),
48 Err(e) => {
49 let _ = reply.send(Err(e));
50 return;
51 }
52 };
53 let mut recv_owned = match Arc::try_unwrap(recv_slice) {
54 Ok(s) => s,
55 Err(_) => {
56 let _ = reply.send(Err(GpuError::Unrecoverable(
57 "ReduceScatter recv buffer has multiple live references".into(),
58 )));
59 return;
60 }
61 };
62 let res = ctx
63 .comm
64 .reduce_scatter(&*send_slice, &mut recv_owned, &op)
65 .map_err(|e| GpuError::LibraryError {
66 lib: LIB,
67 msg: format!("reduce_scatter: {e:?}"),
68 });
69 let _ = reply.send(res.map(|_| ()));
70 drop(recv_owned);
71 drop(send_slice);
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn reduce_scatter_request_round_trip() {
81 fn assert_supported<T: NcclReduceSupported>() {
82 let _ = <T as NcclReduceSupported>::dispatch_dtype();
83 }
84 assert_supported::<f32>();
85 assert_supported::<f64>();
86 assert_supported::<i8>();
87 assert_supported::<u8>();
88 assert_supported::<i32>();
89 assert_supported::<u32>();
90 assert_supported::<i64>();
91 assert_supported::<u64>();
92 #[cfg(feature = "f16")]
93 {
94 assert_supported::<half::f16>();
95 assert_supported::<half::bf16>();
96 }
97 }
98}