Skip to main content

atomr_accel_cuda/multi_device/
world_actor.rs

1//! `NcclWorldActor` — supervises a multi-GPU NCCL world.
2//!
3//! Spawns N `DeviceActor`s (one per device id in the world config),
4//! waits for each to report `ContextReady`, snapshots their CUDA
5//! contexts to mint per-rank streams, calls
6//! `Comm::from_devices(streams)` to build the NCCL group, and spawns
7//! one `CollectiveActor` per rank.
8//!
9//! Routing model: when the world receives an `AllReduceF32` with
10//! `Vec<GpuRef<f32>>`, it cross-validates each `GpuRef`'s device-id
11//! and dispatches to the matching `CollectiveActor`. Replies arrive
12//! via per-rank oneshots; the world joins them and reports a single
13//! result.
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use atomr_core::actor::{Actor, ActorRef, Context, Props};
19use cudarc::nccl::Comm;
20use tokio::sync::oneshot;
21use tracing::{info, warn};
22
23use crate::completion::{CompletionStrategy, HostFnCompletion};
24use crate::device::{DeviceActor, DeviceConfig, DeviceMsg};
25use crate::error::GpuError;
26use crate::gpu_ref::GpuRef;
27use crate::kernel::{AllReduceRequest, CollectiveActor, CollectiveMsg, ReduceOp};
28
29#[derive(Debug, Clone)]
30pub struct NcclWorldConfig {
31    pub device_ids: Vec<u32>,
32    pub root: usize,
33}
34
35impl NcclWorldConfig {
36    pub fn new(device_ids: Vec<u32>) -> Self {
37        Self {
38            device_ids,
39            root: 0,
40        }
41    }
42}
43
44pub enum NcclWorldMsg {
45    AllReduceF32 {
46        tensors: Vec<GpuRef<f32>>,
47        op: ReduceOp,
48        reply: oneshot::Sender<Result<(), GpuError>>,
49    },
50    /// Internal: child reports it's ready.
51    ChildReady {
52        device_idx: usize,
53        device_ref: ActorRef<DeviceMsg>,
54    },
55    /// Internal: a per-device generation watch fired, meaning a
56    /// `ContextActor` rebuilt its CUDA context and the existing NCCL
57    /// communicators are now invalid. The world tears the
58    /// collectives down and rebuilds.
59    DeviceContextChanged {
60        device_idx: usize,
61        new_generation: u64,
62    },
63}
64
65pub struct NcclWorldActor {
66    config: NcclWorldConfig,
67    devices: Vec<Option<ActorRef<DeviceMsg>>>,
68    collectives: Vec<Option<ActorRef<CollectiveMsg>>>,
69    /// Set true once `try_build_world` has run successfully.
70    built: bool,
71    /// Per-device generation last seen. When a device reports a
72    /// generation change above this value, the world is rebuilt.
73    last_generation: Vec<u64>,
74    #[allow(dead_code)]
75    completion: Arc<dyn CompletionStrategy>,
76}
77
78impl NcclWorldActor {
79    pub fn props(config: NcclWorldConfig) -> Props<Self> {
80        Props::create(move || {
81            let n = config.device_ids.len();
82            NcclWorldActor {
83                config: config.clone(),
84                devices: (0..n).map(|_| None).collect(),
85                collectives: (0..n).map(|_| None).collect(),
86                built: false,
87                last_generation: vec![0; n],
88                completion: Arc::new(HostFnCompletion::new()),
89            }
90        })
91    }
92
93    async fn try_build_world(&mut self, ctx: &mut Context<Self>) {
94        if self.built {
95            return;
96        }
97        if self.devices.iter().any(|d| d.is_none()) {
98            return;
99        }
100
101        // Snapshot each device's CudaContext.
102        let mut snaps = Vec::with_capacity(self.devices.len());
103        for d in &self.devices {
104            let dref = d.as_ref().unwrap();
105            let (tx, rx) = oneshot::channel();
106            dref.tell(DeviceMsg::SnapshotContext { reply: tx });
107            match rx.await {
108                Ok(Some(c)) => snaps.push(c),
109                _ => {
110                    warn!("NcclWorldActor: a device reported no context; aborting world-build");
111                    return;
112                }
113            }
114        }
115
116        // Mint a fresh stream per device for the comm.
117        let mut streams = Vec::with_capacity(snaps.len());
118        for c in &snaps {
119            match c.new_stream() {
120                Ok(s) => streams.push(s),
121                Err(e) => {
122                    warn!(error = %e, "NcclWorldActor: new_stream failed");
123                    return;
124                }
125            }
126        }
127
128        // Build the NCCL world. This can panic on no-driver hosts;
129        // catch_unwind preserves the actor.
130        let comms_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
131            Comm::from_devices(streams.clone())
132        }));
133        let comms = match comms_res {
134            Ok(Ok(cs)) => cs,
135            Ok(Err(e)) => {
136                warn!(error = ?e, "NcclWorldActor: Comm::from_devices failed");
137                return;
138            }
139            Err(_) => {
140                warn!("NcclWorldActor: NCCL not loadable on this host");
141                return;
142            }
143        };
144
145        // Spawn one CollectiveActor per rank as a child of this
146        // world. Each takes its `Comm` by move.
147        for (i, comm) in comms.into_iter().enumerate() {
148            // We need the per-device DeviceState to drive the
149            // CollectiveActor's cross-device validation. We don't
150            // have a public `state` accessor on DeviceActor; for
151            // F4.x we pass a fresh DeviceState per rank — the
152            // device-id matches the request's device-id which is
153            // what cross-validation needs.
154            let state = Arc::new(crate::device::DeviceState::new(self.config.device_ids[i]));
155            let comp: Arc<dyn CompletionStrategy> = Arc::new(HostFnCompletion::new());
156            let props = CollectiveActor::props_for_rank(comm, state, comp);
157            match ctx.spawn::<CollectiveActor>(props, &format!("nccl-{i}")) {
158                Ok(r) => self.collectives[i] = Some(r),
159                Err(e) => {
160                    warn!(error = %e, "spawn CollectiveActor[{i}] failed");
161                    return;
162                }
163            }
164        }
165        self.built = true;
166        info!(devices = self.devices.len(), "NcclWorldActor: world built");
167    }
168
169    fn dispatch_all_reduce_f32(
170        &self,
171        tensors: Vec<GpuRef<f32>>,
172        op: ReduceOp,
173        reply: oneshot::Sender<Result<(), GpuError>>,
174    ) {
175        if !self.built {
176            let _ = reply.send(Err(GpuError::Unrecoverable(
177                "NcclWorldActor: world not built yet".into(),
178            )));
179            return;
180        }
181        if tensors.len() != self.config.device_ids.len() {
182            let _ = reply.send(Err(GpuError::Unrecoverable(format!(
183                "AllReduce: expected {} tensors, got {}",
184                self.config.device_ids.len(),
185                tensors.len()
186            ))));
187            return;
188        }
189        for (i, t) in tensors.iter().enumerate() {
190            if let Some(d) = t.device_id() {
191                if d != self.config.device_ids[i] {
192                    let _ = reply.send(Err(GpuError::Unrecoverable(format!(
193                        "AllReduce: tensor[{i}] on device {d}, expected {}",
194                        self.config.device_ids[i]
195                    ))));
196                    return;
197                }
198            }
199        }
200        // Per-rank dispatch: send AllReduceF32 to each collective and
201        // await all replies, then post a single combined reply on
202        // the user's channel. NCCL requires
203        // group_start/group_end framing only when issuing multiple
204        // ops in a row; a single AllReduce per rank is fine
205        // standalone.
206        let collectives: Vec<_> = self
207            .collectives
208            .iter()
209            .map(|c| c.as_ref().unwrap().clone())
210            .collect();
211        tokio::spawn(async move {
212            let mut rxs = Vec::with_capacity(tensors.len());
213            for (c, t) in collectives.into_iter().zip(tensors) {
214                let (tx, rx) = oneshot::channel();
215                let op_clone = match op {
216                    ReduceOp::Sum => ReduceOp::Sum,
217                    ReduceOp::Prod => ReduceOp::Prod,
218                    ReduceOp::Max => ReduceOp::Max,
219                    ReduceOp::Min => ReduceOp::Min,
220                    ReduceOp::Avg => ReduceOp::Avg,
221                };
222                c.tell(CollectiveMsg::Op(Box::new(AllReduceRequest::<f32> {
223                    tensor: t,
224                    op: op_clone,
225                    reply: tx,
226                })));
227                rxs.push(rx);
228            }
229            let mut combined = Ok(());
230            for rx in rxs {
231                match rx.await {
232                    Ok(Ok(())) => {}
233                    Ok(Err(e)) => {
234                        combined = Err(e);
235                        break;
236                    }
237                    Err(_) => {
238                        combined = Err(GpuError::Unrecoverable(
239                            "AllReduce: a collective actor dropped its reply".into(),
240                        ));
241                        break;
242                    }
243                }
244            }
245            let _ = reply.send(combined);
246        });
247    }
248}
249
250#[async_trait]
251impl Actor for NcclWorldActor {
252    type Msg = NcclWorldMsg;
253
254    async fn pre_start(&mut self, ctx: &mut Context<Self>) {
255        let world_ref = ctx.self_ref().clone();
256        for (i, &ord) in self.config.device_ids.iter().enumerate() {
257            let cfg = DeviceConfig::new(ord);
258            match ctx.spawn::<DeviceActor>(DeviceActor::props(cfg), &format!("dev-{i}")) {
259                Ok(r) => {
260                    self.devices[i] = Some(r.clone());
261                    let world = world_ref.clone();
262                    let dr = r.clone();
263                    tokio::spawn(async move {
264                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
265                        world.tell(NcclWorldMsg::ChildReady {
266                            device_idx: i,
267                            device_ref: dr,
268                        });
269                    });
270                }
271                Err(e) => panic!("Unrecoverable: spawn DeviceActor[{i}]: {e}"),
272            }
273        }
274    }
275
276    async fn handle(&mut self, ctx: &mut Context<Self>, msg: NcclWorldMsg) {
277        match msg {
278            NcclWorldMsg::ChildReady {
279                device_idx,
280                device_ref,
281            } => {
282                self.devices[device_idx] = Some(device_ref.clone());
283
284                // Subscribe to this device's generation watch and
285                // bridge changes into `DeviceContextChanged` events
286                // on our own mailbox.
287                let world_ref = ctx.self_ref().clone();
288                let dr = device_ref.clone();
289                tokio::spawn(async move {
290                    let watch_rx_res = dr
291                        .ask_with(
292                            move |tx| DeviceMsg::WatchGeneration { reply: tx },
293                            std::time::Duration::from_secs(5),
294                        )
295                        .await;
296                    let mut rx = match watch_rx_res {
297                        Ok(rx) => rx,
298                        Err(_) => return,
299                    };
300                    let mut last = *rx.borrow();
301                    while rx.changed().await.is_ok() {
302                        let gen = *rx.borrow();
303                        if gen != last {
304                            last = gen;
305                            world_ref.tell(NcclWorldMsg::DeviceContextChanged {
306                                device_idx,
307                                new_generation: gen,
308                            });
309                        }
310                    }
311                });
312
313                self.try_build_world(ctx).await;
314            }
315            NcclWorldMsg::DeviceContextChanged {
316                device_idx,
317                new_generation,
318            } => {
319                let prev = self.last_generation.get(device_idx).copied().unwrap_or(0);
320                if new_generation <= prev {
321                    return;
322                }
323                self.last_generation[device_idx] = new_generation;
324                if !self.built {
325                    return;
326                }
327                tracing::warn!(
328                    device_idx,
329                    new_generation,
330                    "NcclWorldActor: device context rebuilt — tearing down NCCL world"
331                );
332                // Tear down all collective actors. They were spawned
333                // as children so stop them via their refs.
334                for c in self.collectives.iter_mut() {
335                    if let Some(c) = c.take() {
336                        c.stop();
337                    }
338                }
339                self.built = false;
340                // Try to rebuild now that the device fleet is back.
341                self.try_build_world(ctx).await;
342            }
343            NcclWorldMsg::AllReduceF32 { tensors, op, reply } => {
344                self.dispatch_all_reduce_f32(tensors, op, reply);
345            }
346        }
347    }
348}