Skip to main content

atomr_accel_cuda/p2p/
mod.rs

1//! P2P (peer-to-peer) topology + cross-device async memcpy.
2//!
3//! cudarc 0.19 exposes peer access only at the `sys` layer:
4//! `cuDeviceCanAccessPeer`, `cuCtxEnablePeerAccess`,
5//! `cuMemcpyPeerAsync`. This module wraps those with explicit
6//! `unsafe` blocks behind an actor surface.
7//!
8//! Lifecycle:
9//! 1. Construct with N `ActorRef<DeviceMsg>` siblings.
10//! 2. Send `EnableAll` — actor snapshots each device's
11//!    `Arc<CudaContext>`, probes `cuDeviceCanAccessPeer` for every
12//!    pair, calls `cuCtxEnablePeerAccess` on directions that
13//!    succeed, and replies with the resulting [`P2pGraph`].
14//! 3. Send `CopyF32 { src, src_device, dst, dst_device }` — actor
15//!    issues `cuMemcpyPeerAsync` on a fresh destination-side stream
16//!    and replies after `cudaStreamSynchronize`.
17
18#![allow(clippy::needless_range_loop)]
19
20use std::collections::HashSet;
21use std::sync::Arc;
22
23use async_trait::async_trait;
24use atomr_core::actor::{Actor, ActorRef, Context, Props};
25use cudarc::driver::sys as driver_sys;
26use cudarc::driver::CudaContext;
27use cudarc::driver::DevicePtr;
28use cudarc::driver::DevicePtrMut;
29use parking_lot::Mutex;
30use tokio::sync::oneshot;
31use tracing::info;
32
33use crate::device::DeviceMsg;
34use crate::error::GpuError;
35use crate::gpu_ref::GpuRef;
36
37#[derive(Debug, Clone)]
38pub struct P2pGraph {
39    pub edges: Vec<Vec<bool>>,
40    pub device_count: u32,
41}
42
43impl P2pGraph {
44    pub fn new(n: u32) -> Self {
45        Self {
46            edges: (0..n).map(|_| vec![false; n as usize]).collect(),
47            device_count: n,
48        }
49    }
50
51    pub fn can_pair(&self, a: u32, b: u32) -> bool {
52        self.edges[a as usize][b as usize]
53    }
54
55    /// Connected components (NVLink islands).
56    pub fn islands(&self) -> Vec<HashSet<u32>> {
57        let n = self.device_count as usize;
58        let mut visited = vec![false; n];
59        let mut out = Vec::new();
60        for i in 0..n {
61            if visited[i] {
62                continue;
63            }
64            let mut stack = vec![i];
65            let mut island = HashSet::new();
66            while let Some(j) = stack.pop() {
67                if visited[j] {
68                    continue;
69                }
70                visited[j] = true;
71                island.insert(j as u32);
72                for k in 0..n {
73                    if !visited[k] && (self.edges[j][k] || self.edges[k][j]) {
74                        stack.push(k);
75                    }
76                }
77            }
78            out.push(island);
79        }
80        out
81    }
82}
83
84pub enum P2pMsg {
85    EnableAll {
86        reply: oneshot::Sender<Result<P2pGraph, GpuError>>,
87    },
88    CanAccess {
89        from: u32,
90        to: u32,
91        reply: oneshot::Sender<bool>,
92    },
93    /// Async peer copy from `src` (on `src_device`) to `dst`
94    /// (on `dst_device`). Both `GpuRef`s must be valid; copy size
95    /// is `min(src.len, dst.len)` × sizeof(f32).
96    CopyF32 {
97        src: GpuRef<f32>,
98        src_device: u32,
99        dst: GpuRef<f32>,
100        dst_device: u32,
101        reply: oneshot::Sender<Result<(), GpuError>>,
102    },
103    Topology {
104        reply: oneshot::Sender<P2pGraph>,
105    },
106    /// Internal: a device's generation_watch fired, meaning its
107    /// `ContextActor` rebuilt the underlying `CudaContext`. The
108    /// cached `Arc<CudaContext>` in `self.contexts` is now stale —
109    /// re-snapshot it. Peer-access mappings persist across rebuilds
110    /// only if the new context lives in the same primary slot, so
111    /// the topology marks itself disabled and a follow-up
112    /// `EnableAll` is required to refresh `graph` + re-issue
113    /// `cuCtxEnablePeerAccess`.
114    RefreshDevice {
115        device_idx: u32,
116        new_generation: u64,
117    },
118}
119
120struct SendCtx(Arc<CudaContext>);
121unsafe impl Send for SendCtx {}
122unsafe impl Sync for SendCtx {}
123
124pub struct P2pTopology {
125    devices: Vec<ActorRef<DeviceMsg>>,
126    contexts: Mutex<Vec<Option<SendCtx>>>,
127    graph: P2pGraph,
128    enabled: bool,
129}
130
131impl P2pTopology {
132    pub fn props(devices: Vec<ActorRef<DeviceMsg>>) -> Props<Self> {
133        let n = devices.len() as u32;
134        Props::create(move || P2pTopology {
135            devices: devices.clone(),
136            contexts: Mutex::new((0..n).map(|_| None).collect()),
137            graph: P2pGraph::new(n),
138            enabled: false,
139        })
140    }
141}
142
143#[async_trait]
144impl Actor for P2pTopology {
145    type Msg = P2pMsg;
146
147    async fn pre_start(&mut self, ctx: &mut Context<Self>) {
148        // Subscribe to each sibling DeviceActor's generation_watch so a
149        // ContextActor rebuild invalidates this topology's cached
150        // contexts. The bridge spawns one task per device that asks
151        // the device for its watch::Receiver and forwards every change
152        // back to us as a RefreshDevice message.
153        let self_ref = ctx.self_ref().clone();
154        for (idx, dev) in self.devices.iter().enumerate() {
155            let topo_ref = self_ref.clone();
156            let dev_ref = dev.clone();
157            tokio::spawn(async move {
158                let watch_rx_res = dev_ref
159                    .ask_with(
160                        move |tx| DeviceMsg::WatchGeneration { reply: tx },
161                        std::time::Duration::from_secs(5),
162                    )
163                    .await;
164                let mut rx = match watch_rx_res {
165                    Ok(rx) => rx,
166                    Err(_) => return,
167                };
168                let mut last = *rx.borrow();
169                while rx.changed().await.is_ok() {
170                    let gen = *rx.borrow();
171                    if gen != last {
172                        last = gen;
173                        topo_ref.tell(P2pMsg::RefreshDevice {
174                            device_idx: idx as u32,
175                            new_generation: gen,
176                        });
177                    }
178                }
179            });
180        }
181    }
182
183    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: P2pMsg) {
184        match msg {
185            P2pMsg::EnableAll { reply } => {
186                let n = self.devices.len();
187                // Snapshot each device's context. Returns None until
188                // ContextActor::Init completes.
189                let mut snaps: Vec<Option<Arc<CudaContext>>> = Vec::with_capacity(n);
190                for d in &self.devices {
191                    let (tx, rx) = oneshot::channel();
192                    d.tell(DeviceMsg::SnapshotContext { reply: tx });
193                    match rx.await {
194                        Ok(c) => snaps.push(c),
195                        Err(_) => snaps.push(None),
196                    }
197                }
198                {
199                    let mut g = self.contexts.lock();
200                    for (i, s) in snaps.iter().enumerate() {
201                        g[i] = s.clone().map(SendCtx);
202                    }
203                }
204
205                let mut graph = P2pGraph::new(n as u32);
206                let any_unloadable = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
207                    for i in 0..n {
208                        let Some(ctx_a) = snaps[i].as_ref() else {
209                            continue;
210                        };
211                        for j in 0..n {
212                            if i == j {
213                                graph.edges[i][j] = true;
214                                continue;
215                            }
216                            let Some(_) = snaps[j].as_ref() else { continue };
217                            let mut can = 0i32;
218                            // cuDeviceCanAccessPeer takes ordinals.
219                            let s = unsafe {
220                                driver_sys::cuDeviceCanAccessPeer(
221                                    &mut can as *mut _,
222                                    ctx_a.cu_device(),
223                                    snaps[j].as_ref().unwrap().cu_device(),
224                                )
225                            };
226                            if s == driver_sys::cudaError_enum::CUDA_SUCCESS && can == 1 {
227                                graph.edges[i][j] = true;
228                            }
229                        }
230                    }
231                    Ok::<(), GpuError>(())
232                }));
233                if any_unloadable.is_err() {
234                    let _ = reply.send(Err(GpuError::Unrecoverable(
235                        "P2pTopology::EnableAll: CUDA driver not loadable".into(),
236                    )));
237                    return;
238                }
239
240                // Enable peer access in each direction where probe
241                // succeeded. cuCtxEnablePeerAccess must be called from
242                // the source context (set current) targeting the peer.
243                let enable_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
244                    for i in 0..n {
245                        let Some(ctx_a) = snaps[i].as_ref() else {
246                            continue;
247                        };
248                        let _ = ctx_a.bind_to_thread();
249                        for j in 0..n {
250                            if i == j || !graph.edges[i][j] {
251                                continue;
252                            }
253                            let peer = snaps[j].as_ref().unwrap();
254                            let s = unsafe { driver_sys::cuCtxEnablePeerAccess(peer.cu_ctx(), 0) };
255                            // CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED is fine.
256                            if s != driver_sys::cudaError_enum::CUDA_SUCCESS
257                                && s
258                                    != driver_sys::cudaError_enum::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED
259                            {
260                                graph.edges[i][j] = false;
261                            }
262                        }
263                    }
264                }));
265                let _ = enable_res; // partial enables are best-effort
266
267                self.graph = graph.clone();
268                self.enabled = true;
269                info!(devices = n, "P2pTopology::EnableAll done");
270                let _ = reply.send(Ok(graph));
271            }
272            P2pMsg::CanAccess { from, to, reply } => {
273                let v = if from == to {
274                    true
275                } else {
276                    self.graph
277                        .edges
278                        .get(from as usize)
279                        .and_then(|row| row.get(to as usize).copied())
280                        .unwrap_or(false)
281                };
282                let _ = reply.send(v);
283            }
284            P2pMsg::CopyF32 {
285                src,
286                src_device,
287                dst,
288                dst_device,
289                reply,
290            } => {
291                if !self.enabled {
292                    let _ = reply.send(Err(GpuError::Unrecoverable(
293                        "P2pTopology: call EnableAll first".into(),
294                    )));
295                    return;
296                }
297                if !self.graph.can_pair(src_device, dst_device) {
298                    let _ = reply.send(Err(GpuError::Unrecoverable(format!(
299                        "P2pTopology: device {src_device} cannot peer-access {dst_device}"
300                    ))));
301                    return;
302                }
303                let ctxs = self.contexts.lock();
304                let src_ctx = match ctxs.get(src_device as usize).and_then(|c| c.as_ref()) {
305                    Some(c) => c.0.clone(),
306                    None => {
307                        let _ = reply.send(Err(GpuError::Unrecoverable(format!(
308                            "P2pTopology: src device {src_device} context not available"
309                        ))));
310                        return;
311                    }
312                };
313                let dst_ctx = match ctxs.get(dst_device as usize).and_then(|c| c.as_ref()) {
314                    Some(c) => c.0.clone(),
315                    None => {
316                        let _ = reply.send(Err(GpuError::Unrecoverable(format!(
317                            "P2pTopology: dst device {dst_device} context not available"
318                        ))));
319                        return;
320                    }
321                };
322                drop(ctxs);
323
324                let src_slice = match src.access() {
325                    Ok(s) => s.clone(),
326                    Err(e) => {
327                        let _ = reply.send(Err(e));
328                        return;
329                    }
330                };
331                let dst_slice = match dst.access() {
332                    Ok(s) => s.clone(),
333                    Err(e) => {
334                        let _ = reply.send(Err(e));
335                        return;
336                    }
337                };
338                let mut dst_owned = match Arc::try_unwrap(dst_slice) {
339                    Ok(s) => s,
340                    Err(_) => {
341                        let _ = reply.send(Err(GpuError::Unrecoverable(
342                            "P2pCopy: dst has multiple live references".into(),
343                        )));
344                        return;
345                    }
346                };
347
348                let len = std::cmp::min(src_slice.len(), dst_owned.len());
349                let bytes = len * std::mem::size_of::<f32>();
350                // Mint a destination-side stream for the copy.
351                let dst_stream = match dst_ctx.new_stream() {
352                    Ok(s) => s,
353                    Err(e) => {
354                        let _ = reply.send(Err(GpuError::LibraryError {
355                            lib: "driver",
356                            msg: format!("dst new_stream: {e}"),
357                        }));
358                        return;
359                    }
360                };
361
362                // F9.2: if the source `GpuRef` carries a recorded
363                // last-write stream (set by upstream BlasActor /
364                // CudnnActor / etc.), inject a cross-stream event
365                // wait so the peer copy doesn't race with in-flight
366                // writes — and we don't have to host-synchronize.
367                let last_write_src = src.last_write_stream();
368                let copy_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
369                    if let Some(src_stream) = last_write_src.as_ref() {
370                        let ev =
371                            src_stream
372                                .record_event(None)
373                                .map_err(|e| GpuError::LibraryError {
374                                    lib: "driver",
375                                    msg: format!("p2p: src record_event: {e}"),
376                                })?;
377                        // Wait on the destination side. Cross-context
378                        // event waits work because cuCtxEnablePeerAccess
379                        // was already called in EnableAll.
380                        dst_stream.wait(&ev).map_err(|e| GpuError::LibraryError {
381                            lib: "driver",
382                            msg: format!("p2p: dst wait: {e}"),
383                        })?;
384                    }
385                    let (src_ptr, _g1) = src_slice.device_ptr(&dst_stream);
386                    let (dst_ptr, _g2) = dst_owned.device_ptr_mut(&dst_stream);
387                    let s = unsafe {
388                        driver_sys::cuMemcpyPeerAsync(
389                            dst_ptr,
390                            dst_ctx.cu_ctx(),
391                            src_ptr,
392                            src_ctx.cu_ctx(),
393                            bytes,
394                            dst_stream.cu_stream(),
395                        )
396                    };
397                    drop((_g1, _g2));
398                    if s != driver_sys::cudaError_enum::CUDA_SUCCESS {
399                        return Err(GpuError::LibraryError {
400                            lib: "driver",
401                            msg: format!("cuMemcpyPeerAsync: {s:?}"),
402                        });
403                    }
404                    // Synchronize as the host-visible barrier. A
405                    // future improvement (F10) would replace this
406                    // with a HostFnCompletion-style callback so the
407                    // actor never blocks the OS thread.
408                    dst_stream
409                        .synchronize()
410                        .map_err(|e| GpuError::LibraryError {
411                            lib: "driver",
412                            msg: format!("cudaStreamSynchronize: {e}"),
413                        })?;
414                    Ok(())
415                }));
416                let result = match copy_res {
417                    Ok(r) => r,
418                    Err(_) => Err(GpuError::Unrecoverable(
419                        "P2pCopy: CUDA driver not loadable".into(),
420                    )),
421                };
422                dst.record_write(&dst_stream);
423                let _ = reply.send(result);
424                drop(dst_owned);
425            }
426            P2pMsg::Topology { reply } => {
427                let _ = reply.send(self.graph.clone());
428            }
429            P2pMsg::RefreshDevice {
430                device_idx,
431                new_generation,
432            } => {
433                info!(
434                    device_idx,
435                    new_generation,
436                    "P2pTopology: device context rebuilt — invalidating cached snapshot"
437                );
438                // Re-snapshot just the affected device. Other devices'
439                // cached contexts remain valid until they themselves
440                // bump their generation.
441                let dev = match self.devices.get(device_idx as usize) {
442                    Some(d) => d.clone(),
443                    None => return,
444                };
445                let (tx, rx) = oneshot::channel();
446                dev.tell(DeviceMsg::SnapshotContext { reply: tx });
447                let new_ctx = rx.await.unwrap_or_default();
448                {
449                    let mut g = self.contexts.lock();
450                    if let Some(slot) = g.get_mut(device_idx as usize) {
451                        *slot = new_ctx.map(SendCtx);
452                    }
453                }
454                // Peer-access on the rebuilt context isn't enabled
455                // automatically. Force callers to re-issue EnableAll
456                // before the next CopyF32 — otherwise the copy would
457                // surface a CUDA_ERROR_INVALID_VALUE.
458                self.enabled = false;
459            }
460        }
461    }
462}