atomr_accel_cuda/graph/record/
fft_r2c.rs1use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{FftR2COp as InnerFftR2COp, FftRecorder, RecordMode};
11
12pub struct FftR2COp {
19 inner: Option<InnerFftR2COp>,
20}
21
22impl FftR2COp {
23 pub fn new(src: GpuRef<f32>, dst: GpuRef<cudarc::cufft::sys::float2>) -> Self {
24 Self {
25 inner: Some(InnerFftR2COp { src, dst }),
26 }
27 }
28}
29
30impl GraphOp for FftR2COp {
31 fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
32 let stream = ctx.require_stream()?;
33 let plan = ctx.fft.ok_or_else(|| {
34 GpuError::Unrecoverable(
35 "FftR2COp::record: no cuFFT plan installed; call GraphMsg::SetFftPlan first".into(),
36 )
37 })?;
38 let op = self
39 .inner
40 .take()
41 .ok_or_else(|| GpuError::Unrecoverable("FftR2COp::record: already consumed".into()))?;
42 let mut recorder = FftRecorder { plan };
43 recorder.enqueue_record(stream, op)
44 }
45
46 fn op_name(&self) -> &'static str {
47 "graph::fft_r2c"
48 }
49}