Skip to main content

atomr_accel_cuda/completion/
host_fn.rs

1//! `HostFnCompletion` (§5.10) — callback-based completion detection.
2//!
3//! After enqueueing a kernel onto a `CudaStream`, register a
4//! `cuLaunchHostFunc` callback that runs after all preceding work on the
5//! stream completes. The callback fulfils a `oneshot::Sender<Result>`,
6//! which the actor's `await_completion` future awaits.
7//!
8//! The CUDA-side callback context is severely restricted (no CUDA API
9//! calls, no most locks, no blocking). The trampoline below does only:
10//!
11//! - reconstruct an `Arc<oneshot::Sender>` from a raw pointer,
12//! - send `Ok(())` on it,
13//!
14//! all of which are permitted.
15
16use std::ffi::c_void;
17use std::sync::Arc;
18
19use futures_util::future::BoxFuture;
20use futures_util::FutureExt;
21use tokio::sync::oneshot;
22
23use crate::error::GpuError;
24
25use super::CompletionStrategy;
26
27#[derive(Clone, Default)]
28pub struct HostFnCompletion;
29
30impl HostFnCompletion {
31    pub fn new() -> Self {
32        Self
33    }
34}
35
36/// Trampoline invoked by CUDA on the stream callback worker thread. It
37/// reconstructs the boxed `oneshot::Sender` and signals completion.
38///
39/// SAFETY: `data` must have been produced by `Box::into_raw(Box::new(slot))`
40/// in [`HostFnCompletion::await_completion`]; CUDA invokes this exactly
41/// once per `cuLaunchHostFunc` registration.
42unsafe extern "C" fn wake_trampoline(data: *mut c_void) {
43    if data.is_null() {
44        return;
45    }
46    // Reclaim the box and drop, which fulfils the oneshot.
47    let slot: Box<oneshot::Sender<Result<(), GpuError>>> = Box::from_raw(data.cast());
48    let _ = slot.send(Ok(()));
49}
50
51impl CompletionStrategy for HostFnCompletion {
52    fn await_completion(
53        &self,
54        stream: &Arc<cudarc::driver::CudaStream>,
55    ) -> BoxFuture<'static, Result<(), GpuError>> {
56        let stream = stream.clone();
57        let (tx, rx) = oneshot::channel::<Result<(), GpuError>>();
58        let boxed = Box::new(tx);
59        let arg = Box::into_raw(boxed) as *mut c_void;
60
61        // Use the documented `result::launch_host_function` wrapper.
62        // This is unsafe because the callback signature is unconstrained
63        // — we satisfy the requirements by hand: the trampoline does no
64        // CUDA work and does not block.
65        let launch_res = unsafe {
66            cudarc::driver::result::stream::launch_host_function(
67                stream.cu_stream(),
68                wake_trampoline,
69                arg,
70            )
71        };
72
73        if let Err(e) = launch_res {
74            // Reclaim the box so we don't leak the sender; this also
75            // drops the oneshot, so `rx.await` returns a closed-channel
76            // error that we map to a typed failure.
77            unsafe {
78                drop(Box::from_raw(
79                    arg as *mut oneshot::Sender<Result<(), GpuError>>,
80                ));
81            }
82            let msg = format!("cuLaunchHostFunc failed: {e}");
83            return async move { Err(GpuError::Driver(msg)) }.boxed();
84        }
85
86        async move {
87            match rx.await {
88                Ok(r) => r,
89                Err(_) => Err(GpuError::Driver(
90                    "host-function callback dropped without firing".into(),
91                )),
92            }
93        }
94        .boxed()
95    }
96}