atomr_accel_cuda/kernel/collective/
p2p.rs1use std::sync::Arc;
4
5use tokio::sync::oneshot;
6
7use super::{NcclReduceSupported, LIB};
8use crate::error::GpuError;
9use crate::gpu_ref::GpuRef;
10use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
11
12pub struct SendRequest<T: NcclReduceSupported> {
15 pub data: GpuRef<T>,
16 pub peer: i32,
17 pub reply: oneshot::Sender<Result<(), GpuError>>,
18}
19
20impl<T: NcclReduceSupported> CollectiveDispatch for SendRequest<T> {
21 fn dtype_kind(&self) -> DispatchDType {
22 T::dispatch_dtype()
23 }
24
25 fn device_id(&self) -> Option<u32> {
26 self.data.device_id()
27 }
28
29 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
30 let SendRequest { data, peer, reply } = *self;
31 let slice = match data.access() {
32 Ok(s) => s.clone(),
33 Err(e) => {
34 let _ = reply.send(Err(e));
35 return;
36 }
37 };
38 let res = ctx
39 .comm
40 .send(&*slice, peer)
41 .map_err(|e| GpuError::LibraryError {
42 lib: LIB,
43 msg: format!("send: {e:?}"),
44 });
45 let _ = reply.send(res);
46 drop(slice);
47 }
48}
49
50pub struct RecvRequest<T: NcclReduceSupported> {
52 pub data: GpuRef<T>,
53 pub peer: i32,
54 pub reply: oneshot::Sender<Result<(), GpuError>>,
55}
56
57impl<T: NcclReduceSupported> CollectiveDispatch for RecvRequest<T> {
58 fn dtype_kind(&self) -> DispatchDType {
59 T::dispatch_dtype()
60 }
61
62 fn device_id(&self) -> Option<u32> {
63 self.data.device_id()
64 }
65
66 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
67 let RecvRequest { data, peer, reply } = *self;
68 let slice = match data.access() {
69 Ok(s) => s.clone(),
70 Err(e) => {
71 let _ = reply.send(Err(e));
72 return;
73 }
74 };
75 let mut owned = match Arc::try_unwrap(slice) {
76 Ok(s) => s,
77 Err(_) => {
78 let _ = reply.send(Err(GpuError::Unrecoverable(
79 "Recv buffer has multiple live references".into(),
80 )));
81 return;
82 }
83 };
84 let res = ctx
85 .comm
86 .recv(&mut owned, peer)
87 .map_err(|e| GpuError::LibraryError {
88 lib: LIB,
89 msg: format!("recv: {e:?}"),
90 });
91 let _ = reply.send(res.map(|_| ()));
92 drop(owned);
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
103 fn send_recv_request_round_trip() {
104 fn assert_send_recv<T: NcclReduceSupported>() {
105 let _ = <T as NcclReduceSupported>::dispatch_dtype();
106 }
107 assert_send_recv::<f32>();
108 assert_send_recv::<f64>();
109 assert_send_recv::<i8>();
110 assert_send_recv::<u8>();
111 assert_send_recv::<i32>();
112 assert_send_recv::<u32>();
113 assert_send_recv::<i64>();
114 assert_send_recv::<u64>();
115 #[cfg(feature = "f16")]
116 {
117 assert_send_recv::<half::f16>();
118 assert_send_recv::<half::bf16>();
119 }
120 }
121}