1#![allow(clippy::needless_range_loop)]
19
20use std::collections::HashSet;
21use std::sync::Arc;
22
23use async_trait::async_trait;
24use atomr_core::actor::{Actor, ActorRef, Context, Props};
25use cudarc::driver::sys as driver_sys;
26use cudarc::driver::CudaContext;
27use cudarc::driver::DevicePtr;
28use cudarc::driver::DevicePtrMut;
29use parking_lot::Mutex;
30use tokio::sync::oneshot;
31use tracing::info;
32
33use crate::device::DeviceMsg;
34use crate::error::GpuError;
35use crate::gpu_ref::GpuRef;
36
37#[derive(Debug, Clone)]
38pub struct P2pGraph {
39 pub edges: Vec<Vec<bool>>,
40 pub device_count: u32,
41}
42
43impl P2pGraph {
44 pub fn new(n: u32) -> Self {
45 Self {
46 edges: (0..n).map(|_| vec![false; n as usize]).collect(),
47 device_count: n,
48 }
49 }
50
51 pub fn can_pair(&self, a: u32, b: u32) -> bool {
52 self.edges[a as usize][b as usize]
53 }
54
55 pub fn islands(&self) -> Vec<HashSet<u32>> {
57 let n = self.device_count as usize;
58 let mut visited = vec![false; n];
59 let mut out = Vec::new();
60 for i in 0..n {
61 if visited[i] {
62 continue;
63 }
64 let mut stack = vec![i];
65 let mut island = HashSet::new();
66 while let Some(j) = stack.pop() {
67 if visited[j] {
68 continue;
69 }
70 visited[j] = true;
71 island.insert(j as u32);
72 for k in 0..n {
73 if !visited[k] && (self.edges[j][k] || self.edges[k][j]) {
74 stack.push(k);
75 }
76 }
77 }
78 out.push(island);
79 }
80 out
81 }
82}
83
84pub enum P2pMsg {
85 EnableAll {
86 reply: oneshot::Sender<Result<P2pGraph, GpuError>>,
87 },
88 CanAccess {
89 from: u32,
90 to: u32,
91 reply: oneshot::Sender<bool>,
92 },
93 CopyF32 {
97 src: GpuRef<f32>,
98 src_device: u32,
99 dst: GpuRef<f32>,
100 dst_device: u32,
101 reply: oneshot::Sender<Result<(), GpuError>>,
102 },
103 Topology {
104 reply: oneshot::Sender<P2pGraph>,
105 },
106 RefreshDevice {
115 device_idx: u32,
116 new_generation: u64,
117 },
118}
119
120struct SendCtx(Arc<CudaContext>);
121unsafe impl Send for SendCtx {}
122unsafe impl Sync for SendCtx {}
123
124pub struct P2pTopology {
125 devices: Vec<ActorRef<DeviceMsg>>,
126 contexts: Mutex<Vec<Option<SendCtx>>>,
127 graph: P2pGraph,
128 enabled: bool,
129}
130
131impl P2pTopology {
132 pub fn props(devices: Vec<ActorRef<DeviceMsg>>) -> Props<Self> {
133 let n = devices.len() as u32;
134 Props::create(move || P2pTopology {
135 devices: devices.clone(),
136 contexts: Mutex::new((0..n).map(|_| None).collect()),
137 graph: P2pGraph::new(n),
138 enabled: false,
139 })
140 }
141}
142
143#[async_trait]
144impl Actor for P2pTopology {
145 type Msg = P2pMsg;
146
147 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
148 let self_ref = ctx.self_ref().clone();
154 for (idx, dev) in self.devices.iter().enumerate() {
155 let topo_ref = self_ref.clone();
156 let dev_ref = dev.clone();
157 tokio::spawn(async move {
158 let watch_rx_res = dev_ref
159 .ask_with(
160 move |tx| DeviceMsg::WatchGeneration { reply: tx },
161 std::time::Duration::from_secs(5),
162 )
163 .await;
164 let mut rx = match watch_rx_res {
165 Ok(rx) => rx,
166 Err(_) => return,
167 };
168 let mut last = *rx.borrow();
169 while rx.changed().await.is_ok() {
170 let gen = *rx.borrow();
171 if gen != last {
172 last = gen;
173 topo_ref.tell(P2pMsg::RefreshDevice {
174 device_idx: idx as u32,
175 new_generation: gen,
176 });
177 }
178 }
179 });
180 }
181 }
182
183 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: P2pMsg) {
184 match msg {
185 P2pMsg::EnableAll { reply } => {
186 let n = self.devices.len();
187 let mut snaps: Vec<Option<Arc<CudaContext>>> = Vec::with_capacity(n);
190 for d in &self.devices {
191 let (tx, rx) = oneshot::channel();
192 d.tell(DeviceMsg::SnapshotContext { reply: tx });
193 match rx.await {
194 Ok(c) => snaps.push(c),
195 Err(_) => snaps.push(None),
196 }
197 }
198 {
199 let mut g = self.contexts.lock();
200 for (i, s) in snaps.iter().enumerate() {
201 g[i] = s.clone().map(SendCtx);
202 }
203 }
204
205 let mut graph = P2pGraph::new(n as u32);
206 let any_unloadable = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
207 for i in 0..n {
208 let Some(ctx_a) = snaps[i].as_ref() else {
209 continue;
210 };
211 for j in 0..n {
212 if i == j {
213 graph.edges[i][j] = true;
214 continue;
215 }
216 let Some(_) = snaps[j].as_ref() else { continue };
217 let mut can = 0i32;
218 let s = unsafe {
220 driver_sys::cuDeviceCanAccessPeer(
221 &mut can as *mut _,
222 ctx_a.cu_device(),
223 snaps[j].as_ref().unwrap().cu_device(),
224 )
225 };
226 if s == driver_sys::cudaError_enum::CUDA_SUCCESS && can == 1 {
227 graph.edges[i][j] = true;
228 }
229 }
230 }
231 Ok::<(), GpuError>(())
232 }));
233 if any_unloadable.is_err() {
234 let _ = reply.send(Err(GpuError::Unrecoverable(
235 "P2pTopology::EnableAll: CUDA driver not loadable".into(),
236 )));
237 return;
238 }
239
240 let enable_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
244 for i in 0..n {
245 let Some(ctx_a) = snaps[i].as_ref() else {
246 continue;
247 };
248 let _ = ctx_a.bind_to_thread();
249 for j in 0..n {
250 if i == j || !graph.edges[i][j] {
251 continue;
252 }
253 let peer = snaps[j].as_ref().unwrap();
254 let s = unsafe { driver_sys::cuCtxEnablePeerAccess(peer.cu_ctx(), 0) };
255 if s != driver_sys::cudaError_enum::CUDA_SUCCESS
257 && s
258 != driver_sys::cudaError_enum::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED
259 {
260 graph.edges[i][j] = false;
261 }
262 }
263 }
264 }));
265 let _ = enable_res; self.graph = graph.clone();
268 self.enabled = true;
269 info!(devices = n, "P2pTopology::EnableAll done");
270 let _ = reply.send(Ok(graph));
271 }
272 P2pMsg::CanAccess { from, to, reply } => {
273 let v = if from == to {
274 true
275 } else {
276 self.graph
277 .edges
278 .get(from as usize)
279 .and_then(|row| row.get(to as usize).copied())
280 .unwrap_or(false)
281 };
282 let _ = reply.send(v);
283 }
284 P2pMsg::CopyF32 {
285 src,
286 src_device,
287 dst,
288 dst_device,
289 reply,
290 } => {
291 if !self.enabled {
292 let _ = reply.send(Err(GpuError::Unrecoverable(
293 "P2pTopology: call EnableAll first".into(),
294 )));
295 return;
296 }
297 if !self.graph.can_pair(src_device, dst_device) {
298 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
299 "P2pTopology: device {src_device} cannot peer-access {dst_device}"
300 ))));
301 return;
302 }
303 let ctxs = self.contexts.lock();
304 let src_ctx = match ctxs.get(src_device as usize).and_then(|c| c.as_ref()) {
305 Some(c) => c.0.clone(),
306 None => {
307 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
308 "P2pTopology: src device {src_device} context not available"
309 ))));
310 return;
311 }
312 };
313 let dst_ctx = match ctxs.get(dst_device as usize).and_then(|c| c.as_ref()) {
314 Some(c) => c.0.clone(),
315 None => {
316 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
317 "P2pTopology: dst device {dst_device} context not available"
318 ))));
319 return;
320 }
321 };
322 drop(ctxs);
323
324 let src_slice = match src.access() {
325 Ok(s) => s.clone(),
326 Err(e) => {
327 let _ = reply.send(Err(e));
328 return;
329 }
330 };
331 let dst_slice = match dst.access() {
332 Ok(s) => s.clone(),
333 Err(e) => {
334 let _ = reply.send(Err(e));
335 return;
336 }
337 };
338 let mut dst_owned = match Arc::try_unwrap(dst_slice) {
339 Ok(s) => s,
340 Err(_) => {
341 let _ = reply.send(Err(GpuError::Unrecoverable(
342 "P2pCopy: dst has multiple live references".into(),
343 )));
344 return;
345 }
346 };
347
348 let len = std::cmp::min(src_slice.len(), dst_owned.len());
349 let bytes = len * std::mem::size_of::<f32>();
350 let dst_stream = match dst_ctx.new_stream() {
352 Ok(s) => s,
353 Err(e) => {
354 let _ = reply.send(Err(GpuError::LibraryError {
355 lib: "driver",
356 msg: format!("dst new_stream: {e}"),
357 }));
358 return;
359 }
360 };
361
362 let last_write_src = src.last_write_stream();
368 let copy_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
369 if let Some(src_stream) = last_write_src.as_ref() {
370 let ev =
371 src_stream
372 .record_event(None)
373 .map_err(|e| GpuError::LibraryError {
374 lib: "driver",
375 msg: format!("p2p: src record_event: {e}"),
376 })?;
377 dst_stream.wait(&ev).map_err(|e| GpuError::LibraryError {
381 lib: "driver",
382 msg: format!("p2p: dst wait: {e}"),
383 })?;
384 }
385 let (src_ptr, _g1) = src_slice.device_ptr(&dst_stream);
386 let (dst_ptr, _g2) = dst_owned.device_ptr_mut(&dst_stream);
387 let s = unsafe {
388 driver_sys::cuMemcpyPeerAsync(
389 dst_ptr,
390 dst_ctx.cu_ctx(),
391 src_ptr,
392 src_ctx.cu_ctx(),
393 bytes,
394 dst_stream.cu_stream(),
395 )
396 };
397 drop((_g1, _g2));
398 if s != driver_sys::cudaError_enum::CUDA_SUCCESS {
399 return Err(GpuError::LibraryError {
400 lib: "driver",
401 msg: format!("cuMemcpyPeerAsync: {s:?}"),
402 });
403 }
404 dst_stream
409 .synchronize()
410 .map_err(|e| GpuError::LibraryError {
411 lib: "driver",
412 msg: format!("cudaStreamSynchronize: {e}"),
413 })?;
414 Ok(())
415 }));
416 let result = match copy_res {
417 Ok(r) => r,
418 Err(_) => Err(GpuError::Unrecoverable(
419 "P2pCopy: CUDA driver not loadable".into(),
420 )),
421 };
422 dst.record_write(&dst_stream);
423 let _ = reply.send(result);
424 drop(dst_owned);
425 }
426 P2pMsg::Topology { reply } => {
427 let _ = reply.send(self.graph.clone());
428 }
429 P2pMsg::RefreshDevice {
430 device_idx,
431 new_generation,
432 } => {
433 info!(
434 device_idx,
435 new_generation,
436 "P2pTopology: device context rebuilt — invalidating cached snapshot"
437 );
438 let dev = match self.devices.get(device_idx as usize) {
442 Some(d) => d.clone(),
443 None => return,
444 };
445 let (tx, rx) = oneshot::channel();
446 dev.tell(DeviceMsg::SnapshotContext { reply: tx });
447 let new_ctx = rx.await.unwrap_or_default();
448 {
449 let mut g = self.contexts.lock();
450 if let Some(slot) = g.get_mut(device_idx as usize) {
451 *slot = new_ctx.map(SendCtx);
452 }
453 }
454 self.enabled = false;
459 }
460 }
461 }
462}