Skip to main content

atomr_accel_cuda/
dispatcher.rs

1//! `GpuDispatcher` (§5.1) — pinned single-thread runtime that ensures the
2//! actor's CUDA context stays current on the same OS thread for the
3//! actor's whole lifetime.
4//!
5//! Tokio's default work-stealing scheduler moves tasks between worker
6//! threads, which would break the "context is current on this thread"
7//! invariant. This dispatcher owns its own dedicated OS thread, builds a
8//! Tokio runtime on it (multi-threaded with `worker_threads = 1` so
9//! background tasks make progress without anyone calling `block_on`),
10//! and forwards `Dispatcher::spawn_task` to that runtime via a
11//! [`DefaultDispatcher`] composed at construction time.
12//!
13//! Library actors that share a context with their `DeviceActor` should
14//! use the same `GpuDispatcher`. F1 wires the dispatcher
15//! programmatically; atomr-config integration is deferred to F2.
16
17use std::sync::Arc;
18use std::thread;
19
20use atomr_core::dispatch::{DefaultDispatcher, Dispatcher, DispatcherHandle};
21use futures_util::future::BoxFuture;
22use tokio::sync::oneshot;
23
24pub struct GpuDispatcher {
25    inner: Arc<GpuDispatcherInner>,
26}
27
28struct GpuDispatcherInner {
29    /// Wraps the runtime handle. We compose rather than re-implement so
30    /// we don't have to construct `DispatcherHandle` from outside
31    /// atomr-core (its inner field is `pub(crate)`).
32    delegate: DefaultDispatcher,
33    /// Held to keep the runtime thread alive until drop.
34    _join: Option<thread::JoinHandle<()>>,
35    shutdown_tx: parking_lot::Mutex<Option<oneshot::Sender<()>>>,
36}
37
38impl GpuDispatcher {
39    /// Spawn the dedicated thread and its runtime, returning a ready-to-use
40    /// dispatcher.
41    pub fn new() -> std::io::Result<Self> {
42        let (handle_tx, handle_rx) = std::sync::mpsc::sync_channel(1);
43        let (shutdown_tx, shutdown_rx) = oneshot::channel();
44
45        let join = thread::Builder::new()
46            .name("atomr-accel-cuda-gpu".into())
47            .spawn(move || {
48                // worker_threads(1) — exactly one tokio worker, on this
49                // OS thread (the one we just spawned). All tasks
50                // submitted via the runtime handle land here.
51                let rt = match tokio::runtime::Builder::new_multi_thread()
52                    .worker_threads(1)
53                    .thread_name("atomr-accel-cuda-gpu-worker")
54                    .enable_all()
55                    .build()
56                {
57                    Ok(rt) => rt,
58                    Err(e) => {
59                        let _ = handle_tx.send(Err(e));
60                        return;
61                    }
62                };
63                let _ = handle_tx.send(Ok(rt.handle().clone()));
64                rt.block_on(async move {
65                    let _ = shutdown_rx.await;
66                });
67            })?;
68
69        let rt_handle = match handle_rx.recv() {
70            Ok(Ok(h)) => h,
71            Ok(Err(e)) => return Err(e),
72            Err(_) => {
73                return Err(std::io::Error::other(
74                    "GpuDispatcher thread died before yielding its runtime handle",
75                ));
76            }
77        };
78
79        Ok(Self {
80            inner: Arc::new(GpuDispatcherInner {
81                delegate: DefaultDispatcher::new(rt_handle, 16),
82                _join: Some(join),
83                shutdown_tx: parking_lot::Mutex::new(Some(shutdown_tx)),
84            }),
85        })
86    }
87}
88
89impl Dispatcher for GpuDispatcher {
90    fn spawn_task(&self, task: BoxFuture<'static, ()>) -> DispatcherHandle {
91        self.inner.delegate.spawn_task(task)
92    }
93
94    fn throughput(&self) -> u32 {
95        self.inner.delegate.throughput()
96    }
97}
98
99impl Drop for GpuDispatcherInner {
100    fn drop(&mut self) {
101        if let Some(tx) = self.shutdown_tx.lock().take() {
102            let _ = tx.send(());
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use std::time::Duration;
111
112    #[test]
113    fn pinned_dispatcher_runs_on_dedicated_thread() {
114        let d = GpuDispatcher::new().expect("spawn dispatcher");
115        let (tx, rx) = std::sync::mpsc::channel::<thread::ThreadId>();
116
117        for _ in 0..3 {
118            let tx = tx.clone();
119            d.spawn_task(Box::pin(async move {
120                let _ = tx.send(thread::current().id());
121            }));
122        }
123
124        let mut ids = Vec::new();
125        for _ in 0..3 {
126            ids.push(rx.recv_timeout(Duration::from_secs(2)).unwrap());
127        }
128        // All tasks ran on the same dispatcher thread...
129        assert!(
130            ids.windows(2).all(|w| w[0] == w[1]),
131            "tasks ran on different threads: {:?}",
132            ids
133        );
134        // ...and not the calling test thread.
135        assert_ne!(ids[0], thread::current().id());
136    }
137}