atomr_accel_cuda/completion/
poll.rs1use std::sync::Arc;
12use std::time::Duration;
13
14use futures_util::future::BoxFuture;
15use futures_util::FutureExt;
16
17use crate::error::GpuError;
18
19use super::CompletionStrategy;
20
21#[derive(Clone, Debug)]
22pub struct PolledCompletion {
23 pub interval: Duration,
24 pub timeout: Option<Duration>,
28}
29
30impl PolledCompletion {
31 pub fn new(interval: Duration) -> Self {
32 Self {
33 interval,
34 timeout: Some(Duration::from_secs(300)),
35 }
36 }
37}
38
39impl Default for PolledCompletion {
40 fn default() -> Self {
41 Self::new(Duration::from_micros(50))
42 }
43}
44
45impl CompletionStrategy for PolledCompletion {
46 fn await_completion(
47 &self,
48 stream: &Arc<cudarc::driver::CudaStream>,
49 ) -> BoxFuture<'static, Result<(), GpuError>> {
50 let stream = stream.clone();
51 let interval = self.interval;
52 let timeout = self.timeout;
53 async move {
54 let event_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
58 stream.record_event(None)
59 }));
60 let event = match event_res {
61 Ok(Ok(e)) => e,
62 Ok(Err(e)) => {
63 return Err(GpuError::LibraryError {
64 lib: "driver",
65 msg: format!("PolledCompletion: record_event: {e}"),
66 });
67 }
68 Err(_) => {
69 return Err(GpuError::Unrecoverable(
70 "PolledCompletion: CUDA driver not loadable".into(),
71 ));
72 }
73 };
74 let started = std::time::Instant::now();
75 loop {
76 let complete =
77 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| event.is_complete()))
78 .unwrap_or(false);
79 if complete {
80 return Ok(());
81 }
82 if let Some(t) = timeout {
83 if started.elapsed() >= t {
84 return Err(GpuError::Timeout);
85 }
86 }
87 tokio::time::sleep(interval).await;
88 }
89 }
90 .boxed()
91 }
92}