1use std::sync::Arc;
22use std::time::Instant;
23
24use async_trait::async_trait;
25use atomr_core::actor::{Actor, ActorRef, Context, Props};
26use parking_lot::Mutex;
27use tokio::sync::oneshot;
28
29#[cfg(feature = "replay")]
30use atomr_persistence::{Journal, PersistentRepr};
31
32#[derive(Debug, Clone)]
33pub enum ReplayMode {
34 Off,
35 Record,
36 Replay,
37}
38
39#[derive(Debug, Clone)]
40#[cfg_attr(feature = "replay", derive(serde::Serialize, serde::Deserialize))]
41pub enum JournalEntry {
42 DeviceCmd {
43 ts_micros: u64,
44 name: String,
45 payload: String,
46 },
47 KernelCmd {
48 ts_micros: u64,
49 kind: String,
50 payload: String,
51 },
52 RngSeed {
53 actor_path: String,
54 seed: u64,
55 },
56 BatchSize {
57 actor_path: String,
58 size: usize,
59 },
60}
61
62pub trait ReplaySink: Send + 'static {
67 type Msg: Send + 'static;
68 fn make_on_entry(entry: JournalEntry, reply: oneshot::Sender<()>) -> Self::Msg;
69}
70
71pub enum ReplayMsg {
72 Record(JournalEntry),
73 Snapshot {
74 reply: oneshot::Sender<Vec<JournalEntry>>,
75 },
76 SetMode {
77 mode: ReplayMode,
78 },
79 LoadJournal {
82 entries: Vec<JournalEntry>,
83 reply: oneshot::Sender<()>,
84 },
85 ReplayAll,
89 #[cfg(feature = "replay")]
94 LoadFromJournal {
95 from_sequence_nr: u64,
96 max: u64,
97 reply: oneshot::Sender<Result<usize, String>>,
98 },
99}
100
101pub struct ReplayHarness {
102 mode: ReplayMode,
103 journal: Arc<Mutex<Vec<JournalEntry>>>,
104 started_at: Instant,
105 #[cfg(feature = "replay")]
110 persistence: Option<PersistenceState>,
111}
112
113#[cfg(feature = "replay")]
114struct PersistenceState {
115 journal: Arc<dyn Journal>,
116 persistence_id: String,
117 next_seq: Arc<Mutex<u64>>,
119}
120
121impl ReplayHarness {
122 pub fn props(mode: ReplayMode) -> Props<Self> {
123 Props::create(move || ReplayHarness {
124 mode: mode.clone(),
125 journal: Arc::new(Mutex::new(Vec::new())),
126 started_at: Instant::now(),
127 #[cfg(feature = "replay")]
128 persistence: None,
129 })
130 }
131
132 #[cfg(feature = "replay")]
139 pub fn with_journal(
140 mode: ReplayMode,
141 journal: Arc<dyn Journal>,
142 persistence_id: impl Into<String>,
143 ) -> Props<Self> {
144 let pid = persistence_id.into();
145 Props::create(move || ReplayHarness {
146 mode: mode.clone(),
147 journal: Arc::new(Mutex::new(Vec::new())),
148 started_at: Instant::now(),
149 persistence: Some(PersistenceState {
150 journal: journal.clone(),
151 persistence_id: pid.clone(),
152 next_seq: Arc::new(Mutex::new(0)),
153 }),
154 })
155 }
156
157 pub fn journal(&self) -> Arc<Mutex<Vec<JournalEntry>>> {
159 self.journal.clone()
160 }
161
162 pub async fn replay_all<F>(&self, mut sink_fn: F)
166 where
167 F: FnMut(JournalEntry, oneshot::Sender<()>),
168 {
169 if !matches!(self.mode, ReplayMode::Replay) {
170 return;
171 }
172 let entries = self.journal.lock().clone();
173 for entry in entries {
174 let (tx, rx) = oneshot::channel::<()>();
175 sink_fn(entry, tx);
176 let _ = rx.await;
177 }
178 }
179}
180
181pub fn replay_via_sink<S: ReplaySink>(
184 sink: ActorRef<S::Msg>,
185) -> impl FnMut(JournalEntry, oneshot::Sender<()>) {
186 move |entry, reply| {
187 sink.tell(S::make_on_entry(entry, reply));
188 }
189}
190
191#[async_trait]
192impl Actor for ReplayHarness {
193 type Msg = ReplayMsg;
194
195 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ReplayMsg) {
196 match msg {
197 ReplayMsg::Record(mut entry) => {
198 if matches!(self.mode, ReplayMode::Record) {
199 let ts = self.started_at.elapsed().as_micros() as u64;
200 if let JournalEntry::DeviceCmd { ts_micros, .. }
201 | JournalEntry::KernelCmd { ts_micros, .. } = &mut entry
202 {
203 *ts_micros = ts;
204 }
205 self.journal.lock().push(entry.clone());
206 #[cfg(feature = "replay")]
207 if let Some(p) = &self.persistence {
208 match write_to_journal(p, &entry).await {
209 Ok(()) => {}
210 Err(e) => {
211 tracing::warn!(
212 error = %e,
213 persistence_id = %p.persistence_id,
214 "ReplayHarness: persistence write failed"
215 );
216 }
217 }
218 }
219 }
220 }
221 ReplayMsg::Snapshot { reply } => {
222 let _ = reply.send(self.journal.lock().clone());
223 }
224 ReplayMsg::SetMode { mode } => {
225 self.mode = mode;
226 }
227 ReplayMsg::LoadJournal { entries, reply } => {
228 *self.journal.lock() = entries;
229 let _ = reply.send(());
230 }
231 ReplayMsg::ReplayAll => {
232 if !matches!(self.mode, ReplayMode::Replay) {
237 return;
238 }
239 let entries = self.journal.lock().clone();
240 for _ in entries {
241 }
246 }
247 #[cfg(feature = "replay")]
248 ReplayMsg::LoadFromJournal {
249 from_sequence_nr,
250 max,
251 reply,
252 } => {
253 let p = match &self.persistence {
254 Some(p) => p,
255 None => {
256 let _ = reply.send(Err("no persistence backend attached".into()));
257 return;
258 }
259 };
260 match p
261 .journal
262 .replay_messages(&p.persistence_id, from_sequence_nr, u64::MAX, max)
263 .await
264 {
265 Ok(reprs) => {
266 let mut decoded = Vec::with_capacity(reprs.len());
267 for r in &reprs {
268 match serde_json::from_slice::<JournalEntry>(&r.payload) {
269 Ok(e) => decoded.push(e),
270 Err(e) => {
271 let _ = reply
272 .send(Err(format!("decode seq={}: {e}", r.sequence_nr)));
273 return;
274 }
275 }
276 }
277 let n = decoded.len();
278 *self.journal.lock() = decoded;
279 let _ = reply.send(Ok(n));
280 }
281 Err(e) => {
282 let _ = reply.send(Err(format!("journal replay failed: {e}")));
283 }
284 }
285 }
286 }
287 }
288}
289
290#[cfg(feature = "replay")]
291async fn write_to_journal(p: &PersistenceState, entry: &JournalEntry) -> Result<(), String> {
292 let payload = serde_json::to_vec(entry).map_err(|e| format!("serde: {e}"))?;
293 let needs_init = { *p.next_seq.lock() == 0 };
297 if needs_init {
298 let highest = p
299 .journal
300 .highest_sequence_nr(&p.persistence_id, 0)
301 .await
302 .map_err(|e| format!("highest_seq: {e}"))?;
303 let mut s = p.next_seq.lock();
304 if *s == 0 {
305 *s = highest;
306 }
307 }
308 let seq = {
309 let mut s = p.next_seq.lock();
310 *s += 1;
311 *s
312 };
313 let repr = PersistentRepr {
314 persistence_id: p.persistence_id.clone(),
315 sequence_nr: seq,
316 payload,
317 manifest: "atomr_accel_cuda::replay::JournalEntry".into(),
318 writer_uuid: "atomr-accel-cuda".into(),
319 deleted: false,
320 tags: Vec::new(),
321 };
322 p.journal
323 .write_messages(vec![repr])
324 .await
325 .map_err(|e| format!("write_messages: {e}"))
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use atomr_config::Config;
332 use atomr_core::actor::ActorSystem;
333 use std::time::Duration;
334
335 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
336 async fn record_then_snapshot() {
337 let sys = ActorSystem::create("replay-test", Config::empty())
338 .await
339 .unwrap();
340 let actor = sys
341 .actor_of(ReplayHarness::props(ReplayMode::Record), "replay")
342 .unwrap();
343
344 actor.tell(ReplayMsg::Record(JournalEntry::RngSeed {
345 actor_path: "test/rng".into(),
346 seed: 42,
347 }));
348 tokio::time::sleep(Duration::from_millis(50)).await;
349
350 let (tx, rx) = oneshot::channel();
351 actor.tell(ReplayMsg::Snapshot { reply: tx });
352 let entries = tokio::time::timeout(Duration::from_secs(2), rx)
353 .await
354 .unwrap()
355 .unwrap();
356 assert_eq!(entries.len(), 1);
357
358 sys.terminate().await;
359 }
360
361 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
362 async fn off_mode_drops_records() {
363 let sys = ActorSystem::create("replay-off", Config::empty())
364 .await
365 .unwrap();
366 let actor = sys
367 .actor_of(ReplayHarness::props(ReplayMode::Off), "replay")
368 .unwrap();
369
370 actor.tell(ReplayMsg::Record(JournalEntry::RngSeed {
371 actor_path: "test".into(),
372 seed: 1,
373 }));
374 tokio::time::sleep(Duration::from_millis(50)).await;
375
376 let (tx, rx) = oneshot::channel();
377 actor.tell(ReplayMsg::Snapshot { reply: tx });
378 let entries = tokio::time::timeout(Duration::from_secs(2), rx)
379 .await
380 .unwrap()
381 .unwrap();
382 assert_eq!(entries.len(), 0);
383
384 sys.terminate().await;
385 }
386
387 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
388 async fn load_then_replay_via_sink() {
389 let sys = ActorSystem::create("replay-load", Config::empty())
392 .await
393 .unwrap();
394 let actor = sys
395 .actor_of(ReplayHarness::props(ReplayMode::Replay), "replay")
396 .unwrap();
397
398 let journal = vec![
399 JournalEntry::RngSeed {
400 actor_path: "a".into(),
401 seed: 1,
402 },
403 JournalEntry::RngSeed {
404 actor_path: "b".into(),
405 seed: 2,
406 },
407 ];
408 let (tx, rx) = oneshot::channel();
409 actor.tell(ReplayMsg::LoadJournal {
410 entries: journal.clone(),
411 reply: tx,
412 });
413 tokio::time::timeout(Duration::from_secs(2), rx)
414 .await
415 .unwrap()
416 .unwrap();
417
418 let (tx_done, rx_done) = oneshot::channel::<Vec<JournalEntry>>();
424 actor.tell(ReplayMsg::Snapshot { reply: tx_done });
425 let snap = tokio::time::timeout(Duration::from_secs(2), rx_done)
426 .await
427 .unwrap()
428 .unwrap();
429 assert_eq!(snap.len(), 2);
430 match (&snap[0], &snap[1]) {
431 (JournalEntry::RngSeed { seed: s0, .. }, JournalEntry::RngSeed { seed: s1, .. }) => {
432 assert_eq!(*s0, 1);
433 assert_eq!(*s1, 2);
434 }
435 _ => panic!("unexpected journal contents"),
436 }
437
438 sys.terminate().await;
439 }
440}