atomr_accel_cuda/sys/
cufft.rs1#![allow(non_camel_case_types, non_snake_case, dead_code)]
17
18use core::ffi::{c_int, c_void};
19use std::sync::OnceLock;
20
21use cudarc::cufft::sys::{cufftHandle, cufftResult, cufftResult_t};
22
23#[repr(i32)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum CufftXtCallbackType {
27 LoadComplex = 0x0,
29 LoadComplexDouble = 0x1,
31 LoadReal = 0x2,
33 LoadRealDouble = 0x3,
35 StoreComplex = 0x4,
37 StoreComplexDouble = 0x5,
39 StoreReal = 0x6,
41 StoreRealDouble = 0x7,
43}
44
45type CufftXtSetCallbackFn = unsafe extern "C" fn(
46 plan: cufftHandle,
47 callback_routine: *mut *mut c_void,
48 cb_type: i32,
49 caller_info: *mut *mut c_void,
50) -> cufftResult;
51
52type CufftXtClearCallbackFn = unsafe extern "C" fn(plan: cufftHandle, cb_type: i32) -> cufftResult;
53
54struct XtSyms {
55 set_cb: Option<libloading::Symbol<'static, CufftXtSetCallbackFn>>,
56 clear_cb: Option<libloading::Symbol<'static, CufftXtClearCallbackFn>>,
57 _lib: libloading::Library,
58}
59
60unsafe impl Send for XtSyms {}
64unsafe impl Sync for XtSyms {}
65
66static XT_SYMS: OnceLock<Result<XtSyms, String>> = OnceLock::new();
67
68#[cfg(target_os = "linux")]
69const CUFFT_LIB_CANDIDATES: &[&str] = &["libcufft.so", "libcufft.so.11", "libcufft.so.10"];
70
71#[cfg(target_os = "macos")]
72const CUFFT_LIB_CANDIDATES: &[&str] = &["libcufft.dylib"];
73
74#[cfg(target_os = "windows")]
75const CUFFT_LIB_CANDIDATES: &[&str] = &["cufft64_11.dll", "cufft64_10.dll", "cufft64_9.dll"];
76
77#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
78const CUFFT_LIB_CANDIDATES: &[&str] = &[];
79
80fn load_xt_syms() -> Result<XtSyms, String> {
81 let mut last_err: Option<String> = None;
82 for cand in CUFFT_LIB_CANDIDATES {
83 match unsafe { libloading::Library::new(*cand) } {
84 Ok(lib) => {
85 let set_cb = unsafe {
91 lib.get::<CufftXtSetCallbackFn>(b"cufftXtSetCallback\0")
92 .ok()
93 .map(|s| {
94 std::mem::transmute::<
95 libloading::Symbol<'_, CufftXtSetCallbackFn>,
96 libloading::Symbol<'static, CufftXtSetCallbackFn>,
97 >(s)
98 })
99 };
100 let clear_cb = unsafe {
101 lib.get::<CufftXtClearCallbackFn>(b"cufftXtClearCallback\0")
102 .ok()
103 .map(|s| {
104 std::mem::transmute::<
105 libloading::Symbol<'_, CufftXtClearCallbackFn>,
106 libloading::Symbol<'static, CufftXtClearCallbackFn>,
107 >(s)
108 })
109 };
110 return Ok(XtSyms {
111 set_cb,
112 clear_cb,
113 _lib: lib,
114 });
115 }
116 Err(e) => {
117 last_err = Some(format!("{cand}: {e}"));
118 }
119 }
120 }
121 Err(last_err.unwrap_or_else(|| "no libcufft candidates configured".into()))
122}
123
124fn xt_syms() -> Result<&'static XtSyms, &'static str> {
125 let cell = XT_SYMS.get_or_init(load_xt_syms);
126 match cell {
127 Ok(s) => Ok(s),
128 Err(_) => Err("cuFFT shared library not loadable on this host"),
129 }
130}
131
132fn fail_not_supported() -> cufftResult {
135 cufftResult_t::CUFFT_NOT_SUPPORTED
136}
137
138pub unsafe fn xt_set_callback(
154 plan: cufftHandle,
155 cb: *mut c_void,
156 cb_type: CufftXtCallbackType,
157 caller_info: *mut c_void,
158) -> cufftResult {
159 let syms = match xt_syms() {
160 Ok(s) => s,
161 Err(_) => return fail_not_supported(),
162 };
163 let f = match &syms.set_cb {
164 Some(f) => f,
165 None => return fail_not_supported(),
166 };
167 let mut routine = cb;
168 let mut info = caller_info;
169 f(plan, &mut routine, cb_type as c_int, &mut info)
170}
171
172pub unsafe fn xt_clear_callback(plan: cufftHandle, cb_type: CufftXtCallbackType) -> cufftResult {
177 let syms = match xt_syms() {
178 Ok(s) => s,
179 Err(_) => return fail_not_supported(),
180 };
181 let f = match &syms.clear_cb {
182 Some(f) => f,
183 None => return fail_not_supported(),
184 };
185 f(plan, cb_type as c_int)
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn callback_kinds_are_distinct() {
194 assert_ne!(
195 CufftXtCallbackType::LoadComplex as i32,
196 CufftXtCallbackType::StoreComplex as i32
197 );
198 assert_ne!(
199 CufftXtCallbackType::LoadReal as i32,
200 CufftXtCallbackType::StoreReal as i32
201 );
202 assert_ne!(
204 CufftXtCallbackType::LoadComplex as i32,
205 CufftXtCallbackType::LoadComplexDouble as i32
206 );
207 }
208
209 #[test]
210 fn xt_set_callback_is_safe_to_call_without_gpu() {
211 let result = unsafe {
214 xt_set_callback(
215 0,
216 std::ptr::null_mut(),
217 CufftXtCallbackType::LoadComplex,
218 std::ptr::null_mut(),
219 )
220 };
221 let _ = result;
224 }
225}