atomr_accel_cuda/multi_device/
world_actor.rs1use std::sync::Arc;
16
17use async_trait::async_trait;
18use atomr_core::actor::{Actor, ActorRef, Context, Props};
19use cudarc::nccl::Comm;
20use tokio::sync::oneshot;
21use tracing::{info, warn};
22
23use crate::completion::{CompletionStrategy, HostFnCompletion};
24use crate::device::{DeviceActor, DeviceConfig, DeviceMsg};
25use crate::error::GpuError;
26use crate::gpu_ref::GpuRef;
27use crate::kernel::{AllReduceRequest, CollectiveActor, CollectiveMsg, ReduceOp};
28
29#[derive(Debug, Clone)]
30pub struct NcclWorldConfig {
31 pub device_ids: Vec<u32>,
32 pub root: usize,
33}
34
35impl NcclWorldConfig {
36 pub fn new(device_ids: Vec<u32>) -> Self {
37 Self {
38 device_ids,
39 root: 0,
40 }
41 }
42}
43
44pub enum NcclWorldMsg {
45 AllReduceF32 {
46 tensors: Vec<GpuRef<f32>>,
47 op: ReduceOp,
48 reply: oneshot::Sender<Result<(), GpuError>>,
49 },
50 ChildReady {
52 device_idx: usize,
53 device_ref: ActorRef<DeviceMsg>,
54 },
55 DeviceContextChanged {
60 device_idx: usize,
61 new_generation: u64,
62 },
63}
64
65pub struct NcclWorldActor {
66 config: NcclWorldConfig,
67 devices: Vec<Option<ActorRef<DeviceMsg>>>,
68 collectives: Vec<Option<ActorRef<CollectiveMsg>>>,
69 built: bool,
71 last_generation: Vec<u64>,
74 #[allow(dead_code)]
75 completion: Arc<dyn CompletionStrategy>,
76}
77
78impl NcclWorldActor {
79 pub fn props(config: NcclWorldConfig) -> Props<Self> {
80 Props::create(move || {
81 let n = config.device_ids.len();
82 NcclWorldActor {
83 config: config.clone(),
84 devices: (0..n).map(|_| None).collect(),
85 collectives: (0..n).map(|_| None).collect(),
86 built: false,
87 last_generation: vec![0; n],
88 completion: Arc::new(HostFnCompletion::new()),
89 }
90 })
91 }
92
93 async fn try_build_world(&mut self, ctx: &mut Context<Self>) {
94 if self.built {
95 return;
96 }
97 if self.devices.iter().any(|d| d.is_none()) {
98 return;
99 }
100
101 let mut snaps = Vec::with_capacity(self.devices.len());
103 for d in &self.devices {
104 let dref = d.as_ref().unwrap();
105 let (tx, rx) = oneshot::channel();
106 dref.tell(DeviceMsg::SnapshotContext { reply: tx });
107 match rx.await {
108 Ok(Some(c)) => snaps.push(c),
109 _ => {
110 warn!("NcclWorldActor: a device reported no context; aborting world-build");
111 return;
112 }
113 }
114 }
115
116 let mut streams = Vec::with_capacity(snaps.len());
118 for c in &snaps {
119 match c.new_stream() {
120 Ok(s) => streams.push(s),
121 Err(e) => {
122 warn!(error = %e, "NcclWorldActor: new_stream failed");
123 return;
124 }
125 }
126 }
127
128 let comms_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
131 Comm::from_devices(streams.clone())
132 }));
133 let comms = match comms_res {
134 Ok(Ok(cs)) => cs,
135 Ok(Err(e)) => {
136 warn!(error = ?e, "NcclWorldActor: Comm::from_devices failed");
137 return;
138 }
139 Err(_) => {
140 warn!("NcclWorldActor: NCCL not loadable on this host");
141 return;
142 }
143 };
144
145 for (i, comm) in comms.into_iter().enumerate() {
148 let state = Arc::new(crate::device::DeviceState::new(self.config.device_ids[i]));
155 let comp: Arc<dyn CompletionStrategy> = Arc::new(HostFnCompletion::new());
156 let props = CollectiveActor::props_for_rank(comm, state, comp);
157 match ctx.spawn::<CollectiveActor>(props, &format!("nccl-{i}")) {
158 Ok(r) => self.collectives[i] = Some(r),
159 Err(e) => {
160 warn!(error = %e, "spawn CollectiveActor[{i}] failed");
161 return;
162 }
163 }
164 }
165 self.built = true;
166 info!(devices = self.devices.len(), "NcclWorldActor: world built");
167 }
168
169 fn dispatch_all_reduce_f32(
170 &self,
171 tensors: Vec<GpuRef<f32>>,
172 op: ReduceOp,
173 reply: oneshot::Sender<Result<(), GpuError>>,
174 ) {
175 if !self.built {
176 let _ = reply.send(Err(GpuError::Unrecoverable(
177 "NcclWorldActor: world not built yet".into(),
178 )));
179 return;
180 }
181 if tensors.len() != self.config.device_ids.len() {
182 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
183 "AllReduce: expected {} tensors, got {}",
184 self.config.device_ids.len(),
185 tensors.len()
186 ))));
187 return;
188 }
189 for (i, t) in tensors.iter().enumerate() {
190 if let Some(d) = t.device_id() {
191 if d != self.config.device_ids[i] {
192 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
193 "AllReduce: tensor[{i}] on device {d}, expected {}",
194 self.config.device_ids[i]
195 ))));
196 return;
197 }
198 }
199 }
200 let collectives: Vec<_> = self
207 .collectives
208 .iter()
209 .map(|c| c.as_ref().unwrap().clone())
210 .collect();
211 tokio::spawn(async move {
212 let mut rxs = Vec::with_capacity(tensors.len());
213 for (c, t) in collectives.into_iter().zip(tensors) {
214 let (tx, rx) = oneshot::channel();
215 let op_clone = match op {
216 ReduceOp::Sum => ReduceOp::Sum,
217 ReduceOp::Prod => ReduceOp::Prod,
218 ReduceOp::Max => ReduceOp::Max,
219 ReduceOp::Min => ReduceOp::Min,
220 ReduceOp::Avg => ReduceOp::Avg,
221 };
222 c.tell(CollectiveMsg::Op(Box::new(AllReduceRequest::<f32> {
223 tensor: t,
224 op: op_clone,
225 reply: tx,
226 })));
227 rxs.push(rx);
228 }
229 let mut combined = Ok(());
230 for rx in rxs {
231 match rx.await {
232 Ok(Ok(())) => {}
233 Ok(Err(e)) => {
234 combined = Err(e);
235 break;
236 }
237 Err(_) => {
238 combined = Err(GpuError::Unrecoverable(
239 "AllReduce: a collective actor dropped its reply".into(),
240 ));
241 break;
242 }
243 }
244 }
245 let _ = reply.send(combined);
246 });
247 }
248}
249
250#[async_trait]
251impl Actor for NcclWorldActor {
252 type Msg = NcclWorldMsg;
253
254 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
255 let world_ref = ctx.self_ref().clone();
256 for (i, &ord) in self.config.device_ids.iter().enumerate() {
257 let cfg = DeviceConfig::new(ord);
258 match ctx.spawn::<DeviceActor>(DeviceActor::props(cfg), &format!("dev-{i}")) {
259 Ok(r) => {
260 self.devices[i] = Some(r.clone());
261 let world = world_ref.clone();
262 let dr = r.clone();
263 tokio::spawn(async move {
264 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
265 world.tell(NcclWorldMsg::ChildReady {
266 device_idx: i,
267 device_ref: dr,
268 });
269 });
270 }
271 Err(e) => panic!("Unrecoverable: spawn DeviceActor[{i}]: {e}"),
272 }
273 }
274 }
275
276 async fn handle(&mut self, ctx: &mut Context<Self>, msg: NcclWorldMsg) {
277 match msg {
278 NcclWorldMsg::ChildReady {
279 device_idx,
280 device_ref,
281 } => {
282 self.devices[device_idx] = Some(device_ref.clone());
283
284 let world_ref = ctx.self_ref().clone();
288 let dr = device_ref.clone();
289 tokio::spawn(async move {
290 let watch_rx_res = dr
291 .ask_with(
292 move |tx| DeviceMsg::WatchGeneration { reply: tx },
293 std::time::Duration::from_secs(5),
294 )
295 .await;
296 let mut rx = match watch_rx_res {
297 Ok(rx) => rx,
298 Err(_) => return,
299 };
300 let mut last = *rx.borrow();
301 while rx.changed().await.is_ok() {
302 let gen = *rx.borrow();
303 if gen != last {
304 last = gen;
305 world_ref.tell(NcclWorldMsg::DeviceContextChanged {
306 device_idx,
307 new_generation: gen,
308 });
309 }
310 }
311 });
312
313 self.try_build_world(ctx).await;
314 }
315 NcclWorldMsg::DeviceContextChanged {
316 device_idx,
317 new_generation,
318 } => {
319 let prev = self.last_generation.get(device_idx).copied().unwrap_or(0);
320 if new_generation <= prev {
321 return;
322 }
323 self.last_generation[device_idx] = new_generation;
324 if !self.built {
325 return;
326 }
327 tracing::warn!(
328 device_idx,
329 new_generation,
330 "NcclWorldActor: device context rebuilt — tearing down NCCL world"
331 );
332 for c in self.collectives.iter_mut() {
335 if let Some(c) = c.take() {
336 c.stop();
337 }
338 }
339 self.built = false;
340 self.try_build_world(ctx).await;
342 }
343 NcclWorldMsg::AllReduceF32 { tensors, op, reply } => {
344 self.dispatch_all_reduce_f32(tensors, op, reply);
345 }
346 }
347 }
348}