1use std::sync::Arc;
26
27use async_trait::async_trait;
28use atomr_core::actor::{Actor, Context, Props};
29use parking_lot::Mutex;
30use tokio::sync::oneshot;
31
32use crate::completion::CompletionStrategy;
33use crate::device::DeviceState;
34use crate::dtype::RngFloatSupported;
35use crate::error::GpuError;
36use crate::gpu_ref::GpuRef;
37use crate::kernel::dispatch::RngDispatch;
38use crate::stream::StreamAllocator;
39use crate::sys::curand as csys;
40
41pub mod dist;
42#[cfg(feature = "curand-host")]
43pub mod host;
44#[cfg(feature = "curand-quasirandom")]
45pub mod sobol;
46
47pub use crate::sys::curand::RngGeneratorKind;
48pub use dist::{Distribution, FillRequest};
49
50pub(crate) const LIB: &str = "curand";
51
52#[non_exhaustive]
63pub enum RngMsg {
64 Fill(Box<dyn RngDispatch>),
66 SetSeed {
68 seed: u64,
69 reply: oneshot::Sender<Result<(), GpuError>>,
70 },
71 SetGenerator {
76 kind: RngGeneratorKind,
77 reply: oneshot::Sender<Result<(), GpuError>>,
78 },
79 #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
80 FillUniformF32 {
81 dst: GpuRef<f32>,
82 reply: oneshot::Sender<Result<(), GpuError>>,
83 },
84 #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
85 FillUniformF64 {
86 dst: GpuRef<f64>,
87 reply: oneshot::Sender<Result<(), GpuError>>,
88 },
89 #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
90 FillUniformU32 {
91 dst: GpuRef<u32>,
92 reply: oneshot::Sender<Result<(), GpuError>>,
93 },
94 #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
95 FillNormalF32 {
96 dst: GpuRef<f32>,
97 mean: f32,
98 std: f32,
99 reply: oneshot::Sender<Result<(), GpuError>>,
100 },
101 #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
102 FillNormalF64 {
103 dst: GpuRef<f64>,
104 mean: f64,
105 std: f64,
106 reply: oneshot::Sender<Result<(), GpuError>>,
107 },
108 #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
109 FillLogNormalF32 {
110 dst: GpuRef<f32>,
111 mean: f32,
112 std: f32,
113 reply: oneshot::Sender<Result<(), GpuError>>,
114 },
115 #[deprecated(note = "use RngMsg::SetSeed { seed, reply } instead")]
116 Reseed {
117 seed: u64,
118 reply: oneshot::Sender<Result<(), GpuError>>,
119 },
120}
121
122pub(crate) struct SendGen(pub(crate) cudarc::curand::sys::curandGenerator_t);
128
129unsafe impl Send for SendGen {}
133unsafe impl Sync for SendGen {}
134
135pub struct RngActor {
136 inner: RngInner,
137}
138
139pub(crate) enum RngInner {
140 Real {
141 gen: Mutex<SendGen>,
142 kind: Mutex<RngGeneratorKind>,
143 stream: Arc<cudarc::driver::CudaStream>,
144 completion: Arc<dyn CompletionStrategy>,
145 #[allow(dead_code)]
146 state: Arc<DeviceState>,
147 },
148 Mock,
149}
150
151impl RngActor {
152 pub fn props(
153 stream: Arc<cudarc::driver::CudaStream>,
154 _allocator: Arc<dyn StreamAllocator>,
155 completion: Arc<dyn CompletionStrategy>,
156 state: Arc<DeviceState>,
157 seed: u64,
158 ) -> Props<Self> {
159 Self::props_with_kind(
160 stream,
161 _allocator,
162 completion,
163 state,
164 seed,
165 RngGeneratorKind::default(),
166 )
167 }
168
169 pub fn props_with_kind(
172 stream: Arc<cudarc::driver::CudaStream>,
173 _allocator: Arc<dyn StreamAllocator>,
174 completion: Arc<dyn CompletionStrategy>,
175 state: Arc<DeviceState>,
176 seed: u64,
177 kind: RngGeneratorKind,
178 ) -> Props<Self> {
179 Props::create(move || {
180 let g = unsafe {
181 construct_generator(kind, &stream, seed).unwrap_or_else(|e| {
182 panic!("ContextPoisoned: cuRAND generator init failed ({kind:?}): {e}")
183 })
184 };
185 RngActor {
186 inner: RngInner::Real {
187 gen: Mutex::new(SendGen(g)),
188 kind: Mutex::new(kind),
189 stream: stream.clone(),
190 completion: completion.clone(),
191 state: state.clone(),
192 },
193 }
194 })
195 }
196
197 pub fn mock_props() -> Props<Self> {
198 Props::create(|| RngActor {
199 inner: RngInner::Mock,
200 })
201 }
202}
203
204#[async_trait]
205impl Actor for RngActor {
206 type Msg = RngMsg;
207
208 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: RngMsg) {
209 let (gen_lock, kind_lock, stream, completion) = match &self.inner {
210 RngInner::Mock => {
211 reply_mock(msg);
212 return;
213 }
214 RngInner::Real {
215 gen,
216 kind,
217 stream,
218 completion,
219 ..
220 } => (gen, kind, stream, completion),
221 };
222
223 #[allow(deprecated)]
224 match msg {
225 RngMsg::Fill(req) => {
226 let gen_handle = gen_lock.lock().0;
227 if let Err(e) = req.fill(gen_handle, stream, completion) {
228 tracing::warn!(lib = LIB, error = %e, "RngActor::Fill pre-launch error");
234 }
235 }
236 RngMsg::SetSeed { seed, reply } | RngMsg::Reseed { seed, reply } => {
237 let g = gen_lock.lock();
238 let active = *kind_lock.lock();
239 let r = if active.is_quasi() {
240 Ok(())
245 } else {
246 unsafe { csys::set_seed(g.0, seed) }.map_err(|e| GpuError::LibraryError {
247 lib: LIB,
248 msg: format!("set_seed: {e}"),
249 })
250 };
251 let _ = reply.send(r);
252 }
253 RngMsg::SetGenerator { kind, reply } => {
254 let mut g = gen_lock.lock();
255 let mut active = kind_lock.lock();
256 let r = unsafe {
257 let _ = csys::destroy_generator(g.0);
258 match construct_generator(kind, stream, 0) {
259 Ok(new_g) => {
260 g.0 = new_g;
261 *active = kind;
262 Ok(())
263 }
264 Err(e) => Err(GpuError::LibraryError {
265 lib: LIB,
266 msg: format!("set_generator({kind:?}): {e}"),
267 }),
268 }
269 };
270 let _ = reply.send(r);
271 }
272 RngMsg::FillUniformF32 { dst, reply } => {
274 let req = FillRequest::<f32> {
275 buf: dst,
276 dist: Distribution::Uniform { lo: 0.0, hi: 1.0 },
277 reply,
278 };
279 let gen_handle = gen_lock.lock().0;
280 let _ = Box::new(req).fill(gen_handle, stream, completion);
281 }
282 RngMsg::FillUniformF64 { dst, reply } => {
283 let req = FillRequest::<f64> {
284 buf: dst,
285 dist: Distribution::Uniform { lo: 0.0, hi: 1.0 },
286 reply,
287 };
288 let gen_handle = gen_lock.lock().0;
289 let _ = Box::new(req).fill(gen_handle, stream, completion);
290 }
291 RngMsg::FillUniformU32 { dst, reply } => {
292 let gen_handle = gen_lock.lock().0;
293 dist::fill_uniform_u32(gen_handle, stream, completion, dst, reply);
294 }
295 RngMsg::FillNormalF32 {
296 dst,
297 mean,
298 std,
299 reply,
300 } => {
301 let req = FillRequest::<f32> {
302 buf: dst,
303 dist: Distribution::Normal { mean, std },
304 reply,
305 };
306 let gen_handle = gen_lock.lock().0;
307 let _ = Box::new(req).fill(gen_handle, stream, completion);
308 }
309 RngMsg::FillNormalF64 {
310 dst,
311 mean,
312 std,
313 reply,
314 } => {
315 let req = FillRequest::<f64> {
316 buf: dst,
317 dist: Distribution::Normal { mean, std },
318 reply,
319 };
320 let gen_handle = gen_lock.lock().0;
321 let _ = Box::new(req).fill(gen_handle, stream, completion);
322 }
323 RngMsg::FillLogNormalF32 {
324 dst,
325 mean,
326 std,
327 reply,
328 } => {
329 let req = FillRequest::<f32> {
330 buf: dst,
331 dist: Distribution::LogNormal { mean, std },
332 reply,
333 };
334 let gen_handle = gen_lock.lock().0;
335 let _ = Box::new(req).fill(gen_handle, stream, completion);
336 }
337 }
338 }
339}
340
341pub(crate) unsafe fn construct_generator(
349 kind: RngGeneratorKind,
350 stream: &Arc<cudarc::driver::CudaStream>,
351 seed: u64,
352) -> Result<cudarc::curand::sys::curandGenerator_t, cudarc::curand::result::CurandError> {
353 let g = csys::create_generator(kind)?;
354 csys::set_stream(g, stream.cu_stream() as _)?;
355 if !kind.is_quasi() {
356 csys::set_seed(g, seed)?;
357 }
358 Ok(g)
359}
360
361impl Drop for RngActor {
362 fn drop(&mut self) {
363 if let RngInner::Real { gen, .. } = &self.inner {
364 let g = gen.lock();
365 if !g.0.is_null() {
366 let _ = unsafe { csys::destroy_generator(g.0) };
367 }
368 }
369 }
370}
371
372#[allow(deprecated)]
373fn reply_mock(msg: RngMsg) {
374 let err = || GpuError::Unrecoverable("RngActor in mock mode".into());
375 match msg {
376 RngMsg::Fill(req) => {
377 drop(req);
382 }
383 RngMsg::SetSeed { reply, .. }
384 | RngMsg::SetGenerator { reply, .. }
385 | RngMsg::Reseed { reply, .. } => {
386 let _ = reply.send(Err(err()));
387 }
388 RngMsg::FillUniformF32 { reply, .. }
389 | RngMsg::FillNormalF32 { reply, .. }
390 | RngMsg::FillLogNormalF32 { reply, .. } => {
391 let _ = reply.send(Err(err()));
392 }
393 RngMsg::FillUniformF64 { reply, .. } | RngMsg::FillNormalF64 { reply, .. } => {
394 let _ = reply.send(Err(err()));
395 }
396 RngMsg::FillUniformU32 { reply, .. } => {
397 let _ = reply.send(Err(err()));
398 }
399 }
400}
401
402pub mod props {
406 pub use super::dist::{Distribution, FillRequest};
407 pub use super::{RngActor, RngGeneratorKind, RngMsg};
408}
409
410pub fn _capability_marker_compile_fail_doc<T: RngFloatSupported>(_: T::Scalar) {}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn rng_msg_legacy_variants_present() {
442 #[allow(deprecated)]
445 fn _accept(m: RngMsg) {
446 match m {
447 RngMsg::FillUniformF32 { .. } => {}
448 RngMsg::FillUniformF64 { .. } => {}
449 RngMsg::FillUniformU32 { .. } => {}
450 RngMsg::FillNormalF32 { .. } => {}
451 RngMsg::FillNormalF64 { .. } => {}
452 RngMsg::FillLogNormalF32 { .. } => {}
453 RngMsg::Reseed { .. } => {}
454 RngMsg::Fill(_) | RngMsg::SetSeed { .. } | RngMsg::SetGenerator { .. } => {}
455 }
456 }
457 }
458
459 #[test]
460 fn set_generator_kind_round_trip() {
461 let all = [
466 RngGeneratorKind::PseudoDefault,
467 RngGeneratorKind::Philox4_32_10,
468 RngGeneratorKind::XorWow,
469 RngGeneratorKind::Mrg32K3A,
470 RngGeneratorKind::Mtgp32,
471 RngGeneratorKind::Sobol32,
472 RngGeneratorKind::ScrambledSobol32,
473 RngGeneratorKind::Sobol64,
474 RngGeneratorKind::ScrambledSobol64,
475 ];
476 let mut seen = std::collections::HashSet::new();
477 for k in all {
478 let v = k.to_sys() as u32;
479 assert!(seen.insert(v), "duplicate sys value for {k:?}");
480 assert_eq!(k.is_quasi(), (v as i32) >= 200);
481 }
482 }
483}