atomr_accel_cuda/memory/
ipc.rs1#![cfg(feature = "cuda-ipc")]
19
20use cudarc::driver::sys as driver_sys;
21
22use crate::error::GpuError;
23use crate::sys::cuda_driver;
24
25#[derive(Clone, Copy)]
28pub struct IpcMemHandle {
29 pub(crate) raw: driver_sys::CUipcMemHandle,
30}
31
32impl IpcMemHandle {
33 pub fn as_bytes(&self) -> [u8; 64] {
34 unsafe { std::mem::transmute::<[std::ffi::c_char; 64], [u8; 64]>(self.raw.reserved) }
36 }
37
38 pub fn from_bytes(bytes: [u8; 64]) -> Self {
39 let raw = driver_sys::CUipcMemHandle_st {
40 reserved: unsafe { std::mem::transmute::<[u8; 64], [std::ffi::c_char; 64]>(bytes) },
42 };
43 Self { raw }
44 }
45}
46
47impl std::fmt::Debug for IpcMemHandle {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("IpcMemHandle").finish()
50 }
51}
52
53unsafe impl Send for IpcMemHandle {}
54unsafe impl Sync for IpcMemHandle {}
55
56#[derive(Debug)]
61pub struct OpenedMem {
62 dev_ptr: driver_sys::CUdeviceptr,
63 bytes: usize,
64}
65
66impl OpenedMem {
67 pub fn dev_ptr(&self) -> driver_sys::CUdeviceptr {
68 self.dev_ptr
69 }
70
71 pub fn bytes(&self) -> usize {
72 self.bytes
73 }
74}
75
76impl Drop for OpenedMem {
77 fn drop(&mut self) {
78 if self.dev_ptr != 0 {
79 let _ = cuda_driver::ipc_close_mem_handle(self.dev_ptr);
80 }
81 }
82}
83
84unsafe impl Send for OpenedMem {}
85unsafe impl Sync for OpenedMem {}
86
87pub fn get_mem_handle(dev_ptr: driver_sys::CUdeviceptr) -> Result<IpcMemHandle, GpuError> {
89 cuda_driver::ipc_get_mem_handle(dev_ptr).map(|raw| IpcMemHandle { raw })
90}
91
92pub fn open_mem_handle(handle: IpcMemHandle, bytes: usize) -> Result<OpenedMem, GpuError> {
97 let dev_ptr = cuda_driver::ipc_open_mem_handle_v2(
98 handle.raw, 1,
100 )?;
101 Ok(OpenedMem { dev_ptr, bytes })
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn handle_round_trip() {
110 let bytes: [u8; 64] = std::array::from_fn(|i| (i * 3) as u8 ^ 0x55);
111 let h = IpcMemHandle::from_bytes(bytes);
112 let round = h.as_bytes();
113 assert_eq!(round, bytes);
114 fn assert_send_sync<T: Send + Sync>() {}
116 assert_send_sync::<IpcMemHandle>();
117 assert_send_sync::<OpenedMem>();
118 }
119
120 #[test]
121 fn open_returns_typed_error_on_no_driver() {
122 let h = IpcMemHandle::from_bytes([0u8; 64]);
123 let r = open_mem_handle(h, 0);
124 match r {
125 Ok(_) => {}
126 Err(GpuError::Unrecoverable(_)) => {}
127 Err(GpuError::LibraryError { .. }) => {}
128 other => panic!("unexpected: {other:?}"),
129 }
130 }
131}