Skip to main content

atomr_accel_cuda/kernel/
record.rs

1//! Capture-mode contract.
2//!
3//! Library actors that participate in [`crate::pipeline`] or
4//! [`crate::graph`] implement [`RecordMode`] so they can be driven
5//! without registering host-callback completion. The capture caller
6//! (a `PipelineStage` adapter or `GraphActor`) feeds an operation,
7//! the library actor enqueues it onto the supplied stream, and the
8//! caller manages cross-stream synchronization itself via
9//! `CudaEvent`s (pipeline) or graph instantiation (graph).
10//!
11//! **Capture-safe vs. unsafe:** anything calling host functions
12//! (`HostFnCompletion`'s `cuLaunchHostFunc`, `tokio::spawn`,
13//! synchronous memcpy) is *not* capture-safe. RecordMode impls must
14//! enqueue purely on the stream and return synchronously.
15
16use std::sync::Arc;
17
18use cudarc::cublas::sys::cublasOperation_t;
19use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
20
21use crate::error::GpuError;
22use crate::gpu_ref::GpuRef;
23
24#[cfg(feature = "curand")]
25use cudarc::curand::CudaRng;
26
27#[cfg(feature = "cufft")]
28use cudarc::cufft::CudaFft;
29
30/// Library-actor opt-in to capture-mode enqueue.
31///
32/// Implementations enqueue `op` onto `stream` synchronously and return
33/// without awaiting completion. The caller is responsible for any
34/// downstream `CudaEvent` recording / wait.
35pub trait RecordMode {
36    /// The operation type, typically a thin enum mirroring the actor's
37    /// public `Msg` enum but stripped of `oneshot::Sender` reply
38    /// channels.
39    type Op;
40
41    fn enqueue_record(
42        &mut self,
43        stream: &Arc<cudarc::driver::CudaStream>,
44        op: Self::Op,
45    ) -> Result<(), GpuError>;
46}
47
48/// Record-mode op for cuBLAS SGEMM. Mirrors `BlasMsg::Sgemm`'s
49/// payload minus the reply channel.
50pub struct BlasSgemmOp {
51    pub a: GpuRef<f32>,
52    pub b: GpuRef<f32>,
53    pub c: GpuRef<f32>,
54    pub m: i32,
55    pub n: i32,
56    pub k: i32,
57    pub alpha: f32,
58    pub beta: f32,
59}
60
61/// Memcpy op (host-side `cudaMemcpyAsync` device-to-device on the
62/// captured stream). Capture-safe.
63pub struct MemcpyOp {
64    pub src: GpuRef<f32>,
65    pub dst: GpuRef<f32>,
66}
67
68/// Uniform RNG fill op. Capture-safe in cuRAND when no host counter
69/// is consulted; the actor records an in-place uniform fill against
70/// the supplied buffer.
71#[cfg(feature = "curand")]
72pub struct RngFillUniformOp {
73    pub dst: GpuRef<f32>,
74}
75
76/// Capture-mode wrapper around a `CudaBlas` handle. Used by
77/// [`crate::graph::GraphActor`] when it needs to record an SGEMM
78/// inside a stream-capture region.
79pub struct BlasRecorder<'a> {
80    pub handle: &'a CudaBlas,
81}
82
83/// Capture-mode wrapper for in-stream device-to-device memcpy.
84pub struct MemcpyRecorder;
85
86impl RecordMode for MemcpyRecorder {
87    type Op = MemcpyOp;
88    fn enqueue_record(
89        &mut self,
90        stream: &Arc<cudarc::driver::CudaStream>,
91        op: Self::Op,
92    ) -> Result<(), GpuError> {
93        let MemcpyOp { src, dst } = op;
94        let src_slice = src.access()?.clone();
95        let dst_slice = dst.access()?.clone();
96        let mut dst_owned = Arc::try_unwrap(dst_slice)
97            .map_err(|_| GpuError::Unrecoverable("MemcpyRecorder: dst has multiple refs".into()))?;
98        stream
99            .memcpy_dtod(&*src_slice, &mut dst_owned)
100            .map_err(|e| GpuError::LibraryError {
101                lib: "driver",
102                msg: format!("record memcpy_dtod: {e}"),
103            })?;
104        dst.record_write(stream);
105        let _ = (src_slice, dst_owned);
106        Ok(())
107    }
108}
109
110#[cfg(feature = "curand")]
111pub struct RngRecorder<'a> {
112    pub rng: &'a CudaRng,
113}
114
115#[cfg(feature = "cufft")]
116pub struct FftR2COp {
117    pub src: GpuRef<f32>,
118    pub dst: GpuRef<cudarc::cufft::sys::float2>,
119}
120
121#[cfg(feature = "cufft")]
122pub struct FftRecorder<'a> {
123    pub plan: &'a CudaFft,
124}
125
126#[cfg(feature = "cufft")]
127impl<'a> RecordMode for FftRecorder<'a> {
128    type Op = FftR2COp;
129    fn enqueue_record(
130        &mut self,
131        stream: &Arc<cudarc::driver::CudaStream>,
132        op: Self::Op,
133    ) -> Result<(), GpuError> {
134        let FftR2COp { src, dst } = op;
135        let src_slice = src.access()?.clone();
136        let dst_slice = dst.access()?.clone();
137        let mut dst_owned = Arc::try_unwrap(dst_slice)
138            .map_err(|_| GpuError::Unrecoverable("FftRecorder: dst has multiple refs".into()))?;
139        self.plan
140            .exec_r2c(&*src_slice, &mut dst_owned)
141            .map_err(|e| GpuError::LibraryError {
142                lib: "cufft",
143                msg: format!("record exec_r2c: {e}"),
144            })?;
145        dst.record_write(stream);
146        let _ = (src_slice, dst_owned);
147        Ok(())
148    }
149}
150
151#[cfg(feature = "curand")]
152impl<'a> RecordMode for RngRecorder<'a> {
153    type Op = RngFillUniformOp;
154    fn enqueue_record(
155        &mut self,
156        stream: &Arc<cudarc::driver::CudaStream>,
157        op: Self::Op,
158    ) -> Result<(), GpuError> {
159        let RngFillUniformOp { dst } = op;
160        let dst_slice = dst.access()?.clone();
161        let mut owned = Arc::try_unwrap(dst_slice)
162            .map_err(|_| GpuError::Unrecoverable("RngRecorder: dst has multiple refs".into()))?;
163        self.rng
164            .fill_with_uniform(&mut owned)
165            .map_err(|e| GpuError::LibraryError {
166                lib: "curand",
167                msg: format!("record fill_uniform: {e:?}"),
168            })?;
169        dst.record_write(stream);
170        let _ = owned;
171        Ok(())
172    }
173}
174
175impl<'a> RecordMode for BlasRecorder<'a> {
176    type Op = BlasSgemmOp;
177
178    fn enqueue_record(
179        &mut self,
180        stream: &Arc<cudarc::driver::CudaStream>,
181        op: Self::Op,
182    ) -> Result<(), GpuError> {
183        let BlasSgemmOp {
184            a,
185            b,
186            c,
187            m,
188            n,
189            k,
190            alpha,
191            beta,
192        } = op;
193        let a_slice = a.access()?.clone();
194        let b_slice = b.access()?.clone();
195        let c_slice = c.access()?.clone();
196        let mut c_owned = Arc::try_unwrap(c_slice).map_err(|_| {
197            GpuError::Unrecoverable("BlasRecorder: C has multiple live references".into())
198        })?;
199
200        let cfg = GemmConfig::<f32> {
201            transa: cublasOperation_t::CUBLAS_OP_N,
202            transb: cublasOperation_t::CUBLAS_OP_N,
203            m,
204            n,
205            k,
206            alpha,
207            lda: m,
208            ldb: k,
209            beta,
210            ldc: m,
211        };
212        // SAFETY: m/n/k validity is the caller's contract.
213        unsafe {
214            self.handle
215                .gemm(cfg, &*a_slice, &*b_slice, &mut c_owned)
216                .map_err(|e| GpuError::LibraryError {
217                    lib: "cublas",
218                    msg: format!("record gemm: {e}"),
219                })?;
220        }
221        c.record_write(stream);
222        // Slices must outlive the stream operation; in capture mode
223        // the graph-exec object holds them. We leak ownership back
224        // into the GpuRef by re-Arcing — only safe because the
225        // graph machinery keeps the buffers alive until the graph
226        // is destroyed.
227        let _ = (a_slice, b_slice, c_owned);
228        Ok(())
229    }
230}