Skip to main content

atomr_accel_cuda/graph/record/
memcpy.rs

1//! Device-to-device memcpy op for [`super::super::GraphOp`].
2//!
3//! Wraps [`crate::kernel::record::MemcpyOp`] /
4//! [`crate::kernel::record::MemcpyRecorder`] into a single
5//! `GraphOp`-implementing type.
6
7use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{MemcpyOp as InnerMemcpyOp, MemcpyRecorder, RecordMode};
11
12/// Device-to-device memcpy op for the captured stream. Capture-safe.
13///
14/// `record` consumes the held `GpuRef`s on first invocation; a
15/// second call returns [`GpuError::Unrecoverable`].
16pub struct MemcpyOp {
17    inner: Option<InnerMemcpyOp>,
18}
19
20impl MemcpyOp {
21    pub fn new(src: GpuRef<f32>, dst: GpuRef<f32>) -> Self {
22        Self {
23            inner: Some(InnerMemcpyOp { src, dst }),
24        }
25    }
26}
27
28impl GraphOp for MemcpyOp {
29    fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
30        let stream = ctx.require_stream()?;
31        let op = self
32            .inner
33            .take()
34            .ok_or_else(|| GpuError::Unrecoverable("MemcpyOp::record: already consumed".into()))?;
35        let mut recorder = MemcpyRecorder;
36        recorder.enqueue_record(stream, op)
37    }
38
39    fn op_name(&self) -> &'static str {
40        "graph::memcpy"
41    }
42}