Skip to main content

atomr_accel_cuda/
gpu_ref.rs

1//! `GpuRef<T>` — opaque, message-friendly handle to a GPU buffer (§5.8).
2//!
3//! `Send + Sync + 'static` with no lifetime parameters, so it composes
4//! freely in any actor message type. Validity is checked at runtime by
5//! comparing a generation token against `DeviceState.generation`, which
6//! is bumped whenever the underlying `CudaContext` is rebuilt (§5.11
7//! supervision).
8//!
9//! Cross-node serialisation (`GpuToken { node_id, device_id, buffer_id,
10//! generation }` — §5.5) is intentionally **not** implemented in F1; it
11//! lands with the F4 cluster/NCCL story. F1 `GpuRef` is local-only.
12
13use std::sync::{Arc, Weak};
14
15use arc_swap::ArcSwapOption;
16
17use crate::device::DeviceState;
18use crate::error::GpuError;
19
20/// A live device-buffer handle.
21///
22/// Holds a strong `Arc` to the slice (keeping the underlying memory
23/// alive even if the `DeviceActor` has begun shutdown) plus a `Weak` to
24/// the surrounding `DeviceState` (so reference cycles cannot trap the
25/// system in a non-terminating state). Calling [`GpuRef::access`] before
26/// each use validates that the context generation has not advanced.
27pub struct GpuRef<T> {
28    inner: Arc<GpuRefInner<T>>,
29}
30
31struct GpuRefInner<T> {
32    /// Strong-keep on the device buffer.
33    slice: Arc<cudarc::driver::CudaSlice<T>>,
34    /// `DeviceState.generation` at construction time.
35    generation: u64,
36    /// Weak reference back to the device state. Avoids a cycle —
37    /// `DeviceActor` owns the strong `Arc<DeviceState>`.
38    state: Weak<DeviceState>,
39    /// The most recent `CudaStream` that wrote to this buffer. Library
40    /// actors call [`GpuRef::record_write`] after enqueueing a kernel
41    /// that mutates the slice. Cross-stream consumers (`P2pTopology`,
42    /// pipeline stages) read this to inject a `CudaEvent` wait without
43    /// a host roundtrip.
44    last_write_stream: ArcSwapOption<cudarc::driver::CudaStream>,
45}
46
47impl<T> Clone for GpuRef<T> {
48    fn clone(&self) -> Self {
49        Self {
50            inner: self.inner.clone(),
51        }
52    }
53}
54
55impl<T> std::fmt::Debug for GpuRef<T> {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("GpuRef")
58            .field("generation", &self.inner.generation)
59            .field("len", &self.inner.slice.len())
60            .finish()
61    }
62}
63
64impl<T> GpuRef<T> {
65    /// Wrap a raw `Arc<CudaSlice<T>>` produced by a `DeviceActor` into a
66    /// `GpuRef<T>`.
67    ///
68    /// Only `DeviceActor` (and code reachable from its dispatcher) should
69    /// call this — outside callers must obtain `GpuRef`s by asking the
70    /// `DeviceActor` to allocate.
71    pub fn new(slice: Arc<cudarc::driver::CudaSlice<T>>, state: &Arc<DeviceState>) -> Self {
72        let generation = state.generation();
73        Self {
74            inner: Arc::new(GpuRefInner {
75                slice,
76                generation,
77                state: Arc::downgrade(state),
78                last_write_stream: ArcSwapOption::empty(),
79            }),
80        }
81    }
82
83    /// Validate the reference and return access to the underlying slice.
84    ///
85    /// Returns [`GpuError::GpuRefStale`] if either:
86    /// - the owning `DeviceState` has been dropped,
87    /// - the device is no longer accepting operations, or
88    /// - the context generation has advanced past the one this ref was
89    ///   minted with (i.e. a poisoned-context rebuild has happened).
90    pub fn access(&self) -> Result<&Arc<cudarc::driver::CudaSlice<T>>, GpuError> {
91        let state = self
92            .inner
93            .state
94            .upgrade()
95            .ok_or(GpuError::GpuRefStale("device state dropped"))?;
96        if !state.accepting_ops() {
97            return Err(GpuError::GpuRefStale("device shutting down"));
98        }
99        if state.generation() != self.inner.generation {
100            return Err(GpuError::GpuRefStale("context rebuilt"));
101        }
102        Ok(&self.inner.slice)
103    }
104
105    /// Generation token at construction. Exposed for tests.
106    pub fn generation(&self) -> u64 {
107        self.inner.generation
108    }
109
110    /// Length in elements of the underlying slice.
111    pub fn len(&self) -> usize {
112        self.inner.slice.len()
113    }
114
115    pub fn is_empty(&self) -> bool {
116        self.inner.slice.is_empty()
117    }
118
119    /// Device id this `GpuRef` was minted on, or `None` if the owning
120    /// [`DeviceState`] has been dropped.
121    pub fn device_id(&self) -> Option<u32> {
122        self.inner.state.upgrade().map(|s| s.device_id())
123    }
124
125    /// Record the stream that most recently wrote to this buffer.
126    /// Library actors (BlasActor, CudnnActor, FftActor, etc.) call this
127    /// after enqueueing a kernel that mutates the slice so that
128    /// downstream consumers can inject a cross-stream wait.
129    pub fn record_write(&self, stream: &Arc<cudarc::driver::CudaStream>) {
130        self.inner.last_write_stream.store(Some(stream.clone()));
131    }
132
133    /// Most recent producing stream, if any. Returns `None` when no
134    /// kernel has been recorded against this buffer.
135    pub fn last_write_stream(&self) -> Option<Arc<cudarc::driver::CudaStream>> {
136        self.inner.last_write_stream.load_full()
137    }
138
139    /// Phase 4.5++ — opaque `CUdeviceptr` (`u64`) for downstream
140    /// raw-pointer FFI APIs (TensorRT `enqueueV3`, `cuStreamWriteValue64`,
141    /// custom CUDA modules that aren't fronted by cudarc).
142    ///
143    /// Validates the `GpuRef` first via [`GpuRef::access`]. The pointer
144    /// is captured against the slice's own associated stream — the
145    /// `_guard` returned by cudarc's `device_ptr()` is dropped before
146    /// the function returns, but the underlying allocation outlives
147    /// this call because the inner `Arc<CudaSlice<T>>` is held by
148    /// `self`. Callers must ensure they don't dispatch the resulting
149    /// pointer on a stream that has already gone out of scope; in
150    /// practice the pointer is consumed immediately by an FFI shim
151    /// (TensorRT enqueueV3, etc.) on a stream the caller owns.
152    ///
153    /// Returns [`GpuError::GpuRefStale`] if the underlying generation
154    /// token is stale or the device is shutting down.
155    pub fn raw_device_ptr(&self) -> Result<u64, GpuError> {
156        use cudarc::driver::DevicePtr;
157        let slice = self.access()?;
158        let stream = slice.stream();
159        let (ptr, _guard) = slice.device_ptr(stream);
160        // `_guard` is a `SyncOnDrop` whose lifetime ties the pointer to
161        // `slice`; we drop it here. The caller is expected to use the
162        // returned `u64` immediately on an FFI call. The underlying
163        // CudaSlice<T> remains alive via the strong Arc held by
164        // `self.inner.slice` for as long as this `GpuRef` lives.
165        Ok(ptr)
166    }
167}
168
169#[cfg(test)]
170impl<T> GpuRef<T> {
171    /// **Test-only** stub constructor for unit tests that don't need
172    /// a real CUDA context. Returns a `GpuRef<T>` whose underlying
173    /// `CudaSlice<T>` is logically uninitialized — the test must
174    /// **never** call `.access()` on it, dispatch it through a
175    /// kernel actor, or otherwise let it reach cudarc.
176    ///
177    /// SAFETY contract: the caller must ensure the `GpuRef<T>` is
178    /// leaked (e.g. via `Box::leak` of the surrounding container)
179    /// so the inner `Arc<CudaSlice<T>>` never reaches refcount zero.
180    /// Otherwise cudarc's `Drop for CudaSlice<T>` runs with
181    /// uninitialized memory and aborts the process.
182    pub(crate) fn for_test_no_gpu_leaked() -> Self {
183        // Allocate an uninitialized box of CudaSlice<T> on the heap,
184        // then leak it. We construct an `Arc<CudaSlice<T>>` via
185        // `Arc::from_raw` so cudarc's Drop only runs if the strong
186        // count returns to 1 — which `Box::leak` of the surrounding
187        // request guarantees never happens.
188        use std::mem::MaybeUninit;
189        let boxed: Box<MaybeUninit<cudarc::driver::CudaSlice<T>>> = Box::new(MaybeUninit::uninit());
190        let leaked: *mut MaybeUninit<cudarc::driver::CudaSlice<T>> = Box::into_raw(boxed);
191        // SAFETY: the pointer is valid (just-allocated heap) and we
192        // forge an Arc whose strong count is 1. The contract above
193        // requires the surrounding test to leak the surrounding box
194        // so this Arc's count never decrements.
195        let arc_slice: std::sync::Arc<cudarc::driver::CudaSlice<T>> =
196            unsafe { std::sync::Arc::from_raw(leaked as *const cudarc::driver::CudaSlice<T>) };
197        let state = std::sync::Arc::new(crate::device::DeviceState::new(0));
198        Self::new(arc_slice, &state)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::device::DeviceState;
206
207    #[test]
208    fn generation_mismatch_fails_validate() {
209        // We can't construct a real CudaSlice without a GPU. Instead we
210        // exercise the generation-check logic by faking the slice via a
211        // pointer-only view: this test does NOT touch CUDA memory.
212        // Verify the generation accessor and accepting_ops flag.
213        let state = Arc::new(DeviceState::new(0));
214        assert_eq!(state.generation(), 0);
215        state.bump_generation();
216        assert_eq!(state.generation(), 1);
217        assert!(state.accepting_ops());
218        state.begin_shutdown();
219        assert!(!state.accepting_ops());
220    }
221}