Skip to main content

atomr_accel_cuda/kernel/
fft.rs

1//! `FftActor` — wraps [`cudarc::cufft::CudaFft`] with an LRU cache of
2//! plans keyed by shape + transform kind + dtype + batch.
3//!
4//! cuFFT's plan creation can take milliseconds; the cache amortizes
5//! that across many transforms of the same shape.
6//!
7//! # Phase 1 surface
8//!
9//! Phase 1 of the cuFFT slice expands the actor from the F2 sketch
10//! (1D R2C/C2R/C2C f32 + 2D R2C f32) to:
11//!
12//! * full 1D / 2D / **3D** transform ranks,
13//! * **f32** (R2C / C2R / C2C) **and f64** (D2Z / Z2D / Z2Z),
14//! * a true [`cufftPlanMany`]-style batched plan builder
15//!   ([`FftPlanMany`]) — arbitrary `(rank, dims, in_embed, in_stride,
16//!   in_dist, out_embed, out_stride, out_dist, batch)`,
17//! * an optional callback hook ([`FftCallbackKind`]) plumbed through
18//!   `cufftXtSetCallback` (defined in `crate::sys::cufft`). PTX/cubin
19//!   provisioning of the device-side callback is deferred to the
20//!   caller; this layer just stores/forwards the function pointer.
21//!
22//! [`cufftPlanMany`]: https://docs.nvidia.com/cuda/cufft/index.html#function-cufftplanmany
23//!
24//! # Message API: Option C hybrid
25//!
26//! Per Phase 0.2 the canonical typed API runs through
27//! [`FftMsg::Exec(Box<dyn FftDispatch>)`]. Each typed
28//! [`FftRequest<T>`] carries the dtype-aware `GpuRef<T>` payload and
29//! erases at the enum boundary, so `FftActor::Msg` stays a single
30//! non-generic enum (one mailbox per actor) while the public surface
31//! is dtype-typed.
32//!
33//! The legacy F2 variants (`Forward1dR2C`, `Inverse1dC2R`,
34//! `Exec1dC2C`, `Forward2dR2C`) are kept under `#[deprecated]` aliases
35//! so existing examples / external callers compile.
36//!
37//! **Inverse normalization:** cuFFT does NOT normalize inverse
38//! transforms by 1/N — caller's responsibility (typically folded
39//! into a downstream kernel).
40
41use std::any::Any;
42use std::ffi::c_void;
43use std::num::NonZeroUsize;
44use std::sync::Arc;
45
46use async_trait::async_trait;
47use atomr_core::actor::{Actor, Context, Props};
48use cudarc::cufft::sys as cufft_sys;
49use cudarc::cufft::{CudaFft, FftDirection as CudarcFftDirection};
50use lru::LruCache;
51use parking_lot::Mutex;
52use tokio::sync::oneshot;
53
54use crate::completion::CompletionStrategy;
55use crate::device::DeviceState;
56use crate::dtype::{DType, FftSupported};
57use crate::error::GpuError;
58use crate::gpu_ref::GpuRef;
59use crate::kernel::dispatch::{FftDispatch, FftDispatchCtx};
60use crate::kernel::envelope;
61use crate::stream::StreamAllocator;
62use crate::sys::cufft as sys_cufft;
63
64const LIB: &str = "cufft";
65const DEFAULT_CACHE_SIZE: usize = 64;
66
67// ---------------------------------------------------------------------
68// Public types
69// ---------------------------------------------------------------------
70
71/// Direction of a complex transform. Mirrors cudarc's
72/// [`cudarc::cufft::FftDirection`] but lives in our module so callers
73/// can `use atomr_accel_cuda::FftDirection`.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum FftDirection {
76    Forward,
77    Inverse,
78}
79
80impl FftDirection {
81    pub(crate) fn cudarc(self) -> CudarcFftDirection {
82        match self {
83            FftDirection::Forward => CudarcFftDirection::Forward,
84            FftDirection::Inverse => CudarcFftDirection::Inverse,
85        }
86    }
87}
88
89/// Transform kind. Covers the six cuFFT type codes the v0.19 cudarc
90/// safe surface exposes.
91///
92/// `(R2C, C2R, C2C)` are the f32 lanes; `(D2Z, Z2D, Z2Z)` are the
93/// f64 lanes. The `_F32` / `_F64` suffix on the legacy [`FftKind`]
94/// values is preserved so older callers keep compiling, with new
95/// `R2C`/`C2R`/.. aliases added for parity with the plan builder.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97pub enum FftKind {
98    R2C,
99    C2R,
100    C2C,
101    D2Z,
102    Z2D,
103    Z2Z,
104}
105
106impl FftKind {
107    /// Convenience constants matching the F2 naming.
108    #[allow(non_upper_case_globals)]
109    pub const R2cF32: FftKind = FftKind::R2C;
110    #[allow(non_upper_case_globals)]
111    pub const C2rF32: FftKind = FftKind::C2R;
112    #[allow(non_upper_case_globals)]
113    pub const C2cF32: FftKind = FftKind::C2C;
114
115    pub fn cufft_type(self) -> cufft_sys::cufftType {
116        match self {
117            FftKind::R2C => cufft_sys::cufftType::CUFFT_R2C,
118            FftKind::C2R => cufft_sys::cufftType::CUFFT_C2R,
119            FftKind::C2C => cufft_sys::cufftType::CUFFT_C2C,
120            FftKind::D2Z => cufft_sys::cufftType::CUFFT_D2Z,
121            FftKind::Z2D => cufft_sys::cufftType::CUFFT_Z2D,
122            FftKind::Z2Z => cufft_sys::cufftType::CUFFT_Z2Z,
123        }
124    }
125
126    /// Dtype of the **scalar lane** for this transform kind. R2C/C2R/C2C
127    /// are f32; D2Z/Z2D/Z2Z are f64.
128    pub fn scalar_dtype(self) -> DType {
129        match self {
130            FftKind::R2C | FftKind::C2R | FftKind::C2C => DType::F32,
131            FftKind::D2Z | FftKind::Z2D | FftKind::Z2Z => DType::F64,
132        }
133    }
134}
135
136/// Plan-cache key. Captures everything cuFFT cares about for both
137/// the simple `cufftPlan{1,2,3}d` constructors and the advanced
138/// `cufftPlanMany` builder. `dims[i] == 0` for `i >= rank` (unused
139/// dimensions are zeroed for hashability).
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
141pub struct PlanKey {
142    pub rank: u32,
143    pub dims: [i32; 3],
144    pub kind: FftKind,
145    pub dtype: DType,
146    pub batch: i32,
147    /// `Some(seed)` ⇒ this is a `plan_many` plan; the seed is a hash
148    /// of the embed/stride/dist tuples so two `plan_many`s with
149    /// different layouts hash distinctly. `None` ⇒ simple
150    /// `cufftPlan{1,2,3}d`.
151    pub many_layout: Option<u64>,
152}
153
154impl PlanKey {
155    /// Convenience constructor for a simple 1D plan.
156    pub fn plan_1d(n: i32, kind: FftKind, batch: i32) -> Self {
157        Self {
158            rank: 1,
159            dims: [n, 0, 0],
160            kind,
161            dtype: kind.scalar_dtype(),
162            batch,
163            many_layout: None,
164        }
165    }
166
167    /// Convenience constructor for a simple 2D plan.
168    pub fn plan_2d(nx: i32, ny: i32, kind: FftKind) -> Self {
169        Self {
170            rank: 2,
171            dims: [nx, ny, 0],
172            kind,
173            dtype: kind.scalar_dtype(),
174            batch: 1,
175            many_layout: None,
176        }
177    }
178
179    /// Convenience constructor for a simple 3D plan.
180    pub fn plan_3d(nx: i32, ny: i32, nz: i32, kind: FftKind) -> Self {
181        Self {
182            rank: 3,
183            dims: [nx, ny, nz],
184            kind,
185            dtype: kind.scalar_dtype(),
186            batch: 1,
187            many_layout: None,
188        }
189    }
190}
191
192/// Builder for advanced batched plans. Mirrors `cufftPlanMany`'s
193/// argument list: arbitrary `rank` (1, 2, or 3), per-dim sizes,
194/// optional in/out embed dims with strides and per-batch distances.
195///
196/// Use [`FftPlanMany::build`] (resolves through the LRU cache) or
197/// [`FftActor::ensure_plan`] to materialize an [`FftPlan`].
198#[derive(Debug, Clone)]
199pub struct FftPlanMany {
200    pub rank: u32,
201    pub dims: [i32; 3],
202    pub in_embed: Option<[i32; 3]>,
203    pub in_stride: i32,
204    pub in_dist: i32,
205    pub out_embed: Option<[i32; 3]>,
206    pub out_stride: i32,
207    pub out_dist: i32,
208    pub kind: FftKind,
209    pub batch: i32,
210}
211
212impl FftPlanMany {
213    /// Hash the embed/stride/dist tuples into a 64-bit seed used as
214    /// the [`PlanKey::many_layout`] discriminator.
215    pub fn layout_seed(&self) -> u64 {
216        use std::collections::hash_map::DefaultHasher;
217        use std::hash::{Hash, Hasher};
218        let mut h = DefaultHasher::new();
219        self.in_embed.hash(&mut h);
220        self.in_stride.hash(&mut h);
221        self.in_dist.hash(&mut h);
222        self.out_embed.hash(&mut h);
223        self.out_stride.hash(&mut h);
224        self.out_dist.hash(&mut h);
225        h.finish()
226    }
227
228    /// Plan-cache key derived from this layout description.
229    pub fn key(&self) -> PlanKey {
230        PlanKey {
231            rank: self.rank,
232            dims: self.dims,
233            kind: self.kind,
234            dtype: self.kind.scalar_dtype(),
235            batch: self.batch,
236            many_layout: Some(self.layout_seed()),
237        }
238    }
239}
240
241/// Optional callback hook attached to a plan. cuFFT's load callback
242/// is invoked while reading inputs; the store callback is invoked
243/// while writing outputs. The `kind` field tells cuFFT which signal
244/// the device-resident callback expects.
245#[derive(Debug, Clone, Copy)]
246pub enum FftCallbackKind {
247    LoadComplex,
248    LoadComplexDouble,
249    LoadReal,
250    LoadRealDouble,
251    StoreComplex,
252    StoreComplexDouble,
253    StoreReal,
254    StoreRealDouble,
255}
256
257impl FftCallbackKind {
258    fn sys(self) -> sys_cufft::CufftXtCallbackType {
259        use sys_cufft::CufftXtCallbackType as T;
260        match self {
261            FftCallbackKind::LoadComplex => T::LoadComplex,
262            FftCallbackKind::LoadComplexDouble => T::LoadComplexDouble,
263            FftCallbackKind::LoadReal => T::LoadReal,
264            FftCallbackKind::LoadRealDouble => T::LoadRealDouble,
265            FftCallbackKind::StoreComplex => T::StoreComplex,
266            FftCallbackKind::StoreComplexDouble => T::StoreComplexDouble,
267            FftCallbackKind::StoreReal => T::StoreReal,
268            FftCallbackKind::StoreRealDouble => T::StoreRealDouble,
269        }
270    }
271}
272
273/// Opaque handle to a cuFFT plan (already materialized through the
274/// LRU cache). Callers obtain one via [`FftActor::ensure_plan`] or by
275/// caching the [`PlanKey`] returned from [`FftPlanMany::key`].
276#[derive(Clone)]
277pub struct FftPlan {
278    pub key: PlanKey,
279    inner: Arc<CudaFft>,
280}
281
282impl std::fmt::Debug for FftPlan {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        f.debug_struct("FftPlan").field("key", &self.key).finish()
285    }
286}
287
288impl FftPlan {
289    pub fn key(&self) -> PlanKey {
290        self.key
291    }
292
293    /// Install a load/store callback on this plan via
294    /// `cufftXtSetCallback`. Returns `Err` if the runtime
295    /// `libcufft` doesn't expose the Xt API or the call fails.
296    ///
297    /// # Safety
298    /// `cb` must be a valid CUDA *device* function pointer of the
299    /// signature matching `kind`. `caller_info` (if non-null) must
300    /// outlive every launch on this plan.
301    pub unsafe fn with_callback(
302        &self,
303        kind: FftCallbackKind,
304        cb: *mut c_void,
305        caller_info: *mut c_void,
306    ) -> Result<(), GpuError> {
307        let res = sys_cufft::xt_set_callback(self.inner.handle(), cb, kind.sys(), caller_info);
308        match res.result() {
309            Ok(()) => Ok(()),
310            Err(e) => Err(GpuError::LibraryError {
311                lib: LIB,
312                msg: format!("cufftXtSetCallback({kind:?}): {e:?}"),
313            }),
314        }
315    }
316
317    /// Convenience wrapper: install a load callback.
318    ///
319    /// # Safety
320    /// See [`FftPlan::with_callback`].
321    pub unsafe fn with_load_callback(
322        &self,
323        kind: FftCallbackKind,
324        cb: *mut c_void,
325        caller_info: *mut c_void,
326    ) -> Result<(), GpuError> {
327        debug_assert!(matches!(
328            kind,
329            FftCallbackKind::LoadComplex
330                | FftCallbackKind::LoadComplexDouble
331                | FftCallbackKind::LoadReal
332                | FftCallbackKind::LoadRealDouble
333        ));
334        self.with_callback(kind, cb, caller_info)
335    }
336
337    /// Convenience wrapper: install a store callback.
338    ///
339    /// # Safety
340    /// See [`FftPlan::with_callback`].
341    pub unsafe fn with_store_callback(
342        &self,
343        kind: FftCallbackKind,
344        cb: *mut c_void,
345        caller_info: *mut c_void,
346    ) -> Result<(), GpuError> {
347        debug_assert!(matches!(
348            kind,
349            FftCallbackKind::StoreComplex
350                | FftCallbackKind::StoreComplexDouble
351                | FftCallbackKind::StoreReal
352                | FftCallbackKind::StoreRealDouble
353        ));
354        self.with_callback(kind, cb, caller_info)
355    }
356}
357
358// ---------------------------------------------------------------------
359// Typed request → boxed dispatch
360// ---------------------------------------------------------------------
361
362/// Typed cuFFT request — the canonical Phase-1 entry point.
363///
364/// Type parameters
365/// ---------------
366/// - `T` is the *scalar* dtype of the transform (`f32` for the float
367///   lane, `f64` for the double lane). Used for `T::KIND` reporting
368///   on the [`FftDispatch`] trait.
369/// - `I` is the input buffer's element type (default `u8` for the
370///   raw-byte Path B). Phase 1.5++ adds typed Path A by using e.g.
371///   `I = f32` (R2C input), `I = C32` (C2R / C2C input).
372/// - `O` is the output buffer's element type (default `u8`). Path A
373///   uses `O = C32` (R2C / C2C output) or `O = f32` (C2R output).
374///
375/// Plan resolution is performed by the actor on receipt of the
376/// `FftMsg::Exec` message — the request only carries a [`PlanKey`].
377/// Repeat calls with the same key hit the LRU cache on the actor.
378///
379/// In-place transforms: `output` may alias `input` (cuFFT supports
380/// this when shapes line up — only meaningful when `I = O`).
381pub struct FftRequest<T: FftSupported, I = u8, O = u8> {
382    pub plan_key: PlanKey,
383    pub direction: FftDirection,
384    pub input: GpuRef<I>,
385    pub output: GpuRef<O>,
386    pub reply: oneshot::Sender<Result<(), GpuError>>,
387    _scalar: std::marker::PhantomData<T>,
388}
389
390impl<T: FftSupported, I, O> FftRequest<T, I, O> {
391    /// Construct a request. For Path B, `I = O = u8` and the buffers
392    /// are byte-cast on the caller side. For Path A, `I` and `O` are
393    /// the per-side element types (e.g. `f32`/`C32` for R2C).
394    pub fn new(
395        plan_key: PlanKey,
396        direction: FftDirection,
397        input: GpuRef<I>,
398        output: GpuRef<O>,
399        reply: oneshot::Sender<Result<(), GpuError>>,
400    ) -> Self {
401        Self {
402            plan_key,
403            direction,
404            input,
405            output,
406            reply,
407            _scalar: std::marker::PhantomData,
408        }
409    }
410}
411
412impl<T, I, O> FftDispatch for FftRequest<T, I, O>
413where
414    T: FftSupported,
415    I: Send + Sync + 'static,
416    O: Send + Sync + 'static,
417{
418    fn dtype_kind(&self) -> DType {
419        T::KIND
420    }
421
422    fn plan_key(&self) -> PlanKey {
423        self.plan_key
424    }
425
426    fn dispatch(self: Box<Self>, ctx: &FftDispatchCtx<'_>) {
427        // Downcast the type-erased plan back to Arc<CudaFft>. The
428        // actor populates ctx.plan from the same PlanKey it pulled
429        // off `self.plan_key` via the trait method.
430        let plan = match ctx.plan.clone().downcast::<CudaFft>() {
431            Ok(p) => p,
432            Err(_) => {
433                let _ = self.reply.send(Err(GpuError::Unrecoverable(
434                    "FftDispatchCtx.plan downcast to CudaFft failed".into(),
435                )));
436                return;
437            }
438        };
439
440        let stream = ctx.stream.clone();
441        let stream_for_exec = stream.clone();
442        let completion = ctx.completion.clone();
443        let kind = self.plan_key.kind;
444        let direction = self.direction;
445
446        // Validate inputs. We use access_all_2 then unwrap each Arc
447        // for write access (cuFFT C2C / D2Z paths take &mut on input
448        // for in-place; the actor enforces single-writer by
449        // requiring unique GpuRef ownership on the dst).
450        let (src_arc, dst_arc) = match envelope::access_all_2(&self.input, &self.output) {
451            Ok(t) => t,
452            Err(e) => {
453                let _ = self.reply.send(Err(e));
454                return;
455            }
456        };
457
458        // Mark write on the destination so cross-stream consumers can
459        // serialize on it.
460        self.output.record_write(&stream);
461        let reply = self.reply;
462
463        envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
464            // SAFETY: cuFFT exec entry points take typed `*mut`
465            // pointers. We hold owning Arcs to the underlying
466            // CudaSlice<I> / CudaSlice<O> for the duration of the
467            // kernel (`run_kernel`'s keep-alive guarantees that), so
468            // the device pointers stay valid. The dtype matches
469            // because the plan was created with `kind`'s cufftType
470            // — so we pick the matching exec entry point at runtime.
471            // The reinterpret to `*mut c_void` happens inside
472            // `exec_kernel`; the source/dest element type is opaque
473            // there.
474            let res = unsafe {
475                exec_kernel(&plan, &src_arc, &dst_arc, kind, direction, &stream_for_exec)
476            };
477            res.map(|_| (src_arc, dst_arc, plan))
478                .map_err(|e| GpuError::LibraryError {
479                    lib: LIB,
480                    msg: format!("exec_{:?}: {:?}", kind, e),
481                })
482        });
483    }
484}
485
486/// Run the appropriate `cufftExec*` entry point for `kind`. Hand-rolled
487/// rather than going through cudarc's typed `exec_r2c` / `exec_c2c` etc.
488/// so we can dispatch off a runtime [`FftKind`] without a separate
489/// trait dispatch per dtype.
490///
491/// Generic over `I` / `O` so both Path B (`u8`) and Path A (typed
492/// `f32` / `C32` / `f64` / `C64`) buffers feed the same exec entry.
493/// The element type is opaque here — only the device pointer matters.
494///
495/// # Safety
496/// The plan must have been created with `kind`'s `cufftType`. `src`
497/// and `dst` must point into device memory of the appropriate sizes
498/// for that transform kind.
499unsafe fn exec_kernel<I, O>(
500    plan: &Arc<CudaFft>,
501    src: &Arc<cudarc::driver::CudaSlice<I>>,
502    dst: &Arc<cudarc::driver::CudaSlice<O>>,
503    kind: FftKind,
504    direction: FftDirection,
505    stream: &Arc<cudarc::driver::CudaStream>,
506) -> Result<(), cudarc::cufft::result::CufftError> {
507    use cudarc::driver::DevicePtr;
508
509    let (src_ptr, _src_rec) = src.device_ptr(stream);
510    let (dst_ptr, _dst_rec) = dst.device_ptr(stream);
511    let src_ptr = src_ptr as *mut c_void;
512    let dst_ptr = dst_ptr as *mut c_void;
513    let h = plan.handle();
514    use cudarc::cufft::sys as s;
515
516    let r = match kind {
517        FftKind::R2C => s::cufftExecR2C(
518            h,
519            src_ptr as *mut s::cufftReal,
520            dst_ptr as *mut s::cufftComplex,
521        ),
522        FftKind::C2R => s::cufftExecC2R(
523            h,
524            src_ptr as *mut s::cufftComplex,
525            dst_ptr as *mut s::cufftReal,
526        ),
527        FftKind::C2C => s::cufftExecC2C(
528            h,
529            src_ptr as *mut s::cufftComplex,
530            dst_ptr as *mut s::cufftComplex,
531            direction.cudarc() as i32,
532        ),
533        FftKind::D2Z => s::cufftExecD2Z(
534            h,
535            src_ptr as *mut s::cufftDoubleReal,
536            dst_ptr as *mut s::cufftDoubleComplex,
537        ),
538        FftKind::Z2D => s::cufftExecZ2D(
539            h,
540            src_ptr as *mut s::cufftDoubleComplex,
541            dst_ptr as *mut s::cufftDoubleReal,
542        ),
543        FftKind::Z2Z => s::cufftExecZ2Z(
544            h,
545            src_ptr as *mut s::cufftDoubleComplex,
546            dst_ptr as *mut s::cufftDoubleComplex,
547            direction.cudarc() as i32,
548        ),
549    };
550    r.result()
551}
552
553// ---------------------------------------------------------------------
554// Actor message + state
555// ---------------------------------------------------------------------
556
557#[allow(deprecated)]
558pub enum FftMsg {
559    /// Generic typed FFT — the canonical Phase-1 entry point.
560    Exec(Box<dyn FftDispatch>),
561
562    /// 1D real → complex forward transform (f32 → complex32).
563    #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: R2C, .. }")]
564    Forward1dR2C {
565        n: i32,
566        batch: i32,
567        src: GpuRef<f32>,
568        dst: GpuRef<cufft_sys::float2>,
569        reply: oneshot::Sender<Result<(), GpuError>>,
570    },
571    /// 1D complex → real inverse transform (complex32 → f32).
572    /// Caller is responsible for 1/N normalization.
573    #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: C2R, .. }")]
574    Inverse1dC2R {
575        n: i32,
576        batch: i32,
577        src: GpuRef<cufft_sys::float2>,
578        dst: GpuRef<f32>,
579        reply: oneshot::Sender<Result<(), GpuError>>,
580    },
581    /// 1D complex ↔ complex transform.
582    #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: C2C, .. }")]
583    Exec1dC2C {
584        n: i32,
585        batch: i32,
586        direction: CudarcFftDirection,
587        src: GpuRef<cufft_sys::float2>,
588        dst: GpuRef<cufft_sys::float2>,
589        reply: oneshot::Sender<Result<(), GpuError>>,
590    },
591    /// 2D R2C transform.
592    #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: R2C, rank=2, .. }")]
593    Forward2dR2C {
594        nx: i32,
595        ny: i32,
596        src: GpuRef<f32>,
597        dst: GpuRef<cufft_sys::float2>,
598        reply: oneshot::Sender<Result<(), GpuError>>,
599    },
600}
601
602pub struct FftActor {
603    inner: FftInner,
604}
605
606/// `CudaFft` is `Send + Sync` (it has explicit `unsafe impl`s).
607/// The plan LRU must serialize access from the actor task; we use a
608/// parking_lot mutex which is fast and uncontended on the
609/// dispatcher thread.
610struct PlanCache {
611    cache: LruCache<PlanKey, Arc<CudaFft>>,
612}
613
614impl PlanCache {
615    fn new(cap: NonZeroUsize) -> Self {
616        Self {
617            cache: LruCache::new(cap),
618        }
619    }
620}
621
622enum FftInner {
623    Real {
624        stream: Arc<cudarc::driver::CudaStream>,
625        completion: Arc<dyn CompletionStrategy>,
626        plans: Mutex<PlanCache>,
627        #[allow(dead_code)]
628        state: Arc<DeviceState>,
629    },
630    Mock,
631}
632
633impl FftActor {
634    pub fn props(
635        stream: Arc<cudarc::driver::CudaStream>,
636        _allocator: Arc<dyn StreamAllocator>,
637        completion: Arc<dyn CompletionStrategy>,
638        state: Arc<DeviceState>,
639        _ctx: Arc<cudarc::driver::CudaContext>,
640    ) -> Props<Self> {
641        Props::create(move || FftActor {
642            inner: FftInner::Real {
643                stream: stream.clone(),
644                completion: completion.clone(),
645                plans: Mutex::new(PlanCache::new(
646                    NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
647                )),
648                state: state.clone(),
649            },
650        })
651    }
652
653    pub fn mock_props() -> Props<Self> {
654        Props::create(|| FftActor {
655            inner: FftInner::Mock,
656        })
657    }
658}
659
660impl FftActor {
661    /// Resolve a [`PlanKey`] through the LRU cache, building the
662    /// underlying [`CudaFft`] on miss. Used both by the canonical
663    /// `Exec` path and by the legacy compat variants.
664    pub fn ensure_plan(&self, key: PlanKey) -> Result<FftPlan, GpuError> {
665        let arc = self.get_or_create_plan(key)?;
666        Ok(FftPlan { key, inner: arc })
667    }
668
669    /// Resolve a [`FftPlanMany`] description through the LRU cache.
670    pub fn ensure_plan_many(&self, builder: &FftPlanMany) -> Result<FftPlan, GpuError> {
671        let key = builder.key();
672        let FftInner::Real { stream, plans, .. } = &self.inner else {
673            return Err(GpuError::Unrecoverable("fft mock".into()));
674        };
675        {
676            let mut g = plans.lock();
677            if let Some(plan) = g.cache.get(&key) {
678                return Ok(FftPlan {
679                    key,
680                    inner: plan.clone(),
681                });
682            }
683        }
684        let plan = build_plan_many(builder, stream).map_err(|e| GpuError::LibraryError {
685            lib: LIB,
686            msg: format!("plan_many {key:?}: {e}"),
687        })?;
688        let plan = Arc::new(plan);
689        {
690            let mut g = plans.lock();
691            g.cache.put(key, plan.clone());
692        }
693        Ok(FftPlan { key, inner: plan })
694    }
695
696    fn get_or_create_plan(&self, key: PlanKey) -> Result<Arc<CudaFft>, GpuError> {
697        let FftInner::Real { stream, plans, .. } = &self.inner else {
698            return Err(GpuError::Unrecoverable("fft mock".into()));
699        };
700        {
701            let mut g = plans.lock();
702            if let Some(plan) = g.cache.get(&key) {
703                return Ok(plan.clone());
704            }
705        }
706        let plan = build_simple_plan(&key, stream).map_err(|e| GpuError::LibraryError {
707            lib: LIB,
708            msg: format!("plan {key:?}: {e}"),
709        })?;
710        let plan = Arc::new(plan);
711        {
712            let mut g = plans.lock();
713            g.cache.put(key, plan.clone());
714        }
715        Ok(plan)
716    }
717}
718
719fn build_simple_plan(
720    key: &PlanKey,
721    stream: &Arc<cudarc::driver::CudaStream>,
722) -> Result<CudaFft, cudarc::cufft::result::CufftError> {
723    match key.rank {
724        1 => CudaFft::plan_1d(
725            key.dims[0],
726            key.kind.cufft_type(),
727            key.batch,
728            stream.clone(),
729        ),
730        2 => CudaFft::plan_2d(
731            key.dims[0],
732            key.dims[1],
733            key.kind.cufft_type(),
734            stream.clone(),
735        ),
736        3 => CudaFft::plan_3d(
737            key.dims[0],
738            key.dims[1],
739            key.dims[2],
740            key.kind.cufft_type(),
741            stream.clone(),
742        ),
743        // Defensive: unknown rank — fall through to a fake invalid plan.
744        // The PlanKey constructors only emit 1/2/3, but a future
745        // open-extension might land non-rank values.
746        _ => CudaFft::plan_1d(1, key.kind.cufft_type(), 1, stream.clone()),
747    }
748}
749
750fn build_plan_many(
751    b: &FftPlanMany,
752    stream: &Arc<cudarc::driver::CudaStream>,
753) -> Result<CudaFft, cudarc::cufft::result::CufftError> {
754    let n: &[i32] = &b.dims[..b.rank as usize];
755    let in_embed = b.in_embed;
756    let out_embed = b.out_embed;
757    let inembed: Option<&[i32]> = in_embed.as_ref().map(|e| &e[..b.rank as usize]);
758    let onembed: Option<&[i32]> = out_embed.as_ref().map(|e| &e[..b.rank as usize]);
759    CudaFft::plan_many(
760        n,
761        inembed,
762        b.in_stride,
763        b.in_dist,
764        onembed,
765        b.out_stride,
766        b.out_dist,
767        b.kind.cufft_type(),
768        b.batch,
769        stream.clone(),
770    )
771}
772
773// ---------------------------------------------------------------------
774// Actor handler
775// ---------------------------------------------------------------------
776
777#[allow(deprecated)]
778#[async_trait]
779impl Actor for FftActor {
780    type Msg = FftMsg;
781
782    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: FftMsg) {
783        let (stream, completion) = match &self.inner {
784            FftInner::Mock => {
785                reply_mock(msg);
786                return;
787            }
788            FftInner::Real {
789                stream, completion, ..
790            } => (stream.clone(), completion.clone()),
791        };
792
793        match msg {
794            FftMsg::Exec(req) => {
795                // Resolve the plan from the request's plan key
796                // (which the trait surfaces via `plan_key()`), then
797                // hand the resolved Arc<CudaFft> to the dispatch impl
798                // via `FftDispatchCtx`. The dispatch impl downcasts
799                // back to `Arc<CudaFft>`.
800                let key = req.plan_key();
801                let plan_arc = match self.get_or_create_plan(key) {
802                    Ok(p) => p,
803                    Err(_e) => {
804                        // The request owns the reply channel — drop
805                        // it; the requester sees `RecvError`. We
806                        // can't extract the reply from a `Box<dyn
807                        // FftDispatch>` here without an extra trait
808                        // method, and the typed dispatch impl will
809                        // also surface the error if it tries again.
810                        // Take the simpler path: still call dispatch
811                        // with a sentinel plan, which fails the
812                        // downcast and replies with Unrecoverable.
813                        let dummy: Arc<dyn Any + Send + Sync> = Arc::new(());
814                        let dispatch_ctx = FftDispatchCtx {
815                            stream: &stream,
816                            completion: &completion,
817                            plan: dummy,
818                        };
819                        req.dispatch(&dispatch_ctx);
820                        return;
821                    }
822                };
823                let plan_any: Arc<dyn Any + Send + Sync> = plan_arc;
824                let dispatch_ctx = FftDispatchCtx {
825                    stream: &stream,
826                    completion: &completion,
827                    plan: plan_any,
828                };
829                req.dispatch(&dispatch_ctx);
830            }
831            FftMsg::Forward1dR2C {
832                n,
833                batch,
834                src,
835                dst,
836                reply,
837            } => {
838                let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::R2C, batch)) {
839                    Ok(p) => p,
840                    Err(e) => {
841                        let _ = reply.send(Err(e));
842                        return;
843                    }
844                };
845                let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
846                    Ok(t) => t,
847                    Err(e) => {
848                        let _ = reply.send(Err(e));
849                        return;
850                    }
851                };
852                let mut dst_owned = match Arc::try_unwrap(dst_slice) {
853                    Ok(s) => s,
854                    Err(_) => {
855                        let _ = reply.send(Err(GpuError::Unrecoverable(
856                            "FFT dst has multiple live references".into(),
857                        )));
858                        return;
859                    }
860                };
861                dst.record_write(&stream);
862                envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
863                    plan.exec_r2c(&*src_slice, &mut dst_owned)
864                        .map(|_| (src_slice, dst_owned, plan))
865                        .map_err(|e| GpuError::LibraryError {
866                            lib: LIB,
867                            msg: format!("exec_r2c: {e}"),
868                        })
869                });
870            }
871            FftMsg::Inverse1dC2R {
872                n,
873                batch,
874                src,
875                dst,
876                reply,
877            } => {
878                let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::C2R, batch)) {
879                    Ok(p) => p,
880                    Err(e) => {
881                        let _ = reply.send(Err(e));
882                        return;
883                    }
884                };
885                let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
886                    Ok(t) => t,
887                    Err(e) => {
888                        let _ = reply.send(Err(e));
889                        return;
890                    }
891                };
892                let mut src_owned = match Arc::try_unwrap(src_slice) {
893                    Ok(s) => s,
894                    Err(_) => {
895                        let _ = reply.send(Err(GpuError::Unrecoverable(
896                            "FFT C2R src has multiple live references".into(),
897                        )));
898                        return;
899                    }
900                };
901                let mut dst_owned = match Arc::try_unwrap(dst_slice) {
902                    Ok(s) => s,
903                    Err(_) => {
904                        let _ = reply.send(Err(GpuError::Unrecoverable(
905                            "FFT C2R dst has multiple live references".into(),
906                        )));
907                        return;
908                    }
909                };
910                dst.record_write(&stream);
911                envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
912                    plan.exec_c2r(&mut src_owned, &mut dst_owned)
913                        .map(|_| (src_owned, dst_owned, plan))
914                        .map_err(|e| GpuError::LibraryError {
915                            lib: LIB,
916                            msg: format!("exec_c2r: {e}"),
917                        })
918                });
919            }
920            FftMsg::Exec1dC2C {
921                n,
922                batch,
923                direction,
924                src,
925                dst,
926                reply,
927            } => {
928                let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::C2C, batch)) {
929                    Ok(p) => p,
930                    Err(e) => {
931                        let _ = reply.send(Err(e));
932                        return;
933                    }
934                };
935                let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
936                    Ok(t) => t,
937                    Err(e) => {
938                        let _ = reply.send(Err(e));
939                        return;
940                    }
941                };
942                let mut src_owned = match Arc::try_unwrap(src_slice) {
943                    Ok(s) => s,
944                    Err(_) => {
945                        let _ = reply.send(Err(GpuError::Unrecoverable(
946                            "FFT C2C src has multiple live references".into(),
947                        )));
948                        return;
949                    }
950                };
951                let mut dst_owned = match Arc::try_unwrap(dst_slice) {
952                    Ok(s) => s,
953                    Err(_) => {
954                        let _ = reply.send(Err(GpuError::Unrecoverable(
955                            "FFT C2C dst has multiple live references".into(),
956                        )));
957                        return;
958                    }
959                };
960                dst.record_write(&stream);
961                envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
962                    plan.exec_c2c(&mut src_owned, &mut dst_owned, direction)
963                        .map(|_| (src_owned, dst_owned, plan))
964                        .map_err(|e| GpuError::LibraryError {
965                            lib: LIB,
966                            msg: format!("exec_c2c: {e}"),
967                        })
968                });
969            }
970            FftMsg::Forward2dR2C {
971                nx,
972                ny,
973                src,
974                dst,
975                reply,
976            } => {
977                let plan = match self.get_or_create_plan(PlanKey::plan_2d(nx, ny, FftKind::R2C)) {
978                    Ok(p) => p,
979                    Err(e) => {
980                        let _ = reply.send(Err(e));
981                        return;
982                    }
983                };
984                let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
985                    Ok(t) => t,
986                    Err(e) => {
987                        let _ = reply.send(Err(e));
988                        return;
989                    }
990                };
991                let mut dst_owned = match Arc::try_unwrap(dst_slice) {
992                    Ok(s) => s,
993                    Err(_) => {
994                        let _ = reply.send(Err(GpuError::Unrecoverable(
995                            "FFT 2D dst has multiple live references".into(),
996                        )));
997                        return;
998                    }
999                };
1000                dst.record_write(&stream);
1001                envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
1002                    plan.exec_r2c(&*src_slice, &mut dst_owned)
1003                        .map(|_| (src_slice, dst_owned, plan))
1004                        .map_err(|e| GpuError::LibraryError {
1005                            lib: LIB,
1006                            msg: format!("exec_r2c (2d): {e}"),
1007                        })
1008                });
1009            }
1010        }
1011    }
1012}
1013
1014#[allow(deprecated)]
1015fn reply_mock(msg: FftMsg) {
1016    let err = || GpuError::Unrecoverable("FftActor in mock mode".into());
1017    match msg {
1018        FftMsg::Exec(req) => {
1019            // Drop the boxed request. The caller's reply channel
1020            // closes silently, surfacing as `RecvError` — same
1021            // behavior as the legacy variants.
1022            drop(req);
1023        }
1024        FftMsg::Forward1dR2C { reply, .. } => {
1025            let _ = reply.send(Err(err()));
1026        }
1027        FftMsg::Inverse1dC2R { reply, .. } => {
1028            let _ = reply.send(Err(err()));
1029        }
1030        FftMsg::Exec1dC2C { reply, .. } => {
1031            let _ = reply.send(Err(err()));
1032        }
1033        FftMsg::Forward2dR2C { reply, .. } => {
1034            let _ = reply.send(Err(err()));
1035        }
1036    }
1037}
1038
1039// ---------------------------------------------------------------------
1040// Tests (no GPU)
1041// ---------------------------------------------------------------------
1042
1043#[cfg(test)]
1044mod tests {
1045    #![allow(deprecated)]
1046    use super::*;
1047    #[cfg(feature = "f16")]
1048    use crate::dtype::CudaDtype;
1049
1050    // Tests stay structural: no real `GpuRef` construction (would
1051    // require a `CudaContext`). The actor end-to-end path is covered
1052    // by GPU integration tests.
1053
1054    #[test]
1055    fn plan_key_for_simple_plans_zeroes_unused_dims() {
1056        let k1 = PlanKey::plan_1d(1024, FftKind::R2C, 1);
1057        assert_eq!(k1.rank, 1);
1058        assert_eq!(k1.dims, [1024, 0, 0]);
1059        assert_eq!(k1.dtype, DType::F32);
1060        assert!(k1.many_layout.is_none());
1061
1062        let k2 = PlanKey::plan_2d(64, 64, FftKind::R2C);
1063        assert_eq!(k2.rank, 2);
1064        assert_eq!(k2.dims, [64, 64, 0]);
1065        assert_eq!(k2.dtype, DType::F32);
1066
1067        let k3 = PlanKey::plan_3d(32, 32, 32, FftKind::Z2Z);
1068        assert_eq!(k3.rank, 3);
1069        assert_eq!(k3.dims, [32, 32, 32]);
1070        assert_eq!(k3.dtype, DType::F64);
1071    }
1072
1073    /// 3D plan key includes rank=3 (one of the verification-required
1074    /// tests).
1075    #[test]
1076    fn fft_3d_plan_dim_handling() {
1077        let k = PlanKey::plan_3d(8, 16, 32, FftKind::C2C);
1078        assert_eq!(k.rank, 3);
1079        assert_eq!(k.dims[0], 8);
1080        assert_eq!(k.dims[1], 16);
1081        assert_eq!(k.dims[2], 32);
1082        assert_eq!(k.kind, FftKind::C2C);
1083    }
1084
1085    #[test]
1086    fn plan_many_descriptor_correct() {
1087        let many = FftPlanMany {
1088            rank: 2,
1089            dims: [4, 8, 0],
1090            in_embed: Some([4, 8, 0]),
1091            in_stride: 1,
1092            in_dist: 32,
1093            out_embed: Some([4, 5, 0]),
1094            out_stride: 1,
1095            out_dist: 20,
1096            kind: FftKind::R2C,
1097            batch: 2,
1098        };
1099        let key = many.key();
1100        assert_eq!(key.rank, 2);
1101        assert_eq!(key.dims, [4, 8, 0]);
1102        assert_eq!(key.kind, FftKind::R2C);
1103        assert_eq!(key.dtype, DType::F32);
1104        assert_eq!(key.batch, 2);
1105        assert!(
1106            key.many_layout.is_some(),
1107            "plan_many keys must carry a layout discriminator"
1108        );
1109
1110        // Two plan_many's with different layouts must hash differently.
1111        let mut other = many.clone();
1112        other.in_dist = 64;
1113        let key2 = other.key();
1114        assert_ne!(
1115            key.many_layout, key2.many_layout,
1116            "different in_dist must produce different layout seeds"
1117        );
1118        assert_ne!(key, key2);
1119    }
1120
1121    #[test]
1122    fn plan_cache_hit_miss() {
1123        // Smoke the cache structure directly (the actor's
1124        // `get_or_create_plan` requires a CudaStream, which we don't
1125        // have here).
1126        let cap = NonZeroUsize::new(2).unwrap();
1127        let mut cache: LruCache<PlanKey, ()> = LruCache::new(cap);
1128
1129        let k1 = PlanKey::plan_1d(1024, FftKind::R2C, 1);
1130        let k2 = PlanKey::plan_2d(64, 64, FftKind::C2C);
1131        let k3 = PlanKey::plan_3d(8, 8, 8, FftKind::Z2Z);
1132
1133        // Miss / miss / miss.
1134        assert!(cache.get(&k1).is_none());
1135        cache.put(k1, ());
1136        assert!(cache.get(&k1).is_some(), "k1 hit after insert");
1137
1138        cache.put(k2, ());
1139        assert!(cache.get(&k2).is_some());
1140
1141        // Inserting k3 evicts the LRU (k1, since k2 was just touched).
1142        cache.put(k3, ());
1143        assert!(cache.get(&k3).is_some());
1144        assert!(cache.get(&k1).is_none(), "k1 should have been LRU-evicted");
1145        assert!(cache.get(&k2).is_some());
1146    }
1147
1148    #[test]
1149    fn deprecated_r2c1d_still_constructs() {
1150        // The legacy F2 variants must keep compiling for one cycle
1151        // post-Phase-1. We can't *construct* a `GpuRef<T>` without a
1152        // CudaContext, but we can verify the enum variant is
1153        // statically reachable through a never-invoked closure: the
1154        // compiler still type-checks the body.
1155        fn _shape_check() {
1156            let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
1157            // Force the type-checker to instantiate the variant by
1158            // pattern-matching on a hypothetical FftMsg.
1159            fn handle(msg: FftMsg) {
1160                match msg {
1161                    FftMsg::Forward1dR2C { .. }
1162                    | FftMsg::Inverse1dC2R { .. }
1163                    | FftMsg::Exec1dC2C { .. }
1164                    | FftMsg::Forward2dR2C { .. } => {}
1165                    FftMsg::Exec(_) => {}
1166                }
1167            }
1168            // Reference all variants via the patterns above; tx is
1169            // dropped to keep the type-check honest.
1170            drop(tx);
1171            let _ = handle;
1172        }
1173        _shape_check();
1174    }
1175
1176    /// Typed `FftRequest<T>` round-trips its dtype kind + plan key
1177    /// for each `FftSupported` dtype. We monomorphize the trait
1178    /// methods to ensure the dispatch surface is generic over every
1179    /// supported dtype; the GpuRefs never get touched.
1180    #[test]
1181    fn request_round_trip_f32_f64_f16() {
1182        fn check<T: FftSupported>(scalar_kind: DType, transform: FftKind) {
1183            // Build a fake request via the type's marker but never
1184            // dereference the GpuRef payload. We use `std::ptr::null`
1185            // analogues by constructing a request and then immediately
1186            // dropping the reply channel; we only call `dtype_kind`
1187            // and `plan_key`, which don't touch the GpuRef.
1188            //
1189            // This is wrapped in a closure so we can probe the
1190            // `FftDispatch` impl without ever calling `dispatch`.
1191            // We reflect through `T::KIND` directly to keep the test
1192            // GPU-free — the *generic surface* is the unit-under-test.
1193            assert_eq!(T::KIND, scalar_kind);
1194            let key = match transform {
1195                FftKind::R2C | FftKind::C2R | FftKind::C2C => PlanKey::plan_1d(8, transform, 1),
1196                FftKind::D2Z | FftKind::Z2D | FftKind::Z2Z => PlanKey::plan_1d(8, transform, 1),
1197            };
1198            assert_eq!(key.dtype, scalar_kind);
1199            assert_eq!(key.kind, transform);
1200        }
1201
1202        check::<f32>(DType::F32, FftKind::R2C);
1203        check::<f32>(DType::F32, FftKind::C2C);
1204        check::<f64>(DType::F64, FftKind::D2Z);
1205        check::<f64>(DType::F64, FftKind::Z2Z);
1206        #[cfg(feature = "f16")]
1207        {
1208            // f16 fft path uses `cufftXtMakePlanMany`; we exercise
1209            // the marker trait gating only.
1210            assert_eq!(<half::f16 as atomr_accel::AccelDtype>::KIND, DType::F16);
1211        }
1212    }
1213
1214    /// Compile-time check that `FftRequest<T>` is `FftDispatch` for
1215    /// every supported dtype.
1216    #[test]
1217    fn fft_request_implements_fft_dispatch_for_all_dtypes() {
1218        fn assert_dispatch<U: FftDispatch>() {}
1219        assert_dispatch::<FftRequest<f32>>();
1220        assert_dispatch::<FftRequest<f64>>();
1221        #[cfg(feature = "f16")]
1222        assert_dispatch::<FftRequest<half::f16>>();
1223    }
1224}