atomr_accel_cuda/graph/record/
rng_fill_uniform.rs1use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{RecordMode, RngFillUniformOp as InnerRngFillUniformOp, RngRecorder};
11
12pub struct RngFillUniformOp {
21 inner: Option<InnerRngFillUniformOp>,
22}
23
24impl RngFillUniformOp {
25 pub fn new(dst: GpuRef<f32>) -> Self {
26 Self {
27 inner: Some(InnerRngFillUniformOp { dst }),
28 }
29 }
30}
31
32impl GraphOp for RngFillUniformOp {
33 fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
34 let stream = ctx.require_stream()?;
35 let rng = ctx.rng.ok_or_else(|| {
36 GpuError::Unrecoverable(
37 "RngFillUniformOp::record: cuRAND handle not available in ctx".into(),
38 )
39 })?;
40 let op = self.inner.take().ok_or_else(|| {
41 GpuError::Unrecoverable("RngFillUniformOp::record: already consumed".into())
42 })?;
43 let mut recorder = RngRecorder { rng };
44 recorder.enqueue_record(stream, op)
45 }
46
47 fn op_name(&self) -> &'static str {
48 "graph::rng_fill_uniform"
49 }
50}