Skip to main content

atomr_accel_cuda/sys/
cuda_driver.rs

1//! Thin, panic-safe wrappers around CUDA driver-API entry points
2//! that cudarc 0.19 only exposes at the `sys` level.
3//!
4//! Every function in this module wraps the raw `unsafe extern "C"`
5//! call in `std::panic::catch_unwind`, because cudarc's
6//! dynamic-loader path panics if `libcuda.so` isn't present at
7//! runtime. The wrappers convert "library not loadable" panics into
8//! [`crate::error::GpuError::Unrecoverable`] so kernel actors stay
9//! alive on no-GPU hosts.
10//!
11//! All pointer / handle arguments are forwarded as-is — the caller
12//! is responsible for validity.
13
14use cudarc::driver::sys as driver_sys;
15use cudarc::runtime::sys as runtime_sys;
16
17use crate::error::GpuError;
18
19const LIB_DRIVER: &str = "driver";
20const LIB_RUNTIME: &str = "runtime";
21
22fn driver_check(s: driver_sys::CUresult, op: &str) -> Result<(), GpuError> {
23    if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
24        Ok(())
25    } else {
26        Err(GpuError::LibraryError {
27            lib: LIB_DRIVER,
28            msg: format!("{op}: {s:?}"),
29        })
30    }
31}
32
33fn runtime_check(s: runtime_sys::cudaError_t, op: &str) -> Result<(), GpuError> {
34    if s == runtime_sys::cudaError::cudaSuccess {
35        Ok(())
36    } else {
37        Err(GpuError::LibraryError {
38            lib: LIB_RUNTIME,
39            msg: format!("{op}: {s:?}"),
40        })
41    }
42}
43
44/// Invoke `f`, mapping any panic from the cudarc dynamic loader into
45/// `Unrecoverable`. Used when `libcuda.so` may not be loadable on the
46/// host (CI, dev laptops without an NVIDIA driver).
47fn guarded<F, R>(op: &'static str, f: F) -> Result<R, GpuError>
48where
49    F: FnOnce() -> Result<R, GpuError>,
50{
51    match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
52        Ok(r) => r,
53        Err(_) => Err(GpuError::Unrecoverable(format!(
54            "{op}: CUDA driver not loadable"
55        ))),
56    }
57}
58
59// ---------------------------------------------------------------------------
60// cuMemPrefetchAsync (driver-API, used by `memory::prefetch`).
61// ---------------------------------------------------------------------------
62
63/// Prefetch `[dev_ptr .. dev_ptr+count)` to a target memory location
64/// on `stream`. Wraps `cuMemPrefetchAsync_v2` (the v2 shape that
65/// takes a `CUmemLocation`).
66pub fn mem_prefetch_async_v2(
67    dev_ptr: driver_sys::CUdeviceptr,
68    count: usize,
69    location: driver_sys::CUmemLocation,
70    flags: u32,
71    stream: driver_sys::CUstream,
72) -> Result<(), GpuError> {
73    guarded("cuMemPrefetchAsync_v2", || {
74        // SAFETY: pointer + length validity is the caller's contract.
75        let s =
76            unsafe { driver_sys::cuMemPrefetchAsync_v2(dev_ptr, count, location, flags, stream) };
77        driver_check(s, "cuMemPrefetchAsync_v2")
78    })
79}
80
81// ---------------------------------------------------------------------------
82// cuMemAdvise (driver-API, used by `memory::advise`).
83// ---------------------------------------------------------------------------
84
85/// Apply a memory advisory hint to a managed-memory range. Wraps
86/// `cuMemAdvise_v2` (the v2 shape that takes a `CUmemLocation`).
87pub fn mem_advise_v2(
88    dev_ptr: driver_sys::CUdeviceptr,
89    count: usize,
90    advice: driver_sys::CUmem_advise,
91    location: driver_sys::CUmemLocation,
92) -> Result<(), GpuError> {
93    guarded("cuMemAdvise_v2", || {
94        // SAFETY: caller-supplied pointer and length must be valid.
95        let s = unsafe { driver_sys::cuMemAdvise_v2(dev_ptr, count, advice, location) };
96        driver_check(s, "cuMemAdvise_v2")
97    })
98}
99
100// ---------------------------------------------------------------------------
101// cuIpc* (driver-API, used by `memory::ipc` and `event::ipc`).
102// ---------------------------------------------------------------------------
103
104#[cfg(feature = "cuda-ipc")]
105pub fn ipc_get_mem_handle(
106    dev_ptr: driver_sys::CUdeviceptr,
107) -> Result<driver_sys::CUipcMemHandle, GpuError> {
108    guarded("cuIpcGetMemHandle", || {
109        let mut handle = driver_sys::CUipcMemHandle_st {
110            reserved: [0; 64usize],
111        };
112        // SAFETY: out-pointer + caller-provided dev_ptr.
113        let s = unsafe { driver_sys::cuIpcGetMemHandle(&mut handle as *mut _, dev_ptr) };
114        driver_check(s, "cuIpcGetMemHandle")?;
115        Ok(handle)
116    })
117}
118
119#[cfg(feature = "cuda-ipc")]
120pub fn ipc_open_mem_handle_v2(
121    handle: driver_sys::CUipcMemHandle,
122    flags: u32,
123) -> Result<driver_sys::CUdeviceptr, GpuError> {
124    guarded("cuIpcOpenMemHandle_v2", || {
125        let mut dptr: driver_sys::CUdeviceptr = 0;
126        // SAFETY: out-pointer + caller-supplied handle.
127        let s = unsafe { driver_sys::cuIpcOpenMemHandle_v2(&mut dptr as *mut _, handle, flags) };
128        driver_check(s, "cuIpcOpenMemHandle_v2")?;
129        Ok(dptr)
130    })
131}
132
133#[cfg(feature = "cuda-ipc")]
134pub fn ipc_close_mem_handle(dev_ptr: driver_sys::CUdeviceptr) -> Result<(), GpuError> {
135    guarded("cuIpcCloseMemHandle", || {
136        // SAFETY: dev_ptr returned by a prior cuIpcOpenMemHandle_v2.
137        let s = unsafe { driver_sys::cuIpcCloseMemHandle(dev_ptr) };
138        driver_check(s, "cuIpcCloseMemHandle")
139    })
140}
141
142#[cfg(feature = "cuda-ipc")]
143pub fn ipc_get_event_handle(
144    event: driver_sys::CUevent,
145) -> Result<driver_sys::CUipcEventHandle, GpuError> {
146    guarded("cuIpcGetEventHandle", || {
147        let mut handle = driver_sys::CUipcEventHandle_st {
148            reserved: [0; 64usize],
149        };
150        // SAFETY: out-pointer + caller-supplied event handle.
151        let s = unsafe { driver_sys::cuIpcGetEventHandle(&mut handle as *mut _, event) };
152        driver_check(s, "cuIpcGetEventHandle")?;
153        Ok(handle)
154    })
155}
156
157#[cfg(feature = "cuda-ipc")]
158pub fn ipc_open_event_handle(
159    handle: driver_sys::CUipcEventHandle,
160) -> Result<driver_sys::CUevent, GpuError> {
161    guarded("cuIpcOpenEventHandle", || {
162        let mut event: driver_sys::CUevent = std::ptr::null_mut();
163        // SAFETY: out-pointer + caller-supplied handle bytes.
164        let s = unsafe { driver_sys::cuIpcOpenEventHandle(&mut event as *mut _, handle) };
165        driver_check(s, "cuIpcOpenEventHandle")?;
166        Ok(event)
167    })
168}
169
170// ---------------------------------------------------------------------------
171// cuModule* (driver-API, used by `module`).
172// ---------------------------------------------------------------------------
173
174/// Load a cubin/fatbin/PTX image from a memory buffer. The buffer must
175/// outlive the returned `CUmodule` for the duration of any pending
176/// kernel launch — the driver may keep references to embedded
177/// strings.
178pub fn module_load_data(image: *const std::ffi::c_void) -> Result<driver_sys::CUmodule, GpuError> {
179    guarded("cuModuleLoadData", || {
180        let mut m: driver_sys::CUmodule = std::ptr::null_mut();
181        // SAFETY: out-pointer; image is a caller-owned slice of bytes.
182        let s = unsafe { driver_sys::cuModuleLoadData(&mut m as *mut _, image) };
183        driver_check(s, "cuModuleLoadData")?;
184        Ok(m)
185    })
186}
187
188pub fn module_unload(m: driver_sys::CUmodule) -> Result<(), GpuError> {
189    guarded("cuModuleUnload", || {
190        // SAFETY: m was returned by a prior `cuModuleLoad*`.
191        let s = unsafe { driver_sys::cuModuleUnload(m) };
192        driver_check(s, "cuModuleUnload")
193    })
194}
195
196pub fn module_get_function(
197    m: driver_sys::CUmodule,
198    name: &std::ffi::CStr,
199) -> Result<driver_sys::CUfunction, GpuError> {
200    guarded("cuModuleGetFunction", || {
201        let mut f: driver_sys::CUfunction = std::ptr::null_mut();
202        // SAFETY: out-pointer; name is a caller-owned C string.
203        let s = unsafe { driver_sys::cuModuleGetFunction(&mut f as *mut _, m, name.as_ptr()) };
204        driver_check(s, "cuModuleGetFunction")?;
205        Ok(f)
206    })
207}
208
209// ---------------------------------------------------------------------------
210// cuLaunchKernel / cuLaunchCooperativeKernel (driver-API, used by `module`).
211// ---------------------------------------------------------------------------
212
213#[allow(clippy::too_many_arguments)]
214pub fn launch_kernel(
215    f: driver_sys::CUfunction,
216    grid: (u32, u32, u32),
217    block: (u32, u32, u32),
218    shared_bytes: u32,
219    stream: driver_sys::CUstream,
220    kernel_params: *mut *mut std::ffi::c_void,
221) -> Result<(), GpuError> {
222    guarded("cuLaunchKernel", || {
223        // SAFETY: the kernel-params array's lifetime is the caller's
224        // responsibility; the driver consumes it synchronously.
225        let s = unsafe {
226            driver_sys::cuLaunchKernel(
227                f,
228                grid.0,
229                grid.1,
230                grid.2,
231                block.0,
232                block.1,
233                block.2,
234                shared_bytes,
235                stream,
236                kernel_params,
237                std::ptr::null_mut(),
238            )
239        };
240        driver_check(s, "cuLaunchKernel")
241    })
242}
243
244#[allow(clippy::too_many_arguments)]
245pub fn launch_cooperative_kernel(
246    f: driver_sys::CUfunction,
247    grid: (u32, u32, u32),
248    block: (u32, u32, u32),
249    shared_bytes: u32,
250    stream: driver_sys::CUstream,
251    kernel_params: *mut *mut std::ffi::c_void,
252) -> Result<(), GpuError> {
253    guarded("cuLaunchCooperativeKernel", || {
254        // SAFETY: see `launch_kernel`. Cooperative launches additionally
255        // require the kernel to fit on the device's SM count.
256        let s = unsafe {
257            driver_sys::cuLaunchCooperativeKernel(
258                f,
259                grid.0,
260                grid.1,
261                grid.2,
262                block.0,
263                block.1,
264                block.2,
265                shared_bytes,
266                stream,
267                kernel_params,
268            )
269        };
270        driver_check(s, "cuLaunchCooperativeKernel")
271    })
272}
273
274// ---------------------------------------------------------------------------
275// runtime-API IPC (matches `cudaIpc*`, used as an alternative path on
276// systems where the driver-API bindings aren't available — the v2 shape
277// only ships on CUDA 12+).
278// ---------------------------------------------------------------------------
279
280#[cfg(feature = "cuda-ipc")]
281pub fn runtime_ipc_get_mem_handle(
282    dev_ptr: *mut std::ffi::c_void,
283) -> Result<runtime_sys::cudaIpcMemHandle_t, GpuError> {
284    guarded("cudaIpcGetMemHandle", || {
285        let mut handle = runtime_sys::cudaIpcMemHandle_st {
286            reserved: [0; 64usize],
287        };
288        // SAFETY: out-pointer; dev_ptr is the caller's contract.
289        let s = unsafe { runtime_sys::cudaIpcGetMemHandle(&mut handle as *mut _, dev_ptr) };
290        runtime_check(s, "cudaIpcGetMemHandle")?;
291        Ok(handle)
292    })
293}