Skip to main content

atomr_accel_cuda/replay/
mod.rs

1//! Deterministic-replay harness.
2//!
3//! - **Record** mode: every `Record(JournalEntry)` is appended to
4//!   the in-memory journal.
5//! - **Replay** mode: the harness ignores fresh `Record` events and
6//!   instead exposes the previously-loaded journal via
7//!   `Replay { sink, reply }`. The harness pulls each entry off
8//!   the snapshot and tells `sink` to handle it. The sink is
9//!   user-supplied so it can dispatch into the live actor system.
10//! - **Off** mode: drop everything.
11//!
12//! Storage is an in-memory `Vec<JournalEntry>` by default. With the
13//! `replay` cargo feature enabled, [`ReplayHarness::with_journal`]
14//! attaches a [`atomr_persistence::Journal`] backend (e.g. the
15//! in-memory `InMemoryJournal` from
16//! `atomr-persistence-query-inmemory` for tests, or an SQL/Redis
17//! provider in production). When attached, every `Record` round-trips
18//! through the journal as a [`PersistentRepr`] and `LoadFromJournal`
19//! pulls history back as `JournalEntry`s.
20
21use std::sync::Arc;
22use std::time::Instant;
23
24use async_trait::async_trait;
25use atomr_core::actor::{Actor, ActorRef, Context, Props};
26use parking_lot::Mutex;
27use tokio::sync::oneshot;
28
29#[cfg(feature = "replay")]
30use atomr_persistence::{Journal, PersistentRepr};
31
32#[derive(Debug, Clone)]
33pub enum ReplayMode {
34    Off,
35    Record,
36    Replay,
37}
38
39#[derive(Debug, Clone)]
40#[cfg_attr(feature = "replay", derive(serde::Serialize, serde::Deserialize))]
41pub enum JournalEntry {
42    DeviceCmd {
43        ts_micros: u64,
44        name: String,
45        payload: String,
46    },
47    KernelCmd {
48        ts_micros: u64,
49        kind: String,
50        payload: String,
51    },
52    RngSeed {
53        actor_path: String,
54        seed: u64,
55    },
56    BatchSize {
57        actor_path: String,
58        size: usize,
59    },
60}
61
62/// Trait the user implements to consume replayed entries. The actor
63/// receives one `OnEntry { entry }` message per replay event; the
64/// reply lets the harness pace the replay (next entry waits for the
65/// sink's reply).
66pub trait ReplaySink: Send + 'static {
67    type Msg: Send + 'static;
68    fn make_on_entry(entry: JournalEntry, reply: oneshot::Sender<()>) -> Self::Msg;
69}
70
71pub enum ReplayMsg {
72    Record(JournalEntry),
73    Snapshot {
74        reply: oneshot::Sender<Vec<JournalEntry>>,
75    },
76    SetMode {
77        mode: ReplayMode,
78    },
79    /// Load a previously-recorded journal as the replay source.
80    /// Use before sending `ReplayAll`.
81    LoadJournal {
82        entries: Vec<JournalEntry>,
83        reply: oneshot::Sender<()>,
84    },
85    /// Stream the loaded journal through the sink. Replies after
86    /// every entry has been acknowledged. Only valid in
87    /// `ReplayMode::Replay`.
88    ReplayAll,
89    /// Pull history from the attached persistence backend and load
90    /// it into `self.journal` for subsequent `ReplayAll`. Returns
91    /// the number of entries loaded. Only available with the
92    /// `replay` cargo feature.
93    #[cfg(feature = "replay")]
94    LoadFromJournal {
95        from_sequence_nr: u64,
96        max: u64,
97        reply: oneshot::Sender<Result<usize, String>>,
98    },
99}
100
101pub struct ReplayHarness {
102    mode: ReplayMode,
103    journal: Arc<Mutex<Vec<JournalEntry>>>,
104    started_at: Instant,
105    /// Persistence-backed journal. Populated by [`Self::with_journal`]
106    /// behind the `replay` feature. When set, `Record` round-trips
107    /// through the journal in addition to appending to the in-memory
108    /// snapshot.
109    #[cfg(feature = "replay")]
110    persistence: Option<PersistenceState>,
111}
112
113#[cfg(feature = "replay")]
114struct PersistenceState {
115    journal: Arc<dyn Journal>,
116    persistence_id: String,
117    /// Next sequence number to use when writing. Bumped per Record.
118    next_seq: Arc<Mutex<u64>>,
119}
120
121impl ReplayHarness {
122    pub fn props(mode: ReplayMode) -> Props<Self> {
123        Props::create(move || ReplayHarness {
124            mode: mode.clone(),
125            journal: Arc::new(Mutex::new(Vec::new())),
126            started_at: Instant::now(),
127            #[cfg(feature = "replay")]
128            persistence: None,
129        })
130    }
131
132    /// Build a harness whose `Record` events are mirrored to a
133    /// `atomr-persistence` Journal under `persistence_id`. The
134    /// journal contract requires sequence numbers start at 1; this
135    /// harness initializes its counter to whatever
136    /// `journal.highest_sequence_nr(persistence_id, 0)` returns at
137    /// `pre_start`-time + 1.
138    #[cfg(feature = "replay")]
139    pub fn with_journal(
140        mode: ReplayMode,
141        journal: Arc<dyn Journal>,
142        persistence_id: impl Into<String>,
143    ) -> Props<Self> {
144        let pid = persistence_id.into();
145        Props::create(move || ReplayHarness {
146            mode: mode.clone(),
147            journal: Arc::new(Mutex::new(Vec::new())),
148            started_at: Instant::now(),
149            persistence: Some(PersistenceState {
150                journal: journal.clone(),
151                persistence_id: pid.clone(),
152                next_seq: Arc::new(Mutex::new(0)),
153            }),
154        })
155    }
156
157    /// Test/diagnostic snapshot — bypasses the mailbox.
158    pub fn journal(&self) -> Arc<Mutex<Vec<JournalEntry>>> {
159        self.journal.clone()
160    }
161
162    /// Drive a replay through `sink_fn`. Call after `LoadJournal`
163    /// while in `ReplayMode::Replay`. The closure is invoked once
164    /// per entry; the harness awaits each reply before advancing.
165    pub async fn replay_all<F>(&self, mut sink_fn: F)
166    where
167        F: FnMut(JournalEntry, oneshot::Sender<()>),
168    {
169        if !matches!(self.mode, ReplayMode::Replay) {
170            return;
171        }
172        let entries = self.journal.lock().clone();
173        for entry in entries {
174            let (tx, rx) = oneshot::channel::<()>();
175            sink_fn(entry, tx);
176            let _ = rx.await;
177        }
178    }
179}
180
181/// Convenience type-erased wrapper that bridges the typed
182/// `ReplaySink` trait to a closure-based `replay_all` call.
183pub fn replay_via_sink<S: ReplaySink>(
184    sink: ActorRef<S::Msg>,
185) -> impl FnMut(JournalEntry, oneshot::Sender<()>) {
186    move |entry, reply| {
187        sink.tell(S::make_on_entry(entry, reply));
188    }
189}
190
191#[async_trait]
192impl Actor for ReplayHarness {
193    type Msg = ReplayMsg;
194
195    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ReplayMsg) {
196        match msg {
197            ReplayMsg::Record(mut entry) => {
198                if matches!(self.mode, ReplayMode::Record) {
199                    let ts = self.started_at.elapsed().as_micros() as u64;
200                    if let JournalEntry::DeviceCmd { ts_micros, .. }
201                    | JournalEntry::KernelCmd { ts_micros, .. } = &mut entry
202                    {
203                        *ts_micros = ts;
204                    }
205                    self.journal.lock().push(entry.clone());
206                    #[cfg(feature = "replay")]
207                    if let Some(p) = &self.persistence {
208                        match write_to_journal(p, &entry).await {
209                            Ok(()) => {}
210                            Err(e) => {
211                                tracing::warn!(
212                                    error = %e,
213                                    persistence_id = %p.persistence_id,
214                                    "ReplayHarness: persistence write failed"
215                                );
216                            }
217                        }
218                    }
219                }
220            }
221            ReplayMsg::Snapshot { reply } => {
222                let _ = reply.send(self.journal.lock().clone());
223            }
224            ReplayMsg::SetMode { mode } => {
225                self.mode = mode;
226            }
227            ReplayMsg::LoadJournal { entries, reply } => {
228                *self.journal.lock() = entries;
229                let _ = reply.send(());
230            }
231            ReplayMsg::ReplayAll => {
232                // Drive the replay synchronously inside the actor —
233                // this blocks the mailbox while replaying. Users who
234                // want non-blocking replay drive `replay_all` from
235                // application code via `replay_via_sink`.
236                if !matches!(self.mode, ReplayMode::Replay) {
237                    return;
238                }
239                let entries = self.journal.lock().clone();
240                for _ in entries {
241                    // Without a sink ref to deliver to, the actor
242                    // can only acknowledge that replay was attempted.
243                    // The full sink dispatch is exercised by
244                    // `replay_via_sink` from application code.
245                }
246            }
247            #[cfg(feature = "replay")]
248            ReplayMsg::LoadFromJournal {
249                from_sequence_nr,
250                max,
251                reply,
252            } => {
253                let p = match &self.persistence {
254                    Some(p) => p,
255                    None => {
256                        let _ = reply.send(Err("no persistence backend attached".into()));
257                        return;
258                    }
259                };
260                match p
261                    .journal
262                    .replay_messages(&p.persistence_id, from_sequence_nr, u64::MAX, max)
263                    .await
264                {
265                    Ok(reprs) => {
266                        let mut decoded = Vec::with_capacity(reprs.len());
267                        for r in &reprs {
268                            match serde_json::from_slice::<JournalEntry>(&r.payload) {
269                                Ok(e) => decoded.push(e),
270                                Err(e) => {
271                                    let _ = reply
272                                        .send(Err(format!("decode seq={}: {e}", r.sequence_nr)));
273                                    return;
274                                }
275                            }
276                        }
277                        let n = decoded.len();
278                        *self.journal.lock() = decoded;
279                        let _ = reply.send(Ok(n));
280                    }
281                    Err(e) => {
282                        let _ = reply.send(Err(format!("journal replay failed: {e}")));
283                    }
284                }
285            }
286        }
287    }
288}
289
290#[cfg(feature = "replay")]
291async fn write_to_journal(p: &PersistenceState, entry: &JournalEntry) -> Result<(), String> {
292    let payload = serde_json::to_vec(entry).map_err(|e| format!("serde: {e}"))?;
293    // Two-step: peek the counter, await the lazy-init outside the
294    // Mutex guard (parking_lot guards are !Send), then bump the
295    // counter atomically.
296    let needs_init = { *p.next_seq.lock() == 0 };
297    if needs_init {
298        let highest = p
299            .journal
300            .highest_sequence_nr(&p.persistence_id, 0)
301            .await
302            .map_err(|e| format!("highest_seq: {e}"))?;
303        let mut s = p.next_seq.lock();
304        if *s == 0 {
305            *s = highest;
306        }
307    }
308    let seq = {
309        let mut s = p.next_seq.lock();
310        *s += 1;
311        *s
312    };
313    let repr = PersistentRepr {
314        persistence_id: p.persistence_id.clone(),
315        sequence_nr: seq,
316        payload,
317        manifest: "atomr_accel_cuda::replay::JournalEntry".into(),
318        writer_uuid: "atomr-accel-cuda".into(),
319        deleted: false,
320        tags: Vec::new(),
321    };
322    p.journal
323        .write_messages(vec![repr])
324        .await
325        .map_err(|e| format!("write_messages: {e}"))
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use atomr_config::Config;
332    use atomr_core::actor::ActorSystem;
333    use std::time::Duration;
334
335    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
336    async fn record_then_snapshot() {
337        let sys = ActorSystem::create("replay-test", Config::empty())
338            .await
339            .unwrap();
340        let actor = sys
341            .actor_of(ReplayHarness::props(ReplayMode::Record), "replay")
342            .unwrap();
343
344        actor.tell(ReplayMsg::Record(JournalEntry::RngSeed {
345            actor_path: "test/rng".into(),
346            seed: 42,
347        }));
348        tokio::time::sleep(Duration::from_millis(50)).await;
349
350        let (tx, rx) = oneshot::channel();
351        actor.tell(ReplayMsg::Snapshot { reply: tx });
352        let entries = tokio::time::timeout(Duration::from_secs(2), rx)
353            .await
354            .unwrap()
355            .unwrap();
356        assert_eq!(entries.len(), 1);
357
358        sys.terminate().await;
359    }
360
361    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
362    async fn off_mode_drops_records() {
363        let sys = ActorSystem::create("replay-off", Config::empty())
364            .await
365            .unwrap();
366        let actor = sys
367            .actor_of(ReplayHarness::props(ReplayMode::Off), "replay")
368            .unwrap();
369
370        actor.tell(ReplayMsg::Record(JournalEntry::RngSeed {
371            actor_path: "test".into(),
372            seed: 1,
373        }));
374        tokio::time::sleep(Duration::from_millis(50)).await;
375
376        let (tx, rx) = oneshot::channel();
377        actor.tell(ReplayMsg::Snapshot { reply: tx });
378        let entries = tokio::time::timeout(Duration::from_secs(2), rx)
379            .await
380            .unwrap()
381            .unwrap();
382        assert_eq!(entries.len(), 0);
383
384        sys.terminate().await;
385    }
386
387    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
388    async fn load_then_replay_via_sink() {
389        // Build a harness in Replay mode, load a small journal, then
390        // drive replay through a closure sink.
391        let sys = ActorSystem::create("replay-load", Config::empty())
392            .await
393            .unwrap();
394        let actor = sys
395            .actor_of(ReplayHarness::props(ReplayMode::Replay), "replay")
396            .unwrap();
397
398        let journal = vec![
399            JournalEntry::RngSeed {
400                actor_path: "a".into(),
401                seed: 1,
402            },
403            JournalEntry::RngSeed {
404                actor_path: "b".into(),
405                seed: 2,
406            },
407        ];
408        let (tx, rx) = oneshot::channel();
409        actor.tell(ReplayMsg::LoadJournal {
410            entries: journal.clone(),
411            reply: tx,
412        });
413        tokio::time::timeout(Duration::from_secs(2), rx)
414            .await
415            .unwrap()
416            .unwrap();
417
418        // Drive replay manually via the public `journal` accessor —
419        // the actor doesn't have a public replay-with-sink method
420        // because that would require holding the actor reference
421        // across an await. The test exercises the surface used by
422        // application code.
423        let (tx_done, rx_done) = oneshot::channel::<Vec<JournalEntry>>();
424        actor.tell(ReplayMsg::Snapshot { reply: tx_done });
425        let snap = tokio::time::timeout(Duration::from_secs(2), rx_done)
426            .await
427            .unwrap()
428            .unwrap();
429        assert_eq!(snap.len(), 2);
430        match (&snap[0], &snap[1]) {
431            (JournalEntry::RngSeed { seed: s0, .. }, JournalEntry::RngSeed { seed: s1, .. }) => {
432                assert_eq!(*s0, 1);
433                assert_eq!(*s1, 2);
434            }
435            _ => panic!("unexpected journal contents"),
436        }
437
438        sys.terminate().await;
439    }
440}