Skip to main content

atomr_accel_cuda/graph/record/
rng_fill_uniform.rs

1//! Uniform RNG fill op for [`super::super::GraphOp`].
2//!
3//! Wraps [`crate::kernel::record::RngFillUniformOp`] /
4//! [`crate::kernel::record::RngRecorder`] into a single
5//! `GraphOp`-implementing type. Gated on the `curand` feature.
6
7use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{RecordMode, RngFillUniformOp as InnerRngFillUniformOp, RngRecorder};
11
12/// Uniform RNG fill op for graph capture.
13///
14/// The op needs a cuRAND handle on the captured stream, supplied
15/// by `GraphRecordCtx::rng`. If absent the op fails with
16/// [`GpuError::Unrecoverable`].
17///
18/// `record` consumes the held `GpuRef` on first invocation; a
19/// second call returns [`GpuError::Unrecoverable`].
20pub struct RngFillUniformOp {
21    inner: Option<InnerRngFillUniformOp>,
22}
23
24impl RngFillUniformOp {
25    pub fn new(dst: GpuRef<f32>) -> Self {
26        Self {
27            inner: Some(InnerRngFillUniformOp { dst }),
28        }
29    }
30}
31
32impl GraphOp for RngFillUniformOp {
33    fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
34        let stream = ctx.require_stream()?;
35        let rng = ctx.rng.ok_or_else(|| {
36            GpuError::Unrecoverable(
37                "RngFillUniformOp::record: cuRAND handle not available in ctx".into(),
38            )
39        })?;
40        let op = self.inner.take().ok_or_else(|| {
41            GpuError::Unrecoverable("RngFillUniformOp::record: already consumed".into())
42        })?;
43        let mut recorder = RngRecorder { rng };
44        recorder.enqueue_record(stream, op)
45    }
46
47    fn op_name(&self) -> &'static str {
48        "graph::rng_fill_uniform"
49    }
50}