Skip to main content

atomr_accel_cuda/kernel/collective/
broadcast.rs

1//! Typed Broadcast request — generic over `T: NcclReduceSupported`.
2
3use std::sync::Arc;
4
5use tokio::sync::oneshot;
6
7use super::{NcclReduceSupported, LIB};
8use crate::error::GpuError;
9use crate::gpu_ref::GpuRef;
10use crate::kernel::dispatch::{CollectiveDispatch, CollectiveDispatchCtx, DispatchDType};
11
12/// In-place broadcast: every rank's `data` buffer is overwritten with
13/// the contents of `data` on the rank whose `comm.rank() == root`.
14pub struct BroadcastRequest<T: NcclReduceSupported> {
15    pub data: GpuRef<T>,
16    pub root: usize,
17    pub reply: oneshot::Sender<Result<(), GpuError>>,
18}
19
20impl<T: NcclReduceSupported> CollectiveDispatch for BroadcastRequest<T> {
21    fn dtype_kind(&self) -> DispatchDType {
22        T::dispatch_dtype()
23    }
24
25    fn device_id(&self) -> Option<u32> {
26        self.data.device_id()
27    }
28
29    fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>) {
30        let BroadcastRequest { data, root, reply } = *self;
31        let slice = match data.access() {
32            Ok(s) => s.clone(),
33            Err(e) => {
34                let _ = reply.send(Err(e));
35                return;
36            }
37        };
38        let mut owned = match Arc::try_unwrap(slice) {
39            Ok(s) => s,
40            Err(_) => {
41                let _ = reply.send(Err(GpuError::Unrecoverable(
42                    "Broadcast data has multiple live references".into(),
43                )));
44                return;
45            }
46        };
47        let root_i32 = match i32::try_from(root) {
48            Ok(r) => r,
49            Err(_) => {
50                let _ = reply.send(Err(GpuError::Unrecoverable(format!(
51                    "Broadcast: root {root} does not fit in i32"
52                ))));
53                return;
54            }
55        };
56        let res = ctx
57            .comm
58            .broadcast_in_place(&mut owned, root_i32)
59            .map_err(|e| GpuError::LibraryError {
60                lib: LIB,
61                msg: format!("broadcast_in_place: {e:?}"),
62            });
63        let _ = reply.send(res.map(|_| ()));
64        drop(owned);
65    }
66}