Skip to main content

atomr_accel_cuda/completion/
poll.rs

1//! `PolledCompletion` (§5.10) — periodic `cuEventQuery`-style polling.
2//!
3//! Useful where `cuLaunchHostFunc` (used by [`super::HostFnCompletion`])
4//! is forbidden by deployment policy. Trade-off: every outstanding
5//! kernel costs one tokio task waking on a timer.
6//!
7//! Implementation: `await_completion` records a `CudaEvent` on the
8//! supplied stream, then drives a tokio sleep loop calling
9//! `event.is_complete()` at `interval` until it returns true.
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use futures_util::future::BoxFuture;
15use futures_util::FutureExt;
16
17use crate::error::GpuError;
18
19use super::CompletionStrategy;
20
21#[derive(Clone, Debug)]
22pub struct PolledCompletion {
23    pub interval: Duration,
24    /// Hard cap on total wait time. `None` = unbounded. The bound
25    /// is necessary because a stuck driver could otherwise spin
26    /// forever; default 5 minutes.
27    pub timeout: Option<Duration>,
28}
29
30impl PolledCompletion {
31    pub fn new(interval: Duration) -> Self {
32        Self {
33            interval,
34            timeout: Some(Duration::from_secs(300)),
35        }
36    }
37}
38
39impl Default for PolledCompletion {
40    fn default() -> Self {
41        Self::new(Duration::from_micros(50))
42    }
43}
44
45impl CompletionStrategy for PolledCompletion {
46    fn await_completion(
47        &self,
48        stream: &Arc<cudarc::driver::CudaStream>,
49    ) -> BoxFuture<'static, Result<(), GpuError>> {
50        let stream = stream.clone();
51        let interval = self.interval;
52        let timeout = self.timeout;
53        async move {
54            // Record an event after all currently-queued work on the
55            // stream. Catch panics from the FFI loader on no-driver
56            // hosts and surface them as a typed error.
57            let event_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
58                stream.record_event(None)
59            }));
60            let event = match event_res {
61                Ok(Ok(e)) => e,
62                Ok(Err(e)) => {
63                    return Err(GpuError::LibraryError {
64                        lib: "driver",
65                        msg: format!("PolledCompletion: record_event: {e}"),
66                    });
67                }
68                Err(_) => {
69                    return Err(GpuError::Unrecoverable(
70                        "PolledCompletion: CUDA driver not loadable".into(),
71                    ));
72                }
73            };
74            let started = std::time::Instant::now();
75            loop {
76                let complete =
77                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| event.is_complete()))
78                        .unwrap_or(false);
79                if complete {
80                    return Ok(());
81                }
82                if let Some(t) = timeout {
83                    if started.elapsed() >= t {
84                        return Err(GpuError::Timeout);
85                    }
86                }
87                tokio::time::sleep(interval).await;
88            }
89        }
90        .boxed()
91    }
92}