atomr_accel_cuda/completion/host_fn.rs
1//! `HostFnCompletion` (§5.10) — callback-based completion detection.
2//!
3//! After enqueueing a kernel onto a `CudaStream`, register a
4//! `cuLaunchHostFunc` callback that runs after all preceding work on the
5//! stream completes. The callback fulfils a `oneshot::Sender<Result>`,
6//! which the actor's `await_completion` future awaits.
7//!
8//! The CUDA-side callback context is severely restricted (no CUDA API
9//! calls, no most locks, no blocking). The trampoline below does only:
10//!
11//! - reconstruct an `Arc<oneshot::Sender>` from a raw pointer,
12//! - send `Ok(())` on it,
13//!
14//! all of which are permitted.
15
16use std::ffi::c_void;
17use std::sync::Arc;
18
19use futures_util::future::BoxFuture;
20use futures_util::FutureExt;
21use tokio::sync::oneshot;
22
23use crate::error::GpuError;
24
25use super::CompletionStrategy;
26
27#[derive(Clone, Default)]
28pub struct HostFnCompletion;
29
30impl HostFnCompletion {
31 pub fn new() -> Self {
32 Self
33 }
34}
35
36/// Trampoline invoked by CUDA on the stream callback worker thread. It
37/// reconstructs the boxed `oneshot::Sender` and signals completion.
38///
39/// SAFETY: `data` must have been produced by `Box::into_raw(Box::new(slot))`
40/// in [`HostFnCompletion::await_completion`]; CUDA invokes this exactly
41/// once per `cuLaunchHostFunc` registration.
42unsafe extern "C" fn wake_trampoline(data: *mut c_void) {
43 if data.is_null() {
44 return;
45 }
46 // Reclaim the box and drop, which fulfils the oneshot.
47 let slot: Box<oneshot::Sender<Result<(), GpuError>>> = Box::from_raw(data.cast());
48 let _ = slot.send(Ok(()));
49}
50
51impl CompletionStrategy for HostFnCompletion {
52 fn await_completion(
53 &self,
54 stream: &Arc<cudarc::driver::CudaStream>,
55 ) -> BoxFuture<'static, Result<(), GpuError>> {
56 let stream = stream.clone();
57 let (tx, rx) = oneshot::channel::<Result<(), GpuError>>();
58 let boxed = Box::new(tx);
59 let arg = Box::into_raw(boxed) as *mut c_void;
60
61 // Use the documented `result::launch_host_function` wrapper.
62 // This is unsafe because the callback signature is unconstrained
63 // — we satisfy the requirements by hand: the trampoline does no
64 // CUDA work and does not block.
65 let launch_res = unsafe {
66 cudarc::driver::result::stream::launch_host_function(
67 stream.cu_stream(),
68 wake_trampoline,
69 arg,
70 )
71 };
72
73 if let Err(e) = launch_res {
74 // Reclaim the box so we don't leak the sender; this also
75 // drops the oneshot, so `rx.await` returns a closed-channel
76 // error that we map to a typed failure.
77 unsafe {
78 drop(Box::from_raw(
79 arg as *mut oneshot::Sender<Result<(), GpuError>>,
80 ));
81 }
82 let msg = format!("cuLaunchHostFunc failed: {e}");
83 return async move { Err(GpuError::Driver(msg)) }.boxed();
84 }
85
86 async move {
87 match rx.await {
88 Ok(r) => r,
89 Err(_) => Err(GpuError::Driver(
90 "host-function callback dropped without firing".into(),
91 )),
92 }
93 }
94 .boxed()
95 }
96}