1use std::any::Any;
42use std::ffi::c_void;
43use std::num::NonZeroUsize;
44use std::sync::Arc;
45
46use async_trait::async_trait;
47use atomr_core::actor::{Actor, Context, Props};
48use cudarc::cufft::sys as cufft_sys;
49use cudarc::cufft::{CudaFft, FftDirection as CudarcFftDirection};
50use lru::LruCache;
51use parking_lot::Mutex;
52use tokio::sync::oneshot;
53
54use crate::completion::CompletionStrategy;
55use crate::device::DeviceState;
56use crate::dtype::{DType, FftSupported};
57use crate::error::GpuError;
58use crate::gpu_ref::GpuRef;
59use crate::kernel::dispatch::{FftDispatch, FftDispatchCtx};
60use crate::kernel::envelope;
61use crate::stream::StreamAllocator;
62use crate::sys::cufft as sys_cufft;
63
64const LIB: &str = "cufft";
65const DEFAULT_CACHE_SIZE: usize = 64;
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum FftDirection {
76 Forward,
77 Inverse,
78}
79
80impl FftDirection {
81 pub(crate) fn cudarc(self) -> CudarcFftDirection {
82 match self {
83 FftDirection::Forward => CudarcFftDirection::Forward,
84 FftDirection::Inverse => CudarcFftDirection::Inverse,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97pub enum FftKind {
98 R2C,
99 C2R,
100 C2C,
101 D2Z,
102 Z2D,
103 Z2Z,
104}
105
106impl FftKind {
107 #[allow(non_upper_case_globals)]
109 pub const R2cF32: FftKind = FftKind::R2C;
110 #[allow(non_upper_case_globals)]
111 pub const C2rF32: FftKind = FftKind::C2R;
112 #[allow(non_upper_case_globals)]
113 pub const C2cF32: FftKind = FftKind::C2C;
114
115 pub fn cufft_type(self) -> cufft_sys::cufftType {
116 match self {
117 FftKind::R2C => cufft_sys::cufftType::CUFFT_R2C,
118 FftKind::C2R => cufft_sys::cufftType::CUFFT_C2R,
119 FftKind::C2C => cufft_sys::cufftType::CUFFT_C2C,
120 FftKind::D2Z => cufft_sys::cufftType::CUFFT_D2Z,
121 FftKind::Z2D => cufft_sys::cufftType::CUFFT_Z2D,
122 FftKind::Z2Z => cufft_sys::cufftType::CUFFT_Z2Z,
123 }
124 }
125
126 pub fn scalar_dtype(self) -> DType {
129 match self {
130 FftKind::R2C | FftKind::C2R | FftKind::C2C => DType::F32,
131 FftKind::D2Z | FftKind::Z2D | FftKind::Z2Z => DType::F64,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
141pub struct PlanKey {
142 pub rank: u32,
143 pub dims: [i32; 3],
144 pub kind: FftKind,
145 pub dtype: DType,
146 pub batch: i32,
147 pub many_layout: Option<u64>,
152}
153
154impl PlanKey {
155 pub fn plan_1d(n: i32, kind: FftKind, batch: i32) -> Self {
157 Self {
158 rank: 1,
159 dims: [n, 0, 0],
160 kind,
161 dtype: kind.scalar_dtype(),
162 batch,
163 many_layout: None,
164 }
165 }
166
167 pub fn plan_2d(nx: i32, ny: i32, kind: FftKind) -> Self {
169 Self {
170 rank: 2,
171 dims: [nx, ny, 0],
172 kind,
173 dtype: kind.scalar_dtype(),
174 batch: 1,
175 many_layout: None,
176 }
177 }
178
179 pub fn plan_3d(nx: i32, ny: i32, nz: i32, kind: FftKind) -> Self {
181 Self {
182 rank: 3,
183 dims: [nx, ny, nz],
184 kind,
185 dtype: kind.scalar_dtype(),
186 batch: 1,
187 many_layout: None,
188 }
189 }
190}
191
192#[derive(Debug, Clone)]
199pub struct FftPlanMany {
200 pub rank: u32,
201 pub dims: [i32; 3],
202 pub in_embed: Option<[i32; 3]>,
203 pub in_stride: i32,
204 pub in_dist: i32,
205 pub out_embed: Option<[i32; 3]>,
206 pub out_stride: i32,
207 pub out_dist: i32,
208 pub kind: FftKind,
209 pub batch: i32,
210}
211
212impl FftPlanMany {
213 pub fn layout_seed(&self) -> u64 {
216 use std::collections::hash_map::DefaultHasher;
217 use std::hash::{Hash, Hasher};
218 let mut h = DefaultHasher::new();
219 self.in_embed.hash(&mut h);
220 self.in_stride.hash(&mut h);
221 self.in_dist.hash(&mut h);
222 self.out_embed.hash(&mut h);
223 self.out_stride.hash(&mut h);
224 self.out_dist.hash(&mut h);
225 h.finish()
226 }
227
228 pub fn key(&self) -> PlanKey {
230 PlanKey {
231 rank: self.rank,
232 dims: self.dims,
233 kind: self.kind,
234 dtype: self.kind.scalar_dtype(),
235 batch: self.batch,
236 many_layout: Some(self.layout_seed()),
237 }
238 }
239}
240
241#[derive(Debug, Clone, Copy)]
246pub enum FftCallbackKind {
247 LoadComplex,
248 LoadComplexDouble,
249 LoadReal,
250 LoadRealDouble,
251 StoreComplex,
252 StoreComplexDouble,
253 StoreReal,
254 StoreRealDouble,
255}
256
257impl FftCallbackKind {
258 fn sys(self) -> sys_cufft::CufftXtCallbackType {
259 use sys_cufft::CufftXtCallbackType as T;
260 match self {
261 FftCallbackKind::LoadComplex => T::LoadComplex,
262 FftCallbackKind::LoadComplexDouble => T::LoadComplexDouble,
263 FftCallbackKind::LoadReal => T::LoadReal,
264 FftCallbackKind::LoadRealDouble => T::LoadRealDouble,
265 FftCallbackKind::StoreComplex => T::StoreComplex,
266 FftCallbackKind::StoreComplexDouble => T::StoreComplexDouble,
267 FftCallbackKind::StoreReal => T::StoreReal,
268 FftCallbackKind::StoreRealDouble => T::StoreRealDouble,
269 }
270 }
271}
272
273#[derive(Clone)]
277pub struct FftPlan {
278 pub key: PlanKey,
279 inner: Arc<CudaFft>,
280}
281
282impl std::fmt::Debug for FftPlan {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 f.debug_struct("FftPlan").field("key", &self.key).finish()
285 }
286}
287
288impl FftPlan {
289 pub fn key(&self) -> PlanKey {
290 self.key
291 }
292
293 pub unsafe fn with_callback(
302 &self,
303 kind: FftCallbackKind,
304 cb: *mut c_void,
305 caller_info: *mut c_void,
306 ) -> Result<(), GpuError> {
307 let res = sys_cufft::xt_set_callback(self.inner.handle(), cb, kind.sys(), caller_info);
308 match res.result() {
309 Ok(()) => Ok(()),
310 Err(e) => Err(GpuError::LibraryError {
311 lib: LIB,
312 msg: format!("cufftXtSetCallback({kind:?}): {e:?}"),
313 }),
314 }
315 }
316
317 pub unsafe fn with_load_callback(
322 &self,
323 kind: FftCallbackKind,
324 cb: *mut c_void,
325 caller_info: *mut c_void,
326 ) -> Result<(), GpuError> {
327 debug_assert!(matches!(
328 kind,
329 FftCallbackKind::LoadComplex
330 | FftCallbackKind::LoadComplexDouble
331 | FftCallbackKind::LoadReal
332 | FftCallbackKind::LoadRealDouble
333 ));
334 self.with_callback(kind, cb, caller_info)
335 }
336
337 pub unsafe fn with_store_callback(
342 &self,
343 kind: FftCallbackKind,
344 cb: *mut c_void,
345 caller_info: *mut c_void,
346 ) -> Result<(), GpuError> {
347 debug_assert!(matches!(
348 kind,
349 FftCallbackKind::StoreComplex
350 | FftCallbackKind::StoreComplexDouble
351 | FftCallbackKind::StoreReal
352 | FftCallbackKind::StoreRealDouble
353 ));
354 self.with_callback(kind, cb, caller_info)
355 }
356}
357
358pub struct FftRequest<T: FftSupported, I = u8, O = u8> {
382 pub plan_key: PlanKey,
383 pub direction: FftDirection,
384 pub input: GpuRef<I>,
385 pub output: GpuRef<O>,
386 pub reply: oneshot::Sender<Result<(), GpuError>>,
387 _scalar: std::marker::PhantomData<T>,
388}
389
390impl<T: FftSupported, I, O> FftRequest<T, I, O> {
391 pub fn new(
395 plan_key: PlanKey,
396 direction: FftDirection,
397 input: GpuRef<I>,
398 output: GpuRef<O>,
399 reply: oneshot::Sender<Result<(), GpuError>>,
400 ) -> Self {
401 Self {
402 plan_key,
403 direction,
404 input,
405 output,
406 reply,
407 _scalar: std::marker::PhantomData,
408 }
409 }
410}
411
412impl<T, I, O> FftDispatch for FftRequest<T, I, O>
413where
414 T: FftSupported,
415 I: Send + Sync + 'static,
416 O: Send + Sync + 'static,
417{
418 fn dtype_kind(&self) -> DType {
419 T::KIND
420 }
421
422 fn plan_key(&self) -> PlanKey {
423 self.plan_key
424 }
425
426 fn dispatch(self: Box<Self>, ctx: &FftDispatchCtx<'_>) {
427 let plan = match ctx.plan.clone().downcast::<CudaFft>() {
431 Ok(p) => p,
432 Err(_) => {
433 let _ = self.reply.send(Err(GpuError::Unrecoverable(
434 "FftDispatchCtx.plan downcast to CudaFft failed".into(),
435 )));
436 return;
437 }
438 };
439
440 let stream = ctx.stream.clone();
441 let stream_for_exec = stream.clone();
442 let completion = ctx.completion.clone();
443 let kind = self.plan_key.kind;
444 let direction = self.direction;
445
446 let (src_arc, dst_arc) = match envelope::access_all_2(&self.input, &self.output) {
451 Ok(t) => t,
452 Err(e) => {
453 let _ = self.reply.send(Err(e));
454 return;
455 }
456 };
457
458 self.output.record_write(&stream);
461 let reply = self.reply;
462
463 envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
464 let res = unsafe {
475 exec_kernel(&plan, &src_arc, &dst_arc, kind, direction, &stream_for_exec)
476 };
477 res.map(|_| (src_arc, dst_arc, plan))
478 .map_err(|e| GpuError::LibraryError {
479 lib: LIB,
480 msg: format!("exec_{:?}: {:?}", kind, e),
481 })
482 });
483 }
484}
485
486unsafe fn exec_kernel<I, O>(
500 plan: &Arc<CudaFft>,
501 src: &Arc<cudarc::driver::CudaSlice<I>>,
502 dst: &Arc<cudarc::driver::CudaSlice<O>>,
503 kind: FftKind,
504 direction: FftDirection,
505 stream: &Arc<cudarc::driver::CudaStream>,
506) -> Result<(), cudarc::cufft::result::CufftError> {
507 use cudarc::driver::DevicePtr;
508
509 let (src_ptr, _src_rec) = src.device_ptr(stream);
510 let (dst_ptr, _dst_rec) = dst.device_ptr(stream);
511 let src_ptr = src_ptr as *mut c_void;
512 let dst_ptr = dst_ptr as *mut c_void;
513 let h = plan.handle();
514 use cudarc::cufft::sys as s;
515
516 let r = match kind {
517 FftKind::R2C => s::cufftExecR2C(
518 h,
519 src_ptr as *mut s::cufftReal,
520 dst_ptr as *mut s::cufftComplex,
521 ),
522 FftKind::C2R => s::cufftExecC2R(
523 h,
524 src_ptr as *mut s::cufftComplex,
525 dst_ptr as *mut s::cufftReal,
526 ),
527 FftKind::C2C => s::cufftExecC2C(
528 h,
529 src_ptr as *mut s::cufftComplex,
530 dst_ptr as *mut s::cufftComplex,
531 direction.cudarc() as i32,
532 ),
533 FftKind::D2Z => s::cufftExecD2Z(
534 h,
535 src_ptr as *mut s::cufftDoubleReal,
536 dst_ptr as *mut s::cufftDoubleComplex,
537 ),
538 FftKind::Z2D => s::cufftExecZ2D(
539 h,
540 src_ptr as *mut s::cufftDoubleComplex,
541 dst_ptr as *mut s::cufftDoubleReal,
542 ),
543 FftKind::Z2Z => s::cufftExecZ2Z(
544 h,
545 src_ptr as *mut s::cufftDoubleComplex,
546 dst_ptr as *mut s::cufftDoubleComplex,
547 direction.cudarc() as i32,
548 ),
549 };
550 r.result()
551}
552
553#[allow(deprecated)]
558pub enum FftMsg {
559 Exec(Box<dyn FftDispatch>),
561
562 #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: R2C, .. }")]
564 Forward1dR2C {
565 n: i32,
566 batch: i32,
567 src: GpuRef<f32>,
568 dst: GpuRef<cufft_sys::float2>,
569 reply: oneshot::Sender<Result<(), GpuError>>,
570 },
571 #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: C2R, .. }")]
574 Inverse1dC2R {
575 n: i32,
576 batch: i32,
577 src: GpuRef<cufft_sys::float2>,
578 dst: GpuRef<f32>,
579 reply: oneshot::Sender<Result<(), GpuError>>,
580 },
581 #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: C2C, .. }")]
583 Exec1dC2C {
584 n: i32,
585 batch: i32,
586 direction: CudarcFftDirection,
587 src: GpuRef<cufft_sys::float2>,
588 dst: GpuRef<cufft_sys::float2>,
589 reply: oneshot::Sender<Result<(), GpuError>>,
590 },
591 #[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: R2C, rank=2, .. }")]
593 Forward2dR2C {
594 nx: i32,
595 ny: i32,
596 src: GpuRef<f32>,
597 dst: GpuRef<cufft_sys::float2>,
598 reply: oneshot::Sender<Result<(), GpuError>>,
599 },
600}
601
602pub struct FftActor {
603 inner: FftInner,
604}
605
606struct PlanCache {
611 cache: LruCache<PlanKey, Arc<CudaFft>>,
612}
613
614impl PlanCache {
615 fn new(cap: NonZeroUsize) -> Self {
616 Self {
617 cache: LruCache::new(cap),
618 }
619 }
620}
621
622enum FftInner {
623 Real {
624 stream: Arc<cudarc::driver::CudaStream>,
625 completion: Arc<dyn CompletionStrategy>,
626 plans: Mutex<PlanCache>,
627 #[allow(dead_code)]
628 state: Arc<DeviceState>,
629 },
630 Mock,
631}
632
633impl FftActor {
634 pub fn props(
635 stream: Arc<cudarc::driver::CudaStream>,
636 _allocator: Arc<dyn StreamAllocator>,
637 completion: Arc<dyn CompletionStrategy>,
638 state: Arc<DeviceState>,
639 _ctx: Arc<cudarc::driver::CudaContext>,
640 ) -> Props<Self> {
641 Props::create(move || FftActor {
642 inner: FftInner::Real {
643 stream: stream.clone(),
644 completion: completion.clone(),
645 plans: Mutex::new(PlanCache::new(
646 NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
647 )),
648 state: state.clone(),
649 },
650 })
651 }
652
653 pub fn mock_props() -> Props<Self> {
654 Props::create(|| FftActor {
655 inner: FftInner::Mock,
656 })
657 }
658}
659
660impl FftActor {
661 pub fn ensure_plan(&self, key: PlanKey) -> Result<FftPlan, GpuError> {
665 let arc = self.get_or_create_plan(key)?;
666 Ok(FftPlan { key, inner: arc })
667 }
668
669 pub fn ensure_plan_many(&self, builder: &FftPlanMany) -> Result<FftPlan, GpuError> {
671 let key = builder.key();
672 let FftInner::Real { stream, plans, .. } = &self.inner else {
673 return Err(GpuError::Unrecoverable("fft mock".into()));
674 };
675 {
676 let mut g = plans.lock();
677 if let Some(plan) = g.cache.get(&key) {
678 return Ok(FftPlan {
679 key,
680 inner: plan.clone(),
681 });
682 }
683 }
684 let plan = build_plan_many(builder, stream).map_err(|e| GpuError::LibraryError {
685 lib: LIB,
686 msg: format!("plan_many {key:?}: {e}"),
687 })?;
688 let plan = Arc::new(plan);
689 {
690 let mut g = plans.lock();
691 g.cache.put(key, plan.clone());
692 }
693 Ok(FftPlan { key, inner: plan })
694 }
695
696 fn get_or_create_plan(&self, key: PlanKey) -> Result<Arc<CudaFft>, GpuError> {
697 let FftInner::Real { stream, plans, .. } = &self.inner else {
698 return Err(GpuError::Unrecoverable("fft mock".into()));
699 };
700 {
701 let mut g = plans.lock();
702 if let Some(plan) = g.cache.get(&key) {
703 return Ok(plan.clone());
704 }
705 }
706 let plan = build_simple_plan(&key, stream).map_err(|e| GpuError::LibraryError {
707 lib: LIB,
708 msg: format!("plan {key:?}: {e}"),
709 })?;
710 let plan = Arc::new(plan);
711 {
712 let mut g = plans.lock();
713 g.cache.put(key, plan.clone());
714 }
715 Ok(plan)
716 }
717}
718
719fn build_simple_plan(
720 key: &PlanKey,
721 stream: &Arc<cudarc::driver::CudaStream>,
722) -> Result<CudaFft, cudarc::cufft::result::CufftError> {
723 match key.rank {
724 1 => CudaFft::plan_1d(
725 key.dims[0],
726 key.kind.cufft_type(),
727 key.batch,
728 stream.clone(),
729 ),
730 2 => CudaFft::plan_2d(
731 key.dims[0],
732 key.dims[1],
733 key.kind.cufft_type(),
734 stream.clone(),
735 ),
736 3 => CudaFft::plan_3d(
737 key.dims[0],
738 key.dims[1],
739 key.dims[2],
740 key.kind.cufft_type(),
741 stream.clone(),
742 ),
743 _ => CudaFft::plan_1d(1, key.kind.cufft_type(), 1, stream.clone()),
747 }
748}
749
750fn build_plan_many(
751 b: &FftPlanMany,
752 stream: &Arc<cudarc::driver::CudaStream>,
753) -> Result<CudaFft, cudarc::cufft::result::CufftError> {
754 let n: &[i32] = &b.dims[..b.rank as usize];
755 let in_embed = b.in_embed;
756 let out_embed = b.out_embed;
757 let inembed: Option<&[i32]> = in_embed.as_ref().map(|e| &e[..b.rank as usize]);
758 let onembed: Option<&[i32]> = out_embed.as_ref().map(|e| &e[..b.rank as usize]);
759 CudaFft::plan_many(
760 n,
761 inembed,
762 b.in_stride,
763 b.in_dist,
764 onembed,
765 b.out_stride,
766 b.out_dist,
767 b.kind.cufft_type(),
768 b.batch,
769 stream.clone(),
770 )
771}
772
773#[allow(deprecated)]
778#[async_trait]
779impl Actor for FftActor {
780 type Msg = FftMsg;
781
782 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: FftMsg) {
783 let (stream, completion) = match &self.inner {
784 FftInner::Mock => {
785 reply_mock(msg);
786 return;
787 }
788 FftInner::Real {
789 stream, completion, ..
790 } => (stream.clone(), completion.clone()),
791 };
792
793 match msg {
794 FftMsg::Exec(req) => {
795 let key = req.plan_key();
801 let plan_arc = match self.get_or_create_plan(key) {
802 Ok(p) => p,
803 Err(_e) => {
804 let dummy: Arc<dyn Any + Send + Sync> = Arc::new(());
814 let dispatch_ctx = FftDispatchCtx {
815 stream: &stream,
816 completion: &completion,
817 plan: dummy,
818 };
819 req.dispatch(&dispatch_ctx);
820 return;
821 }
822 };
823 let plan_any: Arc<dyn Any + Send + Sync> = plan_arc;
824 let dispatch_ctx = FftDispatchCtx {
825 stream: &stream,
826 completion: &completion,
827 plan: plan_any,
828 };
829 req.dispatch(&dispatch_ctx);
830 }
831 FftMsg::Forward1dR2C {
832 n,
833 batch,
834 src,
835 dst,
836 reply,
837 } => {
838 let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::R2C, batch)) {
839 Ok(p) => p,
840 Err(e) => {
841 let _ = reply.send(Err(e));
842 return;
843 }
844 };
845 let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
846 Ok(t) => t,
847 Err(e) => {
848 let _ = reply.send(Err(e));
849 return;
850 }
851 };
852 let mut dst_owned = match Arc::try_unwrap(dst_slice) {
853 Ok(s) => s,
854 Err(_) => {
855 let _ = reply.send(Err(GpuError::Unrecoverable(
856 "FFT dst has multiple live references".into(),
857 )));
858 return;
859 }
860 };
861 dst.record_write(&stream);
862 envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
863 plan.exec_r2c(&*src_slice, &mut dst_owned)
864 .map(|_| (src_slice, dst_owned, plan))
865 .map_err(|e| GpuError::LibraryError {
866 lib: LIB,
867 msg: format!("exec_r2c: {e}"),
868 })
869 });
870 }
871 FftMsg::Inverse1dC2R {
872 n,
873 batch,
874 src,
875 dst,
876 reply,
877 } => {
878 let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::C2R, batch)) {
879 Ok(p) => p,
880 Err(e) => {
881 let _ = reply.send(Err(e));
882 return;
883 }
884 };
885 let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
886 Ok(t) => t,
887 Err(e) => {
888 let _ = reply.send(Err(e));
889 return;
890 }
891 };
892 let mut src_owned = match Arc::try_unwrap(src_slice) {
893 Ok(s) => s,
894 Err(_) => {
895 let _ = reply.send(Err(GpuError::Unrecoverable(
896 "FFT C2R src has multiple live references".into(),
897 )));
898 return;
899 }
900 };
901 let mut dst_owned = match Arc::try_unwrap(dst_slice) {
902 Ok(s) => s,
903 Err(_) => {
904 let _ = reply.send(Err(GpuError::Unrecoverable(
905 "FFT C2R dst has multiple live references".into(),
906 )));
907 return;
908 }
909 };
910 dst.record_write(&stream);
911 envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
912 plan.exec_c2r(&mut src_owned, &mut dst_owned)
913 .map(|_| (src_owned, dst_owned, plan))
914 .map_err(|e| GpuError::LibraryError {
915 lib: LIB,
916 msg: format!("exec_c2r: {e}"),
917 })
918 });
919 }
920 FftMsg::Exec1dC2C {
921 n,
922 batch,
923 direction,
924 src,
925 dst,
926 reply,
927 } => {
928 let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::C2C, batch)) {
929 Ok(p) => p,
930 Err(e) => {
931 let _ = reply.send(Err(e));
932 return;
933 }
934 };
935 let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
936 Ok(t) => t,
937 Err(e) => {
938 let _ = reply.send(Err(e));
939 return;
940 }
941 };
942 let mut src_owned = match Arc::try_unwrap(src_slice) {
943 Ok(s) => s,
944 Err(_) => {
945 let _ = reply.send(Err(GpuError::Unrecoverable(
946 "FFT C2C src has multiple live references".into(),
947 )));
948 return;
949 }
950 };
951 let mut dst_owned = match Arc::try_unwrap(dst_slice) {
952 Ok(s) => s,
953 Err(_) => {
954 let _ = reply.send(Err(GpuError::Unrecoverable(
955 "FFT C2C dst has multiple live references".into(),
956 )));
957 return;
958 }
959 };
960 dst.record_write(&stream);
961 envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
962 plan.exec_c2c(&mut src_owned, &mut dst_owned, direction)
963 .map(|_| (src_owned, dst_owned, plan))
964 .map_err(|e| GpuError::LibraryError {
965 lib: LIB,
966 msg: format!("exec_c2c: {e}"),
967 })
968 });
969 }
970 FftMsg::Forward2dR2C {
971 nx,
972 ny,
973 src,
974 dst,
975 reply,
976 } => {
977 let plan = match self.get_or_create_plan(PlanKey::plan_2d(nx, ny, FftKind::R2C)) {
978 Ok(p) => p,
979 Err(e) => {
980 let _ = reply.send(Err(e));
981 return;
982 }
983 };
984 let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
985 Ok(t) => t,
986 Err(e) => {
987 let _ = reply.send(Err(e));
988 return;
989 }
990 };
991 let mut dst_owned = match Arc::try_unwrap(dst_slice) {
992 Ok(s) => s,
993 Err(_) => {
994 let _ = reply.send(Err(GpuError::Unrecoverable(
995 "FFT 2D dst has multiple live references".into(),
996 )));
997 return;
998 }
999 };
1000 dst.record_write(&stream);
1001 envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
1002 plan.exec_r2c(&*src_slice, &mut dst_owned)
1003 .map(|_| (src_slice, dst_owned, plan))
1004 .map_err(|e| GpuError::LibraryError {
1005 lib: LIB,
1006 msg: format!("exec_r2c (2d): {e}"),
1007 })
1008 });
1009 }
1010 }
1011 }
1012}
1013
1014#[allow(deprecated)]
1015fn reply_mock(msg: FftMsg) {
1016 let err = || GpuError::Unrecoverable("FftActor in mock mode".into());
1017 match msg {
1018 FftMsg::Exec(req) => {
1019 drop(req);
1023 }
1024 FftMsg::Forward1dR2C { reply, .. } => {
1025 let _ = reply.send(Err(err()));
1026 }
1027 FftMsg::Inverse1dC2R { reply, .. } => {
1028 let _ = reply.send(Err(err()));
1029 }
1030 FftMsg::Exec1dC2C { reply, .. } => {
1031 let _ = reply.send(Err(err()));
1032 }
1033 FftMsg::Forward2dR2C { reply, .. } => {
1034 let _ = reply.send(Err(err()));
1035 }
1036 }
1037}
1038
1039#[cfg(test)]
1044mod tests {
1045 #![allow(deprecated)]
1046 use super::*;
1047 #[cfg(feature = "f16")]
1048 use crate::dtype::CudaDtype;
1049
1050 #[test]
1055 fn plan_key_for_simple_plans_zeroes_unused_dims() {
1056 let k1 = PlanKey::plan_1d(1024, FftKind::R2C, 1);
1057 assert_eq!(k1.rank, 1);
1058 assert_eq!(k1.dims, [1024, 0, 0]);
1059 assert_eq!(k1.dtype, DType::F32);
1060 assert!(k1.many_layout.is_none());
1061
1062 let k2 = PlanKey::plan_2d(64, 64, FftKind::R2C);
1063 assert_eq!(k2.rank, 2);
1064 assert_eq!(k2.dims, [64, 64, 0]);
1065 assert_eq!(k2.dtype, DType::F32);
1066
1067 let k3 = PlanKey::plan_3d(32, 32, 32, FftKind::Z2Z);
1068 assert_eq!(k3.rank, 3);
1069 assert_eq!(k3.dims, [32, 32, 32]);
1070 assert_eq!(k3.dtype, DType::F64);
1071 }
1072
1073 #[test]
1076 fn fft_3d_plan_dim_handling() {
1077 let k = PlanKey::plan_3d(8, 16, 32, FftKind::C2C);
1078 assert_eq!(k.rank, 3);
1079 assert_eq!(k.dims[0], 8);
1080 assert_eq!(k.dims[1], 16);
1081 assert_eq!(k.dims[2], 32);
1082 assert_eq!(k.kind, FftKind::C2C);
1083 }
1084
1085 #[test]
1086 fn plan_many_descriptor_correct() {
1087 let many = FftPlanMany {
1088 rank: 2,
1089 dims: [4, 8, 0],
1090 in_embed: Some([4, 8, 0]),
1091 in_stride: 1,
1092 in_dist: 32,
1093 out_embed: Some([4, 5, 0]),
1094 out_stride: 1,
1095 out_dist: 20,
1096 kind: FftKind::R2C,
1097 batch: 2,
1098 };
1099 let key = many.key();
1100 assert_eq!(key.rank, 2);
1101 assert_eq!(key.dims, [4, 8, 0]);
1102 assert_eq!(key.kind, FftKind::R2C);
1103 assert_eq!(key.dtype, DType::F32);
1104 assert_eq!(key.batch, 2);
1105 assert!(
1106 key.many_layout.is_some(),
1107 "plan_many keys must carry a layout discriminator"
1108 );
1109
1110 let mut other = many.clone();
1112 other.in_dist = 64;
1113 let key2 = other.key();
1114 assert_ne!(
1115 key.many_layout, key2.many_layout,
1116 "different in_dist must produce different layout seeds"
1117 );
1118 assert_ne!(key, key2);
1119 }
1120
1121 #[test]
1122 fn plan_cache_hit_miss() {
1123 let cap = NonZeroUsize::new(2).unwrap();
1127 let mut cache: LruCache<PlanKey, ()> = LruCache::new(cap);
1128
1129 let k1 = PlanKey::plan_1d(1024, FftKind::R2C, 1);
1130 let k2 = PlanKey::plan_2d(64, 64, FftKind::C2C);
1131 let k3 = PlanKey::plan_3d(8, 8, 8, FftKind::Z2Z);
1132
1133 assert!(cache.get(&k1).is_none());
1135 cache.put(k1, ());
1136 assert!(cache.get(&k1).is_some(), "k1 hit after insert");
1137
1138 cache.put(k2, ());
1139 assert!(cache.get(&k2).is_some());
1140
1141 cache.put(k3, ());
1143 assert!(cache.get(&k3).is_some());
1144 assert!(cache.get(&k1).is_none(), "k1 should have been LRU-evicted");
1145 assert!(cache.get(&k2).is_some());
1146 }
1147
1148 #[test]
1149 fn deprecated_r2c1d_still_constructs() {
1150 fn _shape_check() {
1156 let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
1157 fn handle(msg: FftMsg) {
1160 match msg {
1161 FftMsg::Forward1dR2C { .. }
1162 | FftMsg::Inverse1dC2R { .. }
1163 | FftMsg::Exec1dC2C { .. }
1164 | FftMsg::Forward2dR2C { .. } => {}
1165 FftMsg::Exec(_) => {}
1166 }
1167 }
1168 drop(tx);
1171 let _ = handle;
1172 }
1173 _shape_check();
1174 }
1175
1176 #[test]
1181 fn request_round_trip_f32_f64_f16() {
1182 fn check<T: FftSupported>(scalar_kind: DType, transform: FftKind) {
1183 assert_eq!(T::KIND, scalar_kind);
1194 let key = match transform {
1195 FftKind::R2C | FftKind::C2R | FftKind::C2C => PlanKey::plan_1d(8, transform, 1),
1196 FftKind::D2Z | FftKind::Z2D | FftKind::Z2Z => PlanKey::plan_1d(8, transform, 1),
1197 };
1198 assert_eq!(key.dtype, scalar_kind);
1199 assert_eq!(key.kind, transform);
1200 }
1201
1202 check::<f32>(DType::F32, FftKind::R2C);
1203 check::<f32>(DType::F32, FftKind::C2C);
1204 check::<f64>(DType::F64, FftKind::D2Z);
1205 check::<f64>(DType::F64, FftKind::Z2Z);
1206 #[cfg(feature = "f16")]
1207 {
1208 assert_eq!(<half::f16 as atomr_accel::AccelDtype>::KIND, DType::F16);
1211 }
1212 }
1213
1214 #[test]
1217 fn fft_request_implements_fft_dispatch_for_all_dtypes() {
1218 fn assert_dispatch<U: FftDispatch>() {}
1219 assert_dispatch::<FftRequest<f32>>();
1220 assert_dispatch::<FftRequest<f64>>();
1221 #[cfg(feature = "f16")]
1222 assert_dispatch::<FftRequest<half::f16>>();
1223 }
1224}