Skip to main content

atomr_accel_cuda/kernel/rng/
mod.rs

1//! `RngActor` — wraps a cuRAND `curandGenerator_t` handle and fills
2//! device buffers with the full distribution matrix.
3//!
4//! Phase 1 cuRAND surface (vs. F2):
5//!
6//! * **Explicit generator selection** via [`RngGeneratorKind`]
7//!   (Philox4_32_10, XORWOW, MTGP32, MRG32K3A, plus all four Sobol
8//!   variants). [`RngMsg::SetGenerator`] reconstructs the handle in
9//!   place so callers can switch families at runtime.
10//! * **Distribution matrix** (`Uniform`, `Normal`, `LogNormal`,
11//!   `Poisson`, `Exponential`, `Beta`, `Cauchy`, `Gamma`,
12//!   `Discrete`) routed through [`Distribution<T>`] →
13//!   [`FillRequest<T>`] → `RngDispatch::fill`.
14//! * **Quasi-random Sobol** parallel to pseudo-random: see
15//!   [`sobol`] for dimension configuration.
16//! * **Host API parallel to device API** under the `curand-host`
17//!   feature: see [`host`].
18//!
19//! Reseed model: `SetSeed { seed }` calls
20//! [`crate::sys::curand::set_seed`] in place — no panic-restart.
21//! Reseed is a control-plane op; restart-on-reseed would tear down
22//! all in-flight work. The seed is journaled by `ReplayHarness` (F5)
23//! so deterministic replay still works.
24
25use std::sync::Arc;
26
27use async_trait::async_trait;
28use atomr_core::actor::{Actor, Context, Props};
29use parking_lot::Mutex;
30use tokio::sync::oneshot;
31
32use crate::completion::CompletionStrategy;
33use crate::device::DeviceState;
34use crate::dtype::RngFloatSupported;
35use crate::error::GpuError;
36use crate::gpu_ref::GpuRef;
37use crate::kernel::dispatch::RngDispatch;
38use crate::stream::StreamAllocator;
39use crate::sys::curand as csys;
40
41pub mod dist;
42#[cfg(feature = "curand-host")]
43pub mod host;
44#[cfg(feature = "curand-quasirandom")]
45pub mod sobol;
46
47pub use crate::sys::curand::RngGeneratorKind;
48pub use dist::{Distribution, FillRequest};
49
50pub(crate) const LIB: &str = "curand";
51
52/// Public messages for [`RngActor`].
53///
54/// Two-track API:
55///
56/// * **Modern** — [`RngMsg::Fill`], [`RngMsg::SetSeed`],
57///   [`RngMsg::SetGenerator`]. Callers build a [`FillRequest<T>`],
58///   wrap it in `Box<dyn RngDispatch>`, and send it as
59///   `RngMsg::Fill(Box::new(req))`.
60/// * **Legacy** — `Fill{Uniform,Normal,LogNormal}*` plus `Reseed`,
61///   preserved for F2 callers. Marked `#[deprecated]`.
62#[non_exhaustive]
63pub enum RngMsg {
64    /// Type-erased dispatch: see [`RngDispatch`].
65    Fill(Box<dyn RngDispatch>),
66    /// Re-seed the **active** generator (no-op for quasi generators).
67    SetSeed {
68        seed: u64,
69        reply: oneshot::Sender<Result<(), GpuError>>,
70    },
71    /// Tear down the current generator and reconstruct it as `kind`.
72    /// Pseudo→quasi (or vice-versa) is supported. Quasi generators
73    /// take effect with the default 1-dimensional Sobol; use
74    /// [`sobol::SetDimensions`] to widen.
75    SetGenerator {
76        kind: RngGeneratorKind,
77        reply: oneshot::Sender<Result<(), GpuError>>,
78    },
79    #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
80    FillUniformF32 {
81        dst: GpuRef<f32>,
82        reply: oneshot::Sender<Result<(), GpuError>>,
83    },
84    #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
85    FillUniformF64 {
86        dst: GpuRef<f64>,
87        reply: oneshot::Sender<Result<(), GpuError>>,
88    },
89    #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
90    FillUniformU32 {
91        dst: GpuRef<u32>,
92        reply: oneshot::Sender<Result<(), GpuError>>,
93    },
94    #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
95    FillNormalF32 {
96        dst: GpuRef<f32>,
97        mean: f32,
98        std: f32,
99        reply: oneshot::Sender<Result<(), GpuError>>,
100    },
101    #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
102    FillNormalF64 {
103        dst: GpuRef<f64>,
104        mean: f64,
105        std: f64,
106        reply: oneshot::Sender<Result<(), GpuError>>,
107    },
108    #[deprecated(note = "use RngMsg::Fill(Box::new(FillRequest { ... })) instead")]
109    FillLogNormalF32 {
110        dst: GpuRef<f32>,
111        mean: f32,
112        std: f32,
113        reply: oneshot::Sender<Result<(), GpuError>>,
114    },
115    #[deprecated(note = "use RngMsg::SetSeed { seed, reply } instead")]
116    Reseed {
117        seed: u64,
118        reply: oneshot::Sender<Result<(), GpuError>>,
119    },
120}
121
122/// `curandGenerator_t` is a raw `*mut curandGenerator_st` and so is
123/// `!Send + !Sync`. The actor runs exclusively on
124/// [`crate::dispatcher::GpuDispatcher`]'s pinned thread; we assert
125/// `Send + Sync` via this newtype so atomr's `Actor: Send + 'static`
126/// bound is satisfied.
127pub(crate) struct SendGen(pub(crate) cudarc::curand::sys::curandGenerator_t);
128
129// SAFETY: the generator is only ever touched from the GpuDispatcher's
130// pinned OS thread; the outer parking_lot::Mutex enforces exclusion
131// against any actor handler running there.
132unsafe impl Send for SendGen {}
133unsafe impl Sync for SendGen {}
134
135pub struct RngActor {
136    inner: RngInner,
137}
138
139pub(crate) enum RngInner {
140    Real {
141        gen: Mutex<SendGen>,
142        kind: Mutex<RngGeneratorKind>,
143        stream: Arc<cudarc::driver::CudaStream>,
144        completion: Arc<dyn CompletionStrategy>,
145        #[allow(dead_code)]
146        state: Arc<DeviceState>,
147    },
148    Mock,
149}
150
151impl RngActor {
152    pub fn props(
153        stream: Arc<cudarc::driver::CudaStream>,
154        _allocator: Arc<dyn StreamAllocator>,
155        completion: Arc<dyn CompletionStrategy>,
156        state: Arc<DeviceState>,
157        seed: u64,
158    ) -> Props<Self> {
159        Self::props_with_kind(
160            stream,
161            _allocator,
162            completion,
163            state,
164            seed,
165            RngGeneratorKind::default(),
166        )
167    }
168
169    /// Same as [`Self::props`] but lets the caller pick the cuRAND
170    /// generator family upfront.
171    pub fn props_with_kind(
172        stream: Arc<cudarc::driver::CudaStream>,
173        _allocator: Arc<dyn StreamAllocator>,
174        completion: Arc<dyn CompletionStrategy>,
175        state: Arc<DeviceState>,
176        seed: u64,
177        kind: RngGeneratorKind,
178    ) -> Props<Self> {
179        Props::create(move || {
180            let g = unsafe {
181                construct_generator(kind, &stream, seed).unwrap_or_else(|e| {
182                    panic!("ContextPoisoned: cuRAND generator init failed ({kind:?}): {e}")
183                })
184            };
185            RngActor {
186                inner: RngInner::Real {
187                    gen: Mutex::new(SendGen(g)),
188                    kind: Mutex::new(kind),
189                    stream: stream.clone(),
190                    completion: completion.clone(),
191                    state: state.clone(),
192                },
193            }
194        })
195    }
196
197    pub fn mock_props() -> Props<Self> {
198        Props::create(|| RngActor {
199            inner: RngInner::Mock,
200        })
201    }
202}
203
204#[async_trait]
205impl Actor for RngActor {
206    type Msg = RngMsg;
207
208    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: RngMsg) {
209        let (gen_lock, kind_lock, stream, completion) = match &self.inner {
210            RngInner::Mock => {
211                reply_mock(msg);
212                return;
213            }
214            RngInner::Real {
215                gen,
216                kind,
217                stream,
218                completion,
219                ..
220            } => (gen, kind, stream, completion),
221        };
222
223        #[allow(deprecated)]
224        match msg {
225            RngMsg::Fill(req) => {
226                let gen_handle = gen_lock.lock().0;
227                if let Err(e) = req.fill(gen_handle, stream, completion) {
228                    // RngDispatch::fill is responsible for sending its
229                    // own reply on success; on Err the reply is also
230                    // expected to have been sent by the impl. The
231                    // returned error is therefore advisory — log it
232                    // for tracing parity with other actors.
233                    tracing::warn!(lib = LIB, error = %e, "RngActor::Fill pre-launch error");
234                }
235            }
236            RngMsg::SetSeed { seed, reply } | RngMsg::Reseed { seed, reply } => {
237                let g = gen_lock.lock();
238                let active = *kind_lock.lock();
239                let r = if active.is_quasi() {
240                    // Quasi generators don't accept a pseudo seed —
241                    // cuRAND returns CURAND_STATUS_TYPE_ERROR. Treat
242                    // SetSeed on a quasi RNG as a no-op so callers
243                    // can journal a single seed regardless of family.
244                    Ok(())
245                } else {
246                    unsafe { csys::set_seed(g.0, seed) }.map_err(|e| GpuError::LibraryError {
247                        lib: LIB,
248                        msg: format!("set_seed: {e}"),
249                    })
250                };
251                let _ = reply.send(r);
252            }
253            RngMsg::SetGenerator { kind, reply } => {
254                let mut g = gen_lock.lock();
255                let mut active = kind_lock.lock();
256                let r = unsafe {
257                    let _ = csys::destroy_generator(g.0);
258                    match construct_generator(kind, stream, 0) {
259                        Ok(new_g) => {
260                            g.0 = new_g;
261                            *active = kind;
262                            Ok(())
263                        }
264                        Err(e) => Err(GpuError::LibraryError {
265                            lib: LIB,
266                            msg: format!("set_generator({kind:?}): {e}"),
267                        }),
268                    }
269                };
270                let _ = reply.send(r);
271            }
272            // Legacy variants — translate into the modern path.
273            RngMsg::FillUniformF32 { dst, reply } => {
274                let req = FillRequest::<f32> {
275                    buf: dst,
276                    dist: Distribution::Uniform { lo: 0.0, hi: 1.0 },
277                    reply,
278                };
279                let gen_handle = gen_lock.lock().0;
280                let _ = Box::new(req).fill(gen_handle, stream, completion);
281            }
282            RngMsg::FillUniformF64 { dst, reply } => {
283                let req = FillRequest::<f64> {
284                    buf: dst,
285                    dist: Distribution::Uniform { lo: 0.0, hi: 1.0 },
286                    reply,
287                };
288                let gen_handle = gen_lock.lock().0;
289                let _ = Box::new(req).fill(gen_handle, stream, completion);
290            }
291            RngMsg::FillUniformU32 { dst, reply } => {
292                let gen_handle = gen_lock.lock().0;
293                dist::fill_uniform_u32(gen_handle, stream, completion, dst, reply);
294            }
295            RngMsg::FillNormalF32 {
296                dst,
297                mean,
298                std,
299                reply,
300            } => {
301                let req = FillRequest::<f32> {
302                    buf: dst,
303                    dist: Distribution::Normal { mean, std },
304                    reply,
305                };
306                let gen_handle = gen_lock.lock().0;
307                let _ = Box::new(req).fill(gen_handle, stream, completion);
308            }
309            RngMsg::FillNormalF64 {
310                dst,
311                mean,
312                std,
313                reply,
314            } => {
315                let req = FillRequest::<f64> {
316                    buf: dst,
317                    dist: Distribution::Normal { mean, std },
318                    reply,
319                };
320                let gen_handle = gen_lock.lock().0;
321                let _ = Box::new(req).fill(gen_handle, stream, completion);
322            }
323            RngMsg::FillLogNormalF32 {
324                dst,
325                mean,
326                std,
327                reply,
328            } => {
329                let req = FillRequest::<f32> {
330                    buf: dst,
331                    dist: Distribution::LogNormal { mean, std },
332                    reply,
333                };
334                let gen_handle = gen_lock.lock().0;
335                let _ = Box::new(req).fill(gen_handle, stream, completion);
336            }
337        }
338    }
339}
340
341/// Build a fresh cuRAND generator of `kind`, bind it to `stream`, and
342/// (for pseudo families) seed it. Used by both [`RngActor::props`] and
343/// [`RngMsg::SetGenerator`].
344///
345/// # Safety
346/// The returned handle is owned by the caller. It must be released
347/// through [`csys::destroy_generator`].
348pub(crate) unsafe fn construct_generator(
349    kind: RngGeneratorKind,
350    stream: &Arc<cudarc::driver::CudaStream>,
351    seed: u64,
352) -> Result<cudarc::curand::sys::curandGenerator_t, cudarc::curand::result::CurandError> {
353    let g = csys::create_generator(kind)?;
354    csys::set_stream(g, stream.cu_stream() as _)?;
355    if !kind.is_quasi() {
356        csys::set_seed(g, seed)?;
357    }
358    Ok(g)
359}
360
361impl Drop for RngActor {
362    fn drop(&mut self) {
363        if let RngInner::Real { gen, .. } = &self.inner {
364            let g = gen.lock();
365            if !g.0.is_null() {
366                let _ = unsafe { csys::destroy_generator(g.0) };
367            }
368        }
369    }
370}
371
372#[allow(deprecated)]
373fn reply_mock(msg: RngMsg) {
374    let err = || GpuError::Unrecoverable("RngActor in mock mode".into());
375    match msg {
376        RngMsg::Fill(req) => {
377            // Drop the boxed dispatch. We can't fish out the reply
378            // sender generically, so the caller observes the channel
379            // close as a "Cancelled" / send error — same as the F2
380            // mock semantics for any unsupported variant.
381            drop(req);
382        }
383        RngMsg::SetSeed { reply, .. }
384        | RngMsg::SetGenerator { reply, .. }
385        | RngMsg::Reseed { reply, .. } => {
386            let _ = reply.send(Err(err()));
387        }
388        RngMsg::FillUniformF32 { reply, .. }
389        | RngMsg::FillNormalF32 { reply, .. }
390        | RngMsg::FillLogNormalF32 { reply, .. } => {
391            let _ = reply.send(Err(err()));
392        }
393        RngMsg::FillUniformF64 { reply, .. } | RngMsg::FillNormalF64 { reply, .. } => {
394            let _ = reply.send(Err(err()));
395        }
396        RngMsg::FillUniformU32 { reply, .. } => {
397            let _ = reply.send(Err(err()));
398        }
399    }
400}
401
402/// Re-exported here so callers can `use atomr_accel_cuda::kernel::rng::props::*`
403/// for the public surface without remembering which module defines
404/// what.
405pub mod props {
406    pub use super::dist::{Distribution, FillRequest};
407    pub use super::{RngActor, RngGeneratorKind, RngMsg};
408}
409
410// ---------------------------------------------------------------------
411// Capability-marker compile-fail check.
412//
413// `FillRequest<T>` is parameterised by `T: RngFloatSupported`, which
414// is implemented for `f32` and `f64` only. A call site that tries to
415// instantiate `FillRequest<u32>` must fail to compile. We can't run
416// `compile_fail` doctests on the test target (no `pub use` of GpuRef
417// inside this module path), so the check lives in a docstring under
418// the publicly-reachable [`FillRequest`] re-export below.
419// ---------------------------------------------------------------------
420
421/// Compile-fail proof that [`FillRequest`] rejects non-float dtypes.
422///
423/// ```compile_fail
424/// use atomr_accel_cuda::kernel::{Distribution, FillRequest};
425/// fn _bad(b: atomr_accel_cuda::gpu_ref::GpuRef<u32>) {
426///     let (tx, _rx) = tokio::sync::oneshot::channel();
427///     let _r: FillRequest<u32> = FillRequest {
428///         buf: b,
429///         dist: Distribution::Uniform { lo: 0u32, hi: 1u32 },
430///         reply: tx,
431///     };
432/// }
433/// ```
434pub fn _capability_marker_compile_fail_doc<T: RngFloatSupported>(_: T::Scalar) {}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn rng_msg_legacy_variants_present() {
442        // Ensure deprecated names remain in the API surface so older
443        // callers compile (they only get a deprecation warning).
444        #[allow(deprecated)]
445        fn _accept(m: RngMsg) {
446            match m {
447                RngMsg::FillUniformF32 { .. } => {}
448                RngMsg::FillUniformF64 { .. } => {}
449                RngMsg::FillUniformU32 { .. } => {}
450                RngMsg::FillNormalF32 { .. } => {}
451                RngMsg::FillNormalF64 { .. } => {}
452                RngMsg::FillLogNormalF32 { .. } => {}
453                RngMsg::Reseed { .. } => {}
454                RngMsg::Fill(_) | RngMsg::SetSeed { .. } | RngMsg::SetGenerator { .. } => {}
455            }
456        }
457    }
458
459    #[test]
460    fn set_generator_kind_round_trip() {
461        // Round-trips every variant through `to_sys` to make sure no
462        // arm panics or returns a stale numeric value. (Real handle
463        // creation requires a CUDA context; that's covered by the
464        // GPU-runtime e2e suite.)
465        let all = [
466            RngGeneratorKind::PseudoDefault,
467            RngGeneratorKind::Philox4_32_10,
468            RngGeneratorKind::XorWow,
469            RngGeneratorKind::Mrg32K3A,
470            RngGeneratorKind::Mtgp32,
471            RngGeneratorKind::Sobol32,
472            RngGeneratorKind::ScrambledSobol32,
473            RngGeneratorKind::Sobol64,
474            RngGeneratorKind::ScrambledSobol64,
475        ];
476        let mut seen = std::collections::HashSet::new();
477        for k in all {
478            let v = k.to_sys() as u32;
479            assert!(seen.insert(v), "duplicate sys value for {k:?}");
480            assert_eq!(k.is_quasi(), (v as i32) >= 200);
481        }
482    }
483}