atomr_accel_cuda/kernel/
record.rs1use std::sync::Arc;
17
18use cudarc::cublas::sys::cublasOperation_t;
19use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
20
21use crate::error::GpuError;
22use crate::gpu_ref::GpuRef;
23
24#[cfg(feature = "curand")]
25use cudarc::curand::CudaRng;
26
27#[cfg(feature = "cufft")]
28use cudarc::cufft::CudaFft;
29
30pub trait RecordMode {
36 type Op;
40
41 fn enqueue_record(
42 &mut self,
43 stream: &Arc<cudarc::driver::CudaStream>,
44 op: Self::Op,
45 ) -> Result<(), GpuError>;
46}
47
48pub struct BlasSgemmOp {
51 pub a: GpuRef<f32>,
52 pub b: GpuRef<f32>,
53 pub c: GpuRef<f32>,
54 pub m: i32,
55 pub n: i32,
56 pub k: i32,
57 pub alpha: f32,
58 pub beta: f32,
59}
60
61pub struct MemcpyOp {
64 pub src: GpuRef<f32>,
65 pub dst: GpuRef<f32>,
66}
67
68#[cfg(feature = "curand")]
72pub struct RngFillUniformOp {
73 pub dst: GpuRef<f32>,
74}
75
76pub struct BlasRecorder<'a> {
80 pub handle: &'a CudaBlas,
81}
82
83pub struct MemcpyRecorder;
85
86impl RecordMode for MemcpyRecorder {
87 type Op = MemcpyOp;
88 fn enqueue_record(
89 &mut self,
90 stream: &Arc<cudarc::driver::CudaStream>,
91 op: Self::Op,
92 ) -> Result<(), GpuError> {
93 let MemcpyOp { src, dst } = op;
94 let src_slice = src.access()?.clone();
95 let dst_slice = dst.access()?.clone();
96 let mut dst_owned = Arc::try_unwrap(dst_slice)
97 .map_err(|_| GpuError::Unrecoverable("MemcpyRecorder: dst has multiple refs".into()))?;
98 stream
99 .memcpy_dtod(&*src_slice, &mut dst_owned)
100 .map_err(|e| GpuError::LibraryError {
101 lib: "driver",
102 msg: format!("record memcpy_dtod: {e}"),
103 })?;
104 dst.record_write(stream);
105 let _ = (src_slice, dst_owned);
106 Ok(())
107 }
108}
109
110#[cfg(feature = "curand")]
111pub struct RngRecorder<'a> {
112 pub rng: &'a CudaRng,
113}
114
115#[cfg(feature = "cufft")]
116pub struct FftR2COp {
117 pub src: GpuRef<f32>,
118 pub dst: GpuRef<cudarc::cufft::sys::float2>,
119}
120
121#[cfg(feature = "cufft")]
122pub struct FftRecorder<'a> {
123 pub plan: &'a CudaFft,
124}
125
126#[cfg(feature = "cufft")]
127impl<'a> RecordMode for FftRecorder<'a> {
128 type Op = FftR2COp;
129 fn enqueue_record(
130 &mut self,
131 stream: &Arc<cudarc::driver::CudaStream>,
132 op: Self::Op,
133 ) -> Result<(), GpuError> {
134 let FftR2COp { src, dst } = op;
135 let src_slice = src.access()?.clone();
136 let dst_slice = dst.access()?.clone();
137 let mut dst_owned = Arc::try_unwrap(dst_slice)
138 .map_err(|_| GpuError::Unrecoverable("FftRecorder: dst has multiple refs".into()))?;
139 self.plan
140 .exec_r2c(&*src_slice, &mut dst_owned)
141 .map_err(|e| GpuError::LibraryError {
142 lib: "cufft",
143 msg: format!("record exec_r2c: {e}"),
144 })?;
145 dst.record_write(stream);
146 let _ = (src_slice, dst_owned);
147 Ok(())
148 }
149}
150
151#[cfg(feature = "curand")]
152impl<'a> RecordMode for RngRecorder<'a> {
153 type Op = RngFillUniformOp;
154 fn enqueue_record(
155 &mut self,
156 stream: &Arc<cudarc::driver::CudaStream>,
157 op: Self::Op,
158 ) -> Result<(), GpuError> {
159 let RngFillUniformOp { dst } = op;
160 let dst_slice = dst.access()?.clone();
161 let mut owned = Arc::try_unwrap(dst_slice)
162 .map_err(|_| GpuError::Unrecoverable("RngRecorder: dst has multiple refs".into()))?;
163 self.rng
164 .fill_with_uniform(&mut owned)
165 .map_err(|e| GpuError::LibraryError {
166 lib: "curand",
167 msg: format!("record fill_uniform: {e:?}"),
168 })?;
169 dst.record_write(stream);
170 let _ = owned;
171 Ok(())
172 }
173}
174
175impl<'a> RecordMode for BlasRecorder<'a> {
176 type Op = BlasSgemmOp;
177
178 fn enqueue_record(
179 &mut self,
180 stream: &Arc<cudarc::driver::CudaStream>,
181 op: Self::Op,
182 ) -> Result<(), GpuError> {
183 let BlasSgemmOp {
184 a,
185 b,
186 c,
187 m,
188 n,
189 k,
190 alpha,
191 beta,
192 } = op;
193 let a_slice = a.access()?.clone();
194 let b_slice = b.access()?.clone();
195 let c_slice = c.access()?.clone();
196 let mut c_owned = Arc::try_unwrap(c_slice).map_err(|_| {
197 GpuError::Unrecoverable("BlasRecorder: C has multiple live references".into())
198 })?;
199
200 let cfg = GemmConfig::<f32> {
201 transa: cublasOperation_t::CUBLAS_OP_N,
202 transb: cublasOperation_t::CUBLAS_OP_N,
203 m,
204 n,
205 k,
206 alpha,
207 lda: m,
208 ldb: k,
209 beta,
210 ldc: m,
211 };
212 unsafe {
214 self.handle
215 .gemm(cfg, &*a_slice, &*b_slice, &mut c_owned)
216 .map_err(|e| GpuError::LibraryError {
217 lib: "cublas",
218 msg: format!("record gemm: {e}"),
219 })?;
220 }
221 c.record_write(stream);
222 let _ = (a_slice, b_slice, c_owned);
228 Ok(())
229 }
230}