Skip to main content

atomr_accel_cuda/sys/
cufft.rs

1//! Local FFI bindings for cuFFT entry points cudarc 0.19.4 doesn't
2//! expose.
3//!
4//! Specifically: `cufftXtSetCallback` / `cufftXtClearCallback`. These
5//! aren't part of cudarc's `cufft::sys` surface, so we resolve them
6//! ourselves via `libloading` against `libcufft.so` (Linux),
7//! `cufft64_*.dll` (Windows), or `libcufft.dylib` (macOS).
8//!
9//! Why not `extern "C"`? cudarc's default `fallback-dynamic-loading`
10//! build does **not** emit a `-lcufft` link directive — symbols are
11//! resolved at runtime through `dlopen`. Adding a hard `extern "C"`
12//! reference here would break the link on hosts without libcufft on
13//! the link path (which is the whole point of `fallback-dynamic-loading`).
14//! Mirroring cudarc's strategy keeps us link-clean.
15
16#![allow(non_camel_case_types, non_snake_case, dead_code)]
17
18use core::ffi::{c_int, c_void};
19use std::sync::OnceLock;
20
21use cudarc::cufft::sys::{cufftHandle, cufftResult, cufftResult_t};
22
23/// `cufftXtCallbackType` from `cufftXt.h`.
24#[repr(i32)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum CufftXtCallbackType {
27    /// Load callback for f32 complex (cufftComplex) input.
28    LoadComplex = 0x0,
29    /// Load callback for f64 complex (cufftDoubleComplex) input.
30    LoadComplexDouble = 0x1,
31    /// Load callback for f32 real (cufftReal) input.
32    LoadReal = 0x2,
33    /// Load callback for f64 real (cufftDoubleReal) input.
34    LoadRealDouble = 0x3,
35    /// Store callback for f32 complex output.
36    StoreComplex = 0x4,
37    /// Store callback for f64 complex output.
38    StoreComplexDouble = 0x5,
39    /// Store callback for f32 real output.
40    StoreReal = 0x6,
41    /// Store callback for f64 real output.
42    StoreRealDouble = 0x7,
43}
44
45type CufftXtSetCallbackFn = unsafe extern "C" fn(
46    plan: cufftHandle,
47    callback_routine: *mut *mut c_void,
48    cb_type: i32,
49    caller_info: *mut *mut c_void,
50) -> cufftResult;
51
52type CufftXtClearCallbackFn = unsafe extern "C" fn(plan: cufftHandle, cb_type: i32) -> cufftResult;
53
54struct XtSyms {
55    set_cb: Option<libloading::Symbol<'static, CufftXtSetCallbackFn>>,
56    clear_cb: Option<libloading::Symbol<'static, CufftXtClearCallbackFn>>,
57    _lib: libloading::Library,
58}
59
60// SAFETY: `libloading::Symbol` borrows from a `Library`, but we leak
61// the library (via `OnceLock`) so the borrow is effectively `'static`.
62// The function pointers themselves are thread-safe to invoke.
63unsafe impl Send for XtSyms {}
64unsafe impl Sync for XtSyms {}
65
66static XT_SYMS: OnceLock<Result<XtSyms, String>> = OnceLock::new();
67
68#[cfg(target_os = "linux")]
69const CUFFT_LIB_CANDIDATES: &[&str] = &["libcufft.so", "libcufft.so.11", "libcufft.so.10"];
70
71#[cfg(target_os = "macos")]
72const CUFFT_LIB_CANDIDATES: &[&str] = &["libcufft.dylib"];
73
74#[cfg(target_os = "windows")]
75const CUFFT_LIB_CANDIDATES: &[&str] = &["cufft64_11.dll", "cufft64_10.dll", "cufft64_9.dll"];
76
77#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
78const CUFFT_LIB_CANDIDATES: &[&str] = &[];
79
80fn load_xt_syms() -> Result<XtSyms, String> {
81    let mut last_err: Option<String> = None;
82    for cand in CUFFT_LIB_CANDIDATES {
83        match unsafe { libloading::Library::new(*cand) } {
84            Ok(lib) => {
85                // `Symbol` borrows from `lib`. We move `lib` into the
86                // `XtSyms` struct and transmute the `'_` borrow to
87                // `'static` via the leaked `OnceLock` — sound only
88                // because `XT_SYMS` is initialized exactly once and
89                // never dropped before program exit.
90                let set_cb = unsafe {
91                    lib.get::<CufftXtSetCallbackFn>(b"cufftXtSetCallback\0")
92                        .ok()
93                        .map(|s| {
94                            std::mem::transmute::<
95                                libloading::Symbol<'_, CufftXtSetCallbackFn>,
96                                libloading::Symbol<'static, CufftXtSetCallbackFn>,
97                            >(s)
98                        })
99                };
100                let clear_cb = unsafe {
101                    lib.get::<CufftXtClearCallbackFn>(b"cufftXtClearCallback\0")
102                        .ok()
103                        .map(|s| {
104                            std::mem::transmute::<
105                                libloading::Symbol<'_, CufftXtClearCallbackFn>,
106                                libloading::Symbol<'static, CufftXtClearCallbackFn>,
107                            >(s)
108                        })
109                };
110                return Ok(XtSyms {
111                    set_cb,
112                    clear_cb,
113                    _lib: lib,
114                });
115            }
116            Err(e) => {
117                last_err = Some(format!("{cand}: {e}"));
118            }
119        }
120    }
121    Err(last_err.unwrap_or_else(|| "no libcufft candidates configured".into()))
122}
123
124fn xt_syms() -> Result<&'static XtSyms, &'static str> {
125    let cell = XT_SYMS.get_or_init(load_xt_syms);
126    match cell {
127        Ok(s) => Ok(s),
128        Err(_) => Err("cuFFT shared library not loadable on this host"),
129    }
130}
131
132/// Result that doubles as a transport for "library missing" cases on
133/// hosts where libcufft can't be dlopened (the no-GPU CI runner).
134fn fail_not_supported() -> cufftResult {
135    cufftResult_t::CUFFT_NOT_SUPPORTED
136}
137
138/// Install a load/store callback on a cuFFT plan.
139///
140/// `cb` must point to a CUDA *device* function with the signature
141/// matching `cb_type` (see the cuFFT docs). `caller_info` is passed
142/// straight through to the device callback at call time.
143///
144/// Returns `CUFFT_NOT_SUPPORTED` if `libcufft` couldn't be opened or
145/// doesn't export `cufftXtSetCallback` (older toolkits before the Xt
146/// API was carved out). All other return codes are passed through.
147///
148/// # Safety
149/// - `plan` must be a live cuFFT handle.
150/// - `cb` must point to a device-resident function with the
151///   appropriate signature.
152/// - `caller_info` must outlive every kernel launch on `plan`.
153pub unsafe fn xt_set_callback(
154    plan: cufftHandle,
155    cb: *mut c_void,
156    cb_type: CufftXtCallbackType,
157    caller_info: *mut c_void,
158) -> cufftResult {
159    let syms = match xt_syms() {
160        Ok(s) => s,
161        Err(_) => return fail_not_supported(),
162    };
163    let f = match &syms.set_cb {
164        Some(f) => f,
165        None => return fail_not_supported(),
166    };
167    let mut routine = cb;
168    let mut info = caller_info;
169    f(plan, &mut routine, cb_type as c_int, &mut info)
170}
171
172/// Clear a previously-installed callback.
173///
174/// # Safety
175/// `plan` must be a live cuFFT handle.
176pub unsafe fn xt_clear_callback(plan: cufftHandle, cb_type: CufftXtCallbackType) -> cufftResult {
177    let syms = match xt_syms() {
178        Ok(s) => s,
179        Err(_) => return fail_not_supported(),
180    };
181    let f = match &syms.clear_cb {
182        Some(f) => f,
183        None => return fail_not_supported(),
184    };
185    f(plan, cb_type as c_int)
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn callback_kinds_are_distinct() {
194        assert_ne!(
195            CufftXtCallbackType::LoadComplex as i32,
196            CufftXtCallbackType::StoreComplex as i32
197        );
198        assert_ne!(
199            CufftXtCallbackType::LoadReal as i32,
200            CufftXtCallbackType::StoreReal as i32
201        );
202        // f32 vs f64 lanes are distinct.
203        assert_ne!(
204            CufftXtCallbackType::LoadComplex as i32,
205            CufftXtCallbackType::LoadComplexDouble as i32
206        );
207    }
208
209    #[test]
210    fn xt_set_callback_is_safe_to_call_without_gpu() {
211        // On a no-GPU host the dlopen will fail; the wrapper must
212        // surface NOT_SUPPORTED rather than panicking.
213        let result = unsafe {
214            xt_set_callback(
215                0,
216                std::ptr::null_mut(),
217                CufftXtCallbackType::LoadComplex,
218                std::ptr::null_mut(),
219            )
220        };
221        // Exact code depends on the host; we just assert it doesn't
222        // panic and returns *some* cufftResult.
223        let _ = result;
224    }
225}