Skip to main content

atomr_accel_cuda/kernel/collective/
reduce.rs

1//! Typed Reduce request: reduce-to-root variant of AllReduce.
2
3use std::sync::Arc;
4
5use cudarc::nccl::ReduceOp;
6use tokio::sync::oneshot;
7
8use super::{NcclReduceSupported, LIB};
9use crate::error::GpuError;
10use crate::gpu_ref::GpuRef;
11use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
12
13/// Reduce: each rank contributes `send`; the result lands in `recv`
14/// on `root`. On non-root ranks, `recv` may be `None`.
15pub struct ReduceRequest<T: NcclReduceSupported> {
16    pub send: GpuRef<T>,
17    pub recv: Option<GpuRef<T>>,
18    pub op: ReduceOp,
19    pub root: i32,
20    pub reply: oneshot::Sender<Result<(), GpuError>>,
21}
22
23impl<T: NcclReduceSupported> CollectiveDispatch for ReduceRequest<T> {
24    fn dtype_kind(&self) -> DispatchDType {
25        T::dispatch_dtype()
26    }
27
28    fn device_id(&self) -> Option<u32> {
29        self.send
30            .device_id()
31            .or_else(|| self.recv.as_ref().and_then(|r| r.device_id()))
32    }
33
34    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
35        let ReduceRequest {
36            send,
37            recv,
38            op,
39            root,
40            reply,
41        } = *self;
42        let send_slice = match send.access() {
43            Ok(s) => s.clone(),
44            Err(e) => {
45                let _ = reply.send(Err(e));
46                return;
47            }
48        };
49        // On the root rank we must have a recv buffer; on other
50        // ranks recv is allowed to be None.
51        let recv_owned: Option<cudarc::driver::CudaSlice<T>> = match recv {
52            Some(r) => match r.access() {
53                Ok(s) => match Arc::try_unwrap(s.clone()) {
54                    Ok(o) => Some(o),
55                    Err(_) => {
56                        let _ = reply.send(Err(GpuError::Unrecoverable(
57                            "Reduce recv buffer has multiple live references".into(),
58                        )));
59                        return;
60                    }
61                },
62                Err(e) => {
63                    let _ = reply.send(Err(e));
64                    return;
65                }
66            },
67            None => None,
68        };
69
70        let res = match recv_owned {
71            Some(mut owned) => ctx
72                .comm
73                .reduce(&*send_slice, Some(&mut owned), &op, root)
74                .map_err(|e| GpuError::LibraryError {
75                    lib: LIB,
76                    msg: format!("reduce: {e:?}"),
77                })
78                .map(|_| {
79                    drop(owned);
80                }),
81            None => ctx
82                .comm
83                .reduce::<_, cudarc::driver::CudaSlice<T>, T>(&*send_slice, None, &op, root)
84                .map_err(|e| GpuError::LibraryError {
85                    lib: LIB,
86                    msg: format!("reduce: {e:?}"),
87                })
88                .map(|_| ()),
89        };
90        let _ = reply.send(res);
91        drop(send_slice);
92    }
93}