Skip to main content

atomr_accel_cuda/graph/record/
fft_r2c.rs

1//! 1-D R2C FFT op for [`super::super::GraphOp`].
2//!
3//! Wraps [`crate::kernel::record::FftR2COp`] /
4//! [`crate::kernel::record::FftRecorder`] into a single
5//! `GraphOp`-implementing type. Gated on the `cufft` feature.
6
7use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{FftR2COp as InnerFftR2COp, FftRecorder, RecordMode};
11
12/// 1-D R2C FFT op for graph capture. The user installs a pre-built
13/// `cudarc::cufft::CudaFft` plan via `GraphMsg::SetFftPlan` before
14/// recording; the plan is borrowed through `GraphRecordCtx::fft`.
15///
16/// `record` consumes the held `GpuRef`s on first invocation; a
17/// second call returns [`GpuError::Unrecoverable`].
18pub struct FftR2COp {
19    inner: Option<InnerFftR2COp>,
20}
21
22impl FftR2COp {
23    pub fn new(src: GpuRef<f32>, dst: GpuRef<cudarc::cufft::sys::float2>) -> Self {
24        Self {
25            inner: Some(InnerFftR2COp { src, dst }),
26        }
27    }
28}
29
30impl GraphOp for FftR2COp {
31    fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
32        let stream = ctx.require_stream()?;
33        let plan = ctx.fft.ok_or_else(|| {
34            GpuError::Unrecoverable(
35                "FftR2COp::record: no cuFFT plan installed; call GraphMsg::SetFftPlan first".into(),
36            )
37        })?;
38        let op = self
39            .inner
40            .take()
41            .ok_or_else(|| GpuError::Unrecoverable("FftR2COp::record: already consumed".into()))?;
42        let mut recorder = FftRecorder { plan };
43        recorder.enqueue_record(stream, op)
44    }
45
46    fn op_name(&self) -> &'static str {
47        "graph::fft_r2c"
48    }
49}