Skip to main content

atomr_accel_cuda/device/
state.rs

1//! `DeviceState` — the shared state that survives `ContextActor` restarts
2//! (§5.11 outer/inner-tier split, §5.8 `GpuRef` validity).
3//!
4//! `Arc<DeviceState>` is held by:
5//! - the outer `DeviceActor` (lifetime: ActorSystem)
6//! - each `ContextActor` incarnation (replaced on restart)
7//! - every live `GpuRef<T>` (via `Weak`)
8//!
9//! On context rebuild, [`DeviceState::install_context`] swaps in the new
10//! `Arc<CudaContext>` and bumps the generation atomically. `GpuRef`s
11//! minted against the old generation will fail their next `access()`.
12
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::Arc;
15
16use arc_swap::ArcSwapOption;
17use tokio::sync::watch;
18
19#[cfg(feature = "cuda-runtime-tests")]
20type ContextHandle = cudarc::driver::CudaContext;
21
22#[cfg(not(feature = "cuda-runtime-tests"))]
23type ContextHandle = cudarc::driver::CudaContext;
24
25pub struct DeviceState {
26    device_id: u32,
27    /// Bumped on every context rebuild (§5.8). Acquire/Release to pair
28    /// with the `ArcSwap` of `current_ctx`.
29    generation: AtomicU64,
30    /// Set false at the start of `DeviceActor::post_stop`. Outstanding
31    /// `GpuRef::access()` calls will fail fast with `GpuRefStale`.
32    accepting_ops: AtomicBool,
33    /// The current live context, swapped in by `ContextActor::pre_start`.
34    /// `None` between rebuilds and at startup.
35    current_ctx: ArcSwapOption<ContextHandle>,
36    /// Watch channel that publishes the new generation each time
37    /// [`DeviceState::bump_generation`] is called. Allows top-level
38    /// observers (`P2pTopology`, `NcclWorldActor`, `PlacementActor`,
39    /// `ReplayHarness`) to subscribe to context rebuilds without
40    /// polling.
41    generation_tx: watch::Sender<u64>,
42}
43
44impl DeviceState {
45    pub fn new(device_id: u32) -> Self {
46        let (tx, _rx) = watch::channel(0u64);
47        Self {
48            device_id,
49            generation: AtomicU64::new(0),
50            accepting_ops: AtomicBool::new(true),
51            current_ctx: ArcSwapOption::empty(),
52            generation_tx: tx,
53        }
54    }
55
56    pub fn device_id(&self) -> u32 {
57        self.device_id
58    }
59
60    pub fn generation(&self) -> u64 {
61        self.generation.load(Ordering::Acquire)
62    }
63
64    /// Bump the generation. Called by `ContextActor` after building a
65    /// new `CudaContext` and before spawning library children.
66    pub fn bump_generation(&self) -> u64 {
67        // fetch_add returns the previous value; the new generation is
68        // that+1.
69        let new = self.generation.fetch_add(1, Ordering::AcqRel) + 1;
70        // Best-effort publish to subscribers. If no receivers are
71        // attached the send is a no-op.
72        let _ = self.generation_tx.send(new);
73        new
74    }
75
76    /// Subscribe to generation changes. Receivers see every
77    /// [`DeviceState::bump_generation`] call. Used by top-level
78    /// observers that need to react to context rebuilds.
79    pub fn generation_watch(&self) -> watch::Receiver<u64> {
80        self.generation_tx.subscribe()
81    }
82
83    pub fn accepting_ops(&self) -> bool {
84        self.accepting_ops.load(Ordering::Acquire)
85    }
86
87    /// Mark that the `DeviceActor` is winding down. Any subsequent
88    /// `GpuRef::access()` returns `GpuRefStale`.
89    pub fn begin_shutdown(&self) {
90        self.accepting_ops.store(false, Ordering::Release);
91    }
92
93    /// Install a freshly built CUDA context into the shared state.
94    /// Called from `ContextActor::pre_start` (and the post-restart path).
95    pub fn install_context(&self, ctx: Arc<ContextHandle>) {
96        self.current_ctx.store(Some(ctx));
97    }
98
99    /// Drop the current context reference held by the shared state.
100    /// Called from `ContextActor::post_stop` so a poisoned context can be
101    /// torn down before the new incarnation builds its replacement.
102    pub fn clear_context(&self) {
103        self.current_ctx.store(None);
104    }
105
106    /// Snapshot of the current `CudaContext`, if any. `KernelActor`s use
107    /// this in their own `pre_start` to acquire the handle they need.
108    pub fn current_context(&self) -> Option<Arc<ContextHandle>> {
109        self.current_ctx.load_full()
110    }
111}
112
113impl std::fmt::Debug for DeviceState {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("DeviceState")
116            .field("device_id", &self.device_id)
117            .field("generation", &self.generation())
118            .field("accepting_ops", &self.accepting_ops())
119            .field("has_context", &self.current_ctx.load().is_some())
120            .finish()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn generation_starts_zero_and_bumps_monotonically() {
130        let s = DeviceState::new(0);
131        assert_eq!(s.generation(), 0);
132        assert_eq!(s.bump_generation(), 1);
133        assert_eq!(s.bump_generation(), 2);
134        assert_eq!(s.generation(), 2);
135    }
136
137    #[test]
138    fn shutdown_flips_accepting_ops() {
139        let s = DeviceState::new(0);
140        assert!(s.accepting_ops());
141        s.begin_shutdown();
142        assert!(!s.accepting_ops());
143    }
144}