atomr_accel_cuda/kernel/collective/
reduce.rs1use std::sync::Arc;
4
5use cudarc::nccl::ReduceOp;
6use tokio::sync::oneshot;
7
8use super::{NcclReduceSupported, LIB};
9use crate::error::GpuError;
10use crate::gpu_ref::GpuRef;
11use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
12
13pub struct ReduceRequest<T: NcclReduceSupported> {
16 pub send: GpuRef<T>,
17 pub recv: Option<GpuRef<T>>,
18 pub op: ReduceOp,
19 pub root: i32,
20 pub reply: oneshot::Sender<Result<(), GpuError>>,
21}
22
23impl<T: NcclReduceSupported> CollectiveDispatch for ReduceRequest<T> {
24 fn dtype_kind(&self) -> DispatchDType {
25 T::dispatch_dtype()
26 }
27
28 fn device_id(&self) -> Option<u32> {
29 self.send
30 .device_id()
31 .or_else(|| self.recv.as_ref().and_then(|r| r.device_id()))
32 }
33
34 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
35 let ReduceRequest {
36 send,
37 recv,
38 op,
39 root,
40 reply,
41 } = *self;
42 let send_slice = match send.access() {
43 Ok(s) => s.clone(),
44 Err(e) => {
45 let _ = reply.send(Err(e));
46 return;
47 }
48 };
49 let recv_owned: Option<cudarc::driver::CudaSlice<T>> = match recv {
52 Some(r) => match r.access() {
53 Ok(s) => match Arc::try_unwrap(s.clone()) {
54 Ok(o) => Some(o),
55 Err(_) => {
56 let _ = reply.send(Err(GpuError::Unrecoverable(
57 "Reduce recv buffer has multiple live references".into(),
58 )));
59 return;
60 }
61 },
62 Err(e) => {
63 let _ = reply.send(Err(e));
64 return;
65 }
66 },
67 None => None,
68 };
69
70 let res = match recv_owned {
71 Some(mut owned) => ctx
72 .comm
73 .reduce(&*send_slice, Some(&mut owned), &op, root)
74 .map_err(|e| GpuError::LibraryError {
75 lib: LIB,
76 msg: format!("reduce: {e:?}"),
77 })
78 .map(|_| {
79 drop(owned);
80 }),
81 None => ctx
82 .comm
83 .reduce::<_, cudarc::driver::CudaSlice<T>, T>(&*send_slice, None, &op, root)
84 .map_err(|e| GpuError::LibraryError {
85 lib: LIB,
86 msg: format!("reduce: {e:?}"),
87 })
88 .map(|_| ()),
89 };
90 let _ = reply.send(res);
91 drop(send_slice);
92 }
93}