1use cudarc::driver::sys as driver_sys;
15use cudarc::runtime::sys as runtime_sys;
16
17use crate::error::GpuError;
18
19const LIB_DRIVER: &str = "driver";
20const LIB_RUNTIME: &str = "runtime";
21
22fn driver_check(s: driver_sys::CUresult, op: &str) -> Result<(), GpuError> {
23 if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
24 Ok(())
25 } else {
26 Err(GpuError::LibraryError {
27 lib: LIB_DRIVER,
28 msg: format!("{op}: {s:?}"),
29 })
30 }
31}
32
33fn runtime_check(s: runtime_sys::cudaError_t, op: &str) -> Result<(), GpuError> {
34 if s == runtime_sys::cudaError::cudaSuccess {
35 Ok(())
36 } else {
37 Err(GpuError::LibraryError {
38 lib: LIB_RUNTIME,
39 msg: format!("{op}: {s:?}"),
40 })
41 }
42}
43
44fn guarded<F, R>(op: &'static str, f: F) -> Result<R, GpuError>
48where
49 F: FnOnce() -> Result<R, GpuError>,
50{
51 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
52 Ok(r) => r,
53 Err(_) => Err(GpuError::Unrecoverable(format!(
54 "{op}: CUDA driver not loadable"
55 ))),
56 }
57}
58
59pub fn mem_prefetch_async_v2(
67 dev_ptr: driver_sys::CUdeviceptr,
68 count: usize,
69 location: driver_sys::CUmemLocation,
70 flags: u32,
71 stream: driver_sys::CUstream,
72) -> Result<(), GpuError> {
73 guarded("cuMemPrefetchAsync_v2", || {
74 let s =
76 unsafe { driver_sys::cuMemPrefetchAsync_v2(dev_ptr, count, location, flags, stream) };
77 driver_check(s, "cuMemPrefetchAsync_v2")
78 })
79}
80
81pub fn mem_advise_v2(
88 dev_ptr: driver_sys::CUdeviceptr,
89 count: usize,
90 advice: driver_sys::CUmem_advise,
91 location: driver_sys::CUmemLocation,
92) -> Result<(), GpuError> {
93 guarded("cuMemAdvise_v2", || {
94 let s = unsafe { driver_sys::cuMemAdvise_v2(dev_ptr, count, advice, location) };
96 driver_check(s, "cuMemAdvise_v2")
97 })
98}
99
100#[cfg(feature = "cuda-ipc")]
105pub fn ipc_get_mem_handle(
106 dev_ptr: driver_sys::CUdeviceptr,
107) -> Result<driver_sys::CUipcMemHandle, GpuError> {
108 guarded("cuIpcGetMemHandle", || {
109 let mut handle = driver_sys::CUipcMemHandle_st {
110 reserved: [0; 64usize],
111 };
112 let s = unsafe { driver_sys::cuIpcGetMemHandle(&mut handle as *mut _, dev_ptr) };
114 driver_check(s, "cuIpcGetMemHandle")?;
115 Ok(handle)
116 })
117}
118
119#[cfg(feature = "cuda-ipc")]
120pub fn ipc_open_mem_handle_v2(
121 handle: driver_sys::CUipcMemHandle,
122 flags: u32,
123) -> Result<driver_sys::CUdeviceptr, GpuError> {
124 guarded("cuIpcOpenMemHandle_v2", || {
125 let mut dptr: driver_sys::CUdeviceptr = 0;
126 let s = unsafe { driver_sys::cuIpcOpenMemHandle_v2(&mut dptr as *mut _, handle, flags) };
128 driver_check(s, "cuIpcOpenMemHandle_v2")?;
129 Ok(dptr)
130 })
131}
132
133#[cfg(feature = "cuda-ipc")]
134pub fn ipc_close_mem_handle(dev_ptr: driver_sys::CUdeviceptr) -> Result<(), GpuError> {
135 guarded("cuIpcCloseMemHandle", || {
136 let s = unsafe { driver_sys::cuIpcCloseMemHandle(dev_ptr) };
138 driver_check(s, "cuIpcCloseMemHandle")
139 })
140}
141
142#[cfg(feature = "cuda-ipc")]
143pub fn ipc_get_event_handle(
144 event: driver_sys::CUevent,
145) -> Result<driver_sys::CUipcEventHandle, GpuError> {
146 guarded("cuIpcGetEventHandle", || {
147 let mut handle = driver_sys::CUipcEventHandle_st {
148 reserved: [0; 64usize],
149 };
150 let s = unsafe { driver_sys::cuIpcGetEventHandle(&mut handle as *mut _, event) };
152 driver_check(s, "cuIpcGetEventHandle")?;
153 Ok(handle)
154 })
155}
156
157#[cfg(feature = "cuda-ipc")]
158pub fn ipc_open_event_handle(
159 handle: driver_sys::CUipcEventHandle,
160) -> Result<driver_sys::CUevent, GpuError> {
161 guarded("cuIpcOpenEventHandle", || {
162 let mut event: driver_sys::CUevent = std::ptr::null_mut();
163 let s = unsafe { driver_sys::cuIpcOpenEventHandle(&mut event as *mut _, handle) };
165 driver_check(s, "cuIpcOpenEventHandle")?;
166 Ok(event)
167 })
168}
169
170pub fn module_load_data(image: *const std::ffi::c_void) -> Result<driver_sys::CUmodule, GpuError> {
179 guarded("cuModuleLoadData", || {
180 let mut m: driver_sys::CUmodule = std::ptr::null_mut();
181 let s = unsafe { driver_sys::cuModuleLoadData(&mut m as *mut _, image) };
183 driver_check(s, "cuModuleLoadData")?;
184 Ok(m)
185 })
186}
187
188pub fn module_unload(m: driver_sys::CUmodule) -> Result<(), GpuError> {
189 guarded("cuModuleUnload", || {
190 let s = unsafe { driver_sys::cuModuleUnload(m) };
192 driver_check(s, "cuModuleUnload")
193 })
194}
195
196pub fn module_get_function(
197 m: driver_sys::CUmodule,
198 name: &std::ffi::CStr,
199) -> Result<driver_sys::CUfunction, GpuError> {
200 guarded("cuModuleGetFunction", || {
201 let mut f: driver_sys::CUfunction = std::ptr::null_mut();
202 let s = unsafe { driver_sys::cuModuleGetFunction(&mut f as *mut _, m, name.as_ptr()) };
204 driver_check(s, "cuModuleGetFunction")?;
205 Ok(f)
206 })
207}
208
209#[allow(clippy::too_many_arguments)]
214pub fn launch_kernel(
215 f: driver_sys::CUfunction,
216 grid: (u32, u32, u32),
217 block: (u32, u32, u32),
218 shared_bytes: u32,
219 stream: driver_sys::CUstream,
220 kernel_params: *mut *mut std::ffi::c_void,
221) -> Result<(), GpuError> {
222 guarded("cuLaunchKernel", || {
223 let s = unsafe {
226 driver_sys::cuLaunchKernel(
227 f,
228 grid.0,
229 grid.1,
230 grid.2,
231 block.0,
232 block.1,
233 block.2,
234 shared_bytes,
235 stream,
236 kernel_params,
237 std::ptr::null_mut(),
238 )
239 };
240 driver_check(s, "cuLaunchKernel")
241 })
242}
243
244#[allow(clippy::too_many_arguments)]
245pub fn launch_cooperative_kernel(
246 f: driver_sys::CUfunction,
247 grid: (u32, u32, u32),
248 block: (u32, u32, u32),
249 shared_bytes: u32,
250 stream: driver_sys::CUstream,
251 kernel_params: *mut *mut std::ffi::c_void,
252) -> Result<(), GpuError> {
253 guarded("cuLaunchCooperativeKernel", || {
254 let s = unsafe {
257 driver_sys::cuLaunchCooperativeKernel(
258 f,
259 grid.0,
260 grid.1,
261 grid.2,
262 block.0,
263 block.1,
264 block.2,
265 shared_bytes,
266 stream,
267 kernel_params,
268 )
269 };
270 driver_check(s, "cuLaunchCooperativeKernel")
271 })
272}
273
274#[cfg(feature = "cuda-ipc")]
281pub fn runtime_ipc_get_mem_handle(
282 dev_ptr: *mut std::ffi::c_void,
283) -> Result<runtime_sys::cudaIpcMemHandle_t, GpuError> {
284 guarded("cudaIpcGetMemHandle", || {
285 let mut handle = runtime_sys::cudaIpcMemHandle_st {
286 reserved: [0; 64usize],
287 };
288 let s = unsafe { runtime_sys::cudaIpcGetMemHandle(&mut handle as *mut _, dev_ptr) };
290 runtime_check(s, "cudaIpcGetMemHandle")?;
291 Ok(handle)
292 })
293}