Skip to main content

atomr_accel_cuda/kernel/collective/
p2p.rs

1//! Typed point-to-point Send / Recv requests.
2
3use std::sync::Arc;
4
5use tokio::sync::oneshot;
6
7use super::{NcclReduceSupported, LIB};
8use crate::error::GpuError;
9use crate::gpu_ref::GpuRef;
10use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
11
12/// Send `data` to rank `peer`. Per NCCL, must be paired with a
13/// matching `RecvRequest` on the peer inside the same group call.
14pub struct SendRequest<T: NcclReduceSupported> {
15    pub data: GpuRef<T>,
16    pub peer: i32,
17    pub reply: oneshot::Sender<Result<(), GpuError>>,
18}
19
20impl<T: NcclReduceSupported> CollectiveDispatch for SendRequest<T> {
21    fn dtype_kind(&self) -> DispatchDType {
22        T::dispatch_dtype()
23    }
24
25    fn device_id(&self) -> Option<u32> {
26        self.data.device_id()
27    }
28
29    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
30        let SendRequest { data, peer, reply } = *self;
31        let slice = match data.access() {
32            Ok(s) => s.clone(),
33            Err(e) => {
34                let _ = reply.send(Err(e));
35                return;
36            }
37        };
38        let res = ctx
39            .comm
40            .send(&*slice, peer)
41            .map_err(|e| GpuError::LibraryError {
42                lib: LIB,
43                msg: format!("send: {e:?}"),
44            });
45        let _ = reply.send(res);
46        drop(slice);
47    }
48}
49
50/// Receive a buffer from rank `peer` into `data`.
51pub struct RecvRequest<T: NcclReduceSupported> {
52    pub data: GpuRef<T>,
53    pub peer: i32,
54    pub reply: oneshot::Sender<Result<(), GpuError>>,
55}
56
57impl<T: NcclReduceSupported> CollectiveDispatch for RecvRequest<T> {
58    fn dtype_kind(&self) -> DispatchDType {
59        T::dispatch_dtype()
60    }
61
62    fn device_id(&self) -> Option<u32> {
63        self.data.device_id()
64    }
65
66    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
67        let RecvRequest { data, peer, reply } = *self;
68        let slice = match data.access() {
69            Ok(s) => s.clone(),
70            Err(e) => {
71                let _ = reply.send(Err(e));
72                return;
73            }
74        };
75        let mut owned = match Arc::try_unwrap(slice) {
76            Ok(s) => s,
77            Err(_) => {
78                let _ = reply.send(Err(GpuError::Unrecoverable(
79                    "Recv buffer has multiple live references".into(),
80                )));
81                return;
82            }
83        };
84        let res = ctx
85            .comm
86            .recv(&mut owned, peer)
87            .map_err(|e| GpuError::LibraryError {
88                lib: LIB,
89                msg: format!("recv: {e:?}"),
90            });
91        let _ = reply.send(res.map(|_| ()));
92        drop(owned);
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    /// Both `SendRequest<T>` and `RecvRequest<T>` impl
101    /// `CollectiveDispatch` for the full reduce-supported dtype set.
102    #[test]
103    fn send_recv_request_round_trip() {
104        fn assert_send_recv<T: NcclReduceSupported>() {
105            let _ = <T as NcclReduceSupported>::dispatch_dtype();
106        }
107        assert_send_recv::<f32>();
108        assert_send_recv::<f64>();
109        assert_send_recv::<i8>();
110        assert_send_recv::<u8>();
111        assert_send_recv::<i32>();
112        assert_send_recv::<u32>();
113        assert_send_recv::<i64>();
114        assert_send_recv::<u64>();
115        #[cfg(feature = "f16")]
116        {
117            assert_send_recv::<half::f16>();
118            assert_send_recv::<half::bf16>();
119        }
120    }
121}