atomr_accel_cuda/
dispatcher.rs1use std::sync::Arc;
18use std::thread;
19
20use atomr_core::dispatch::{DefaultDispatcher, Dispatcher, DispatcherHandle};
21use futures_util::future::BoxFuture;
22use tokio::sync::oneshot;
23
24pub struct GpuDispatcher {
25 inner: Arc<GpuDispatcherInner>,
26}
27
28struct GpuDispatcherInner {
29 delegate: DefaultDispatcher,
33 _join: Option<thread::JoinHandle<()>>,
35 shutdown_tx: parking_lot::Mutex<Option<oneshot::Sender<()>>>,
36}
37
38impl GpuDispatcher {
39 pub fn new() -> std::io::Result<Self> {
42 let (handle_tx, handle_rx) = std::sync::mpsc::sync_channel(1);
43 let (shutdown_tx, shutdown_rx) = oneshot::channel();
44
45 let join = thread::Builder::new()
46 .name("atomr-accel-cuda-gpu".into())
47 .spawn(move || {
48 let rt = match tokio::runtime::Builder::new_multi_thread()
52 .worker_threads(1)
53 .thread_name("atomr-accel-cuda-gpu-worker")
54 .enable_all()
55 .build()
56 {
57 Ok(rt) => rt,
58 Err(e) => {
59 let _ = handle_tx.send(Err(e));
60 return;
61 }
62 };
63 let _ = handle_tx.send(Ok(rt.handle().clone()));
64 rt.block_on(async move {
65 let _ = shutdown_rx.await;
66 });
67 })?;
68
69 let rt_handle = match handle_rx.recv() {
70 Ok(Ok(h)) => h,
71 Ok(Err(e)) => return Err(e),
72 Err(_) => {
73 return Err(std::io::Error::other(
74 "GpuDispatcher thread died before yielding its runtime handle",
75 ));
76 }
77 };
78
79 Ok(Self {
80 inner: Arc::new(GpuDispatcherInner {
81 delegate: DefaultDispatcher::new(rt_handle, 16),
82 _join: Some(join),
83 shutdown_tx: parking_lot::Mutex::new(Some(shutdown_tx)),
84 }),
85 })
86 }
87}
88
89impl Dispatcher for GpuDispatcher {
90 fn spawn_task(&self, task: BoxFuture<'static, ()>) -> DispatcherHandle {
91 self.inner.delegate.spawn_task(task)
92 }
93
94 fn throughput(&self) -> u32 {
95 self.inner.delegate.throughput()
96 }
97}
98
99impl Drop for GpuDispatcherInner {
100 fn drop(&mut self) {
101 if let Some(tx) = self.shutdown_tx.lock().take() {
102 let _ = tx.send(());
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use std::time::Duration;
111
112 #[test]
113 fn pinned_dispatcher_runs_on_dedicated_thread() {
114 let d = GpuDispatcher::new().expect("spawn dispatcher");
115 let (tx, rx) = std::sync::mpsc::channel::<thread::ThreadId>();
116
117 for _ in 0..3 {
118 let tx = tx.clone();
119 d.spawn_task(Box::pin(async move {
120 let _ = tx.send(thread::current().id());
121 }));
122 }
123
124 let mut ids = Vec::new();
125 for _ in 0..3 {
126 ids.push(rx.recv_timeout(Duration::from_secs(2)).unwrap());
127 }
128 assert!(
130 ids.windows(2).all(|w| w[0] == w[1]),
131 "tasks ran on different threads: {:?}",
132 ids
133 );
134 assert_ne!(ids[0], thread::current().id());
136 }
137}