atomr_accel_cuda/kernel/collective/
broadcast.rs1use std::sync::Arc;
4
5use tokio::sync::oneshot;
6
7use super::{NcclReduceSupported, LIB};
8use crate::error::GpuError;
9use crate::gpu_ref::GpuRef;
10use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
11
12pub struct BroadcastRequest<T: NcclReduceSupported> {
15 pub data: GpuRef<T>,
16 pub root: usize,
17 pub reply: oneshot::Sender<Result<(), GpuError>>,
18}
19
20impl<T: NcclReduceSupported> CollectiveDispatch for BroadcastRequest<T> {
21 fn dtype_kind(&self) -> DispatchDType {
22 T::dispatch_dtype()
23 }
24
25 fn device_id(&self) -> Option<u32> {
26 self.data.device_id()
27 }
28
29 fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
30 let BroadcastRequest { data, root, reply } = *self;
31 let slice = match data.access() {
32 Ok(s) => s.clone(),
33 Err(e) => {
34 let _ = reply.send(Err(e));
35 return;
36 }
37 };
38 let mut owned = match Arc::try_unwrap(slice) {
39 Ok(s) => s,
40 Err(_) => {
41 let _ = reply.send(Err(GpuError::Unrecoverable(
42 "Broadcast data has multiple live references".into(),
43 )));
44 return;
45 }
46 };
47 let root_i32 = match i32::try_from(root) {
48 Ok(r) => r,
49 Err(_) => {
50 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
51 "Broadcast: root {root} does not fit in i32"
52 ))));
53 return;
54 }
55 };
56 let res = ctx
57 .comm
58 .broadcast_in_place(&mut owned, root_i32)
59 .map_err(|e| GpuError::LibraryError {
60 lib: LIB,
61 msg: format!("broadcast_in_place: {e:?}"),
62 });
63 let _ = reply.send(res.map(|_| ()));
64 drop(owned);
65 }
66}