atomr_accel_cuda/device/state.rs
1//! `DeviceState` — the shared state that survives `ContextActor` restarts
2//! (§5.11 outer/inner-tier split, §5.8 `GpuRef` validity).
3//!
4//! `Arc<DeviceState>` is held by:
5//! - the outer `DeviceActor` (lifetime: ActorSystem)
6//! - each `ContextActor` incarnation (replaced on restart)
7//! - every live `GpuRef<T>` (via `Weak`)
8//!
9//! On context rebuild, [`DeviceState::install_context`] swaps in the new
10//! `Arc<CudaContext>` and bumps the generation atomically. `GpuRef`s
11//! minted against the old generation will fail their next `access()`.
12
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::Arc;
15
16use arc_swap::ArcSwapOption;
17use tokio::sync::watch;
18
19#[cfg(feature = "cuda-runtime-tests")]
20type ContextHandle = cudarc::driver::CudaContext;
21
22#[cfg(not(feature = "cuda-runtime-tests"))]
23type ContextHandle = cudarc::driver::CudaContext;
24
25pub struct DeviceState {
26 device_id: u32,
27 /// Bumped on every context rebuild (§5.8). Acquire/Release to pair
28 /// with the `ArcSwap` of `current_ctx`.
29 generation: AtomicU64,
30 /// Set false at the start of `DeviceActor::post_stop`. Outstanding
31 /// `GpuRef::access()` calls will fail fast with `GpuRefStale`.
32 accepting_ops: AtomicBool,
33 /// The current live context, swapped in by `ContextActor::pre_start`.
34 /// `None` between rebuilds and at startup.
35 current_ctx: ArcSwapOption<ContextHandle>,
36 /// Watch channel that publishes the new generation each time
37 /// [`DeviceState::bump_generation`] is called. Allows top-level
38 /// observers (`P2pTopology`, `NcclWorldActor`, `PlacementActor`,
39 /// `ReplayHarness`) to subscribe to context rebuilds without
40 /// polling.
41 generation_tx: watch::Sender<u64>,
42}
43
44impl DeviceState {
45 pub fn new(device_id: u32) -> Self {
46 let (tx, _rx) = watch::channel(0u64);
47 Self {
48 device_id,
49 generation: AtomicU64::new(0),
50 accepting_ops: AtomicBool::new(true),
51 current_ctx: ArcSwapOption::empty(),
52 generation_tx: tx,
53 }
54 }
55
56 pub fn device_id(&self) -> u32 {
57 self.device_id
58 }
59
60 pub fn generation(&self) -> u64 {
61 self.generation.load(Ordering::Acquire)
62 }
63
64 /// Bump the generation. Called by `ContextActor` after building a
65 /// new `CudaContext` and before spawning library children.
66 pub fn bump_generation(&self) -> u64 {
67 // fetch_add returns the previous value; the new generation is
68 // that+1.
69 let new = self.generation.fetch_add(1, Ordering::AcqRel) + 1;
70 // Best-effort publish to subscribers. If no receivers are
71 // attached the send is a no-op.
72 let _ = self.generation_tx.send(new);
73 new
74 }
75
76 /// Subscribe to generation changes. Receivers see every
77 /// [`DeviceState::bump_generation`] call. Used by top-level
78 /// observers that need to react to context rebuilds.
79 pub fn generation_watch(&self) -> watch::Receiver<u64> {
80 self.generation_tx.subscribe()
81 }
82
83 pub fn accepting_ops(&self) -> bool {
84 self.accepting_ops.load(Ordering::Acquire)
85 }
86
87 /// Mark that the `DeviceActor` is winding down. Any subsequent
88 /// `GpuRef::access()` returns `GpuRefStale`.
89 pub fn begin_shutdown(&self) {
90 self.accepting_ops.store(false, Ordering::Release);
91 }
92
93 /// Install a freshly built CUDA context into the shared state.
94 /// Called from `ContextActor::pre_start` (and the post-restart path).
95 pub fn install_context(&self, ctx: Arc<ContextHandle>) {
96 self.current_ctx.store(Some(ctx));
97 }
98
99 /// Drop the current context reference held by the shared state.
100 /// Called from `ContextActor::post_stop` so a poisoned context can be
101 /// torn down before the new incarnation builds its replacement.
102 pub fn clear_context(&self) {
103 self.current_ctx.store(None);
104 }
105
106 /// Snapshot of the current `CudaContext`, if any. `KernelActor`s use
107 /// this in their own `pre_start` to acquire the handle they need.
108 pub fn current_context(&self) -> Option<Arc<ContextHandle>> {
109 self.current_ctx.load_full()
110 }
111}
112
113impl std::fmt::Debug for DeviceState {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("DeviceState")
116 .field("device_id", &self.device_id)
117 .field("generation", &self.generation())
118 .field("accepting_ops", &self.accepting_ops())
119 .field("has_context", &self.current_ctx.load().is_some())
120 .finish()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn generation_starts_zero_and_bumps_monotonically() {
130 let s = DeviceState::new(0);
131 assert_eq!(s.generation(), 0);
132 assert_eq!(s.bump_generation(), 1);
133 assert_eq!(s.bump_generation(), 2);
134 assert_eq!(s.generation(), 2);
135 }
136
137 #[test]
138 fn shutdown_flips_accepting_ops() {
139 let s = DeviceState::new(0);
140 assert!(s.accepting_ops());
141 s.begin_shutdown();
142 assert!(!s.accepting_ops());
143 }
144}