atomr_accel_cuda/gpu_ref.rs
1//! `GpuRef<T>` — opaque, message-friendly handle to a GPU buffer (§5.8).
2//!
3//! `Send + Sync + 'static` with no lifetime parameters, so it composes
4//! freely in any actor message type. Validity is checked at runtime by
5//! comparing a generation token against `DeviceState.generation`, which
6//! is bumped whenever the underlying `CudaContext` is rebuilt (§5.11
7//! supervision).
8//!
9//! Cross-node serialisation (`GpuToken { node_id, device_id, buffer_id,
10//! generation }` — §5.5) is intentionally **not** implemented in F1; it
11//! lands with the F4 cluster/NCCL story. F1 `GpuRef` is local-only.
12
13use std::sync::{Arc, Weak};
14
15use arc_swap::ArcSwapOption;
16
17use crate::device::DeviceState;
18use crate::error::GpuError;
19
20/// A live device-buffer handle.
21///
22/// Holds a strong `Arc` to the slice (keeping the underlying memory
23/// alive even if the `DeviceActor` has begun shutdown) plus a `Weak` to
24/// the surrounding `DeviceState` (so reference cycles cannot trap the
25/// system in a non-terminating state). Calling [`GpuRef::access`] before
26/// each use validates that the context generation has not advanced.
27pub struct GpuRef<T> {
28 inner: Arc<GpuRefInner<T>>,
29}
30
31struct GpuRefInner<T> {
32 /// Strong-keep on the device buffer.
33 slice: Arc<cudarc::driver::CudaSlice<T>>,
34 /// `DeviceState.generation` at construction time.
35 generation: u64,
36 /// Weak reference back to the device state. Avoids a cycle —
37 /// `DeviceActor` owns the strong `Arc<DeviceState>`.
38 state: Weak<DeviceState>,
39 /// The most recent `CudaStream` that wrote to this buffer. Library
40 /// actors call [`GpuRef::record_write`] after enqueueing a kernel
41 /// that mutates the slice. Cross-stream consumers (`P2pTopology`,
42 /// pipeline stages) read this to inject a `CudaEvent` wait without
43 /// a host roundtrip.
44 last_write_stream: ArcSwapOption<cudarc::driver::CudaStream>,
45}
46
47impl<T> Clone for GpuRef<T> {
48 fn clone(&self) -> Self {
49 Self {
50 inner: self.inner.clone(),
51 }
52 }
53}
54
55impl<T> std::fmt::Debug for GpuRef<T> {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("GpuRef")
58 .field("generation", &self.inner.generation)
59 .field("len", &self.inner.slice.len())
60 .finish()
61 }
62}
63
64impl<T> GpuRef<T> {
65 /// Wrap a raw `Arc<CudaSlice<T>>` produced by a `DeviceActor` into a
66 /// `GpuRef<T>`.
67 ///
68 /// Only `DeviceActor` (and code reachable from its dispatcher) should
69 /// call this — outside callers must obtain `GpuRef`s by asking the
70 /// `DeviceActor` to allocate.
71 pub fn new(slice: Arc<cudarc::driver::CudaSlice<T>>, state: &Arc<DeviceState>) -> Self {
72 let generation = state.generation();
73 Self {
74 inner: Arc::new(GpuRefInner {
75 slice,
76 generation,
77 state: Arc::downgrade(state),
78 last_write_stream: ArcSwapOption::empty(),
79 }),
80 }
81 }
82
83 /// Validate the reference and return access to the underlying slice.
84 ///
85 /// Returns [`GpuError::GpuRefStale`] if either:
86 /// - the owning `DeviceState` has been dropped,
87 /// - the device is no longer accepting operations, or
88 /// - the context generation has advanced past the one this ref was
89 /// minted with (i.e. a poisoned-context rebuild has happened).
90 pub fn access(&self) -> Result<&Arc<cudarc::driver::CudaSlice<T>>, GpuError> {
91 let state = self
92 .inner
93 .state
94 .upgrade()
95 .ok_or(GpuError::GpuRefStale("device state dropped"))?;
96 if !state.accepting_ops() {
97 return Err(GpuError::GpuRefStale("device shutting down"));
98 }
99 if state.generation() != self.inner.generation {
100 return Err(GpuError::GpuRefStale("context rebuilt"));
101 }
102 Ok(&self.inner.slice)
103 }
104
105 /// Generation token at construction. Exposed for tests.
106 pub fn generation(&self) -> u64 {
107 self.inner.generation
108 }
109
110 /// Length in elements of the underlying slice.
111 pub fn len(&self) -> usize {
112 self.inner.slice.len()
113 }
114
115 pub fn is_empty(&self) -> bool {
116 self.inner.slice.is_empty()
117 }
118
119 /// Device id this `GpuRef` was minted on, or `None` if the owning
120 /// [`DeviceState`] has been dropped.
121 pub fn device_id(&self) -> Option<u32> {
122 self.inner.state.upgrade().map(|s| s.device_id())
123 }
124
125 /// Record the stream that most recently wrote to this buffer.
126 /// Library actors (BlasActor, CudnnActor, FftActor, etc.) call this
127 /// after enqueueing a kernel that mutates the slice so that
128 /// downstream consumers can inject a cross-stream wait.
129 pub fn record_write(&self, stream: &Arc<cudarc::driver::CudaStream>) {
130 self.inner.last_write_stream.store(Some(stream.clone()));
131 }
132
133 /// Most recent producing stream, if any. Returns `None` when no
134 /// kernel has been recorded against this buffer.
135 pub fn last_write_stream(&self) -> Option<Arc<cudarc::driver::CudaStream>> {
136 self.inner.last_write_stream.load_full()
137 }
138
139 /// Phase 4.5++ — opaque `CUdeviceptr` (`u64`) for downstream
140 /// raw-pointer FFI APIs (TensorRT `enqueueV3`, `cuStreamWriteValue64`,
141 /// custom CUDA modules that aren't fronted by cudarc).
142 ///
143 /// Validates the `GpuRef` first via [`GpuRef::access`]. The pointer
144 /// is captured against the slice's own associated stream — the
145 /// `_guard` returned by cudarc's `device_ptr()` is dropped before
146 /// the function returns, but the underlying allocation outlives
147 /// this call because the inner `Arc<CudaSlice<T>>` is held by
148 /// `self`. Callers must ensure they don't dispatch the resulting
149 /// pointer on a stream that has already gone out of scope; in
150 /// practice the pointer is consumed immediately by an FFI shim
151 /// (TensorRT enqueueV3, etc.) on a stream the caller owns.
152 ///
153 /// Returns [`GpuError::GpuRefStale`] if the underlying generation
154 /// token is stale or the device is shutting down.
155 pub fn raw_device_ptr(&self) -> Result<u64, GpuError> {
156 use cudarc::driver::DevicePtr;
157 let slice = self.access()?;
158 let stream = slice.stream();
159 let (ptr, _guard) = slice.device_ptr(stream);
160 // `_guard` is a `SyncOnDrop` whose lifetime ties the pointer to
161 // `slice`; we drop it here. The caller is expected to use the
162 // returned `u64` immediately on an FFI call. The underlying
163 // CudaSlice<T> remains alive via the strong Arc held by
164 // `self.inner.slice` for as long as this `GpuRef` lives.
165 Ok(ptr)
166 }
167}
168
169#[cfg(test)]
170impl<T> GpuRef<T> {
171 /// **Test-only** stub constructor for unit tests that don't need
172 /// a real CUDA context. Returns a `GpuRef<T>` whose underlying
173 /// `CudaSlice<T>` is logically uninitialized — the test must
174 /// **never** call `.access()` on it, dispatch it through a
175 /// kernel actor, or otherwise let it reach cudarc.
176 ///
177 /// SAFETY contract: the caller must ensure the `GpuRef<T>` is
178 /// leaked (e.g. via `Box::leak` of the surrounding container)
179 /// so the inner `Arc<CudaSlice<T>>` never reaches refcount zero.
180 /// Otherwise cudarc's `Drop for CudaSlice<T>` runs with
181 /// uninitialized memory and aborts the process.
182 pub(crate) fn for_test_no_gpu_leaked() -> Self {
183 // Allocate an uninitialized box of CudaSlice<T> on the heap,
184 // then leak it. We construct an `Arc<CudaSlice<T>>` via
185 // `Arc::from_raw` so cudarc's Drop only runs if the strong
186 // count returns to 1 — which `Box::leak` of the surrounding
187 // request guarantees never happens.
188 use std::mem::MaybeUninit;
189 let boxed: Box<MaybeUninit<cudarc::driver::CudaSlice<T>>> = Box::new(MaybeUninit::uninit());
190 let leaked: *mut MaybeUninit<cudarc::driver::CudaSlice<T>> = Box::into_raw(boxed);
191 // SAFETY: the pointer is valid (just-allocated heap) and we
192 // forge an Arc whose strong count is 1. The contract above
193 // requires the surrounding test to leak the surrounding box
194 // so this Arc's count never decrements.
195 let arc_slice: std::sync::Arc<cudarc::driver::CudaSlice<T>> =
196 unsafe { std::sync::Arc::from_raw(leaked as *const cudarc::driver::CudaSlice<T>) };
197 let state = std::sync::Arc::new(crate::device::DeviceState::new(0));
198 Self::new(arc_slice, &state)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::device::DeviceState;
206
207 #[test]
208 fn generation_mismatch_fails_validate() {
209 // We can't construct a real CudaSlice without a GPU. Instead we
210 // exercise the generation-check logic by faking the slice via a
211 // pointer-only view: this test does NOT touch CUDA memory.
212 // Verify the generation accessor and accepting_ops flag.
213 let state = Arc::new(DeviceState::new(0));
214 assert_eq!(state.generation(), 0);
215 state.bump_generation();
216 assert_eq!(state.generation(), 1);
217 assert!(state.accepting_ops());
218 state.begin_shutdown();
219 assert!(!state.accepting_ops());
220 }
221}