atomr_accel_cuda/sys/curand.rs
1//! Thin Rust-level wrappers over [`cudarc::curand::sys`] for the
2//! cuRAND host-API entry points that aren't surfaced by the safe
3//! [`cudarc::curand::CudaRng`] handle:
4//!
5//! * generator creation by **explicit** `curandRngType_t`
6//! (`Philox4_32_10`, `XORWOW`, `MTGP32`, `MRG32K3A`, plus all four
7//! Sobol variants);
8//! * the **host-API** generator pair (`curandCreateGeneratorHost` —
9//! fills *host* buffers, copies internally);
10//! * `curandSetQuasiRandomGeneratorDimensions`,
11//! `curandSetGeneratorOrdering`, `curandSetGeneratorOffset`;
12//! * the Poisson and bit-generator families
13//! (`curandGeneratePoisson`, `curandGenerate`,
14//! `curandGenerateLongLong`).
15//!
16//! These functions are **unsafe** — they take the raw
17//! `curandGenerator_t` and require the caller to keep the generator
18//! alive, valid, and bound to a stream that the destination pointer
19//! lives on. The actor in `kernel/rng/*` is the only intended caller.
20
21use std::mem::MaybeUninit;
22
23use cudarc::curand::result::CurandError;
24use cudarc::curand::sys;
25
26/// Public mirror of [`sys::curandRngType_t`] so callers don't have to
27/// take a `cudarc::curand::sys::*` symbol on their public API. The
28/// numeric values match cuRAND.
29#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)]
30pub enum RngGeneratorKind {
31 /// `CURAND_RNG_PSEUDO_DEFAULT` — XORWOW today; defined by cuRAND.
32 #[default]
33 PseudoDefault,
34 /// `CURAND_RNG_PSEUDO_PHILOX4_32_10`. Recommended high-quality
35 /// pseudo-RNG. Counter-based, friendly to SIMD/SIMT.
36 Philox4_32_10,
37 /// `CURAND_RNG_PSEUDO_XORWOW`. Default in cuRAND <= 11.
38 XorWow,
39 /// `CURAND_RNG_PSEUDO_MRG32K3A`. L'Ecuyer's MRG.
40 Mrg32K3A,
41 /// `CURAND_RNG_PSEUDO_MTGP32`. Mersenne Twister 32-bit.
42 Mtgp32,
43 /// `CURAND_RNG_QUASI_SOBOL32`. Quasi-random 32-bit Sobol.
44 Sobol32,
45 /// `CURAND_RNG_QUASI_SCRAMBLED_SOBOL32`.
46 ScrambledSobol32,
47 /// `CURAND_RNG_QUASI_SOBOL64`.
48 Sobol64,
49 /// `CURAND_RNG_QUASI_SCRAMBLED_SOBOL64`.
50 ScrambledSobol64,
51}
52
53impl RngGeneratorKind {
54 /// Whether this kind is a quasi-random (Sobol) generator. Quasi
55 /// generators must be configured with
56 /// [`set_quasi_random_dimensions`] before any fill, and they do
57 /// not accept a pseudo-random seed.
58 pub fn is_quasi(self) -> bool {
59 matches!(
60 self,
61 Self::Sobol32 | Self::ScrambledSobol32 | Self::Sobol64 | Self::ScrambledSobol64
62 )
63 }
64
65 /// 64-bit Sobol vs. 32-bit Sobol. Kept separate from [`Self::is_quasi`]
66 /// so callers can pick the matching `curandSetQuasiRandomGeneratorDimensions`
67 /// argument width without re-matching.
68 pub fn is_quasi_64(self) -> bool {
69 matches!(self, Self::Sobol64 | Self::ScrambledSobol64)
70 }
71
72 pub fn to_sys(self) -> sys::curandRngType_t {
73 match self {
74 Self::PseudoDefault => sys::curandRngType_t::CURAND_RNG_PSEUDO_DEFAULT,
75 Self::Philox4_32_10 => sys::curandRngType_t::CURAND_RNG_PSEUDO_PHILOX4_32_10,
76 Self::XorWow => sys::curandRngType_t::CURAND_RNG_PSEUDO_XORWOW,
77 Self::Mrg32K3A => sys::curandRngType_t::CURAND_RNG_PSEUDO_MRG32K3A,
78 Self::Mtgp32 => sys::curandRngType_t::CURAND_RNG_PSEUDO_MTGP32,
79 Self::Sobol32 => sys::curandRngType_t::CURAND_RNG_QUASI_SOBOL32,
80 Self::ScrambledSobol32 => sys::curandRngType_t::CURAND_RNG_QUASI_SCRAMBLED_SOBOL32,
81 Self::Sobol64 => sys::curandRngType_t::CURAND_RNG_QUASI_SOBOL64,
82 Self::ScrambledSobol64 => sys::curandRngType_t::CURAND_RNG_QUASI_SCRAMBLED_SOBOL64,
83 }
84 }
85}
86
87/// Create a device-side generator of the given `kind`
88/// (`curandCreateGenerator`).
89///
90/// # Safety
91/// The returned handle must be released with
92/// [`destroy_generator`] before drop, and a stream must be bound via
93/// [`set_stream`] before any fill is enqueued.
94pub unsafe fn create_generator(
95 kind: RngGeneratorKind,
96) -> Result<sys::curandGenerator_t, CurandError> {
97 let mut g = MaybeUninit::uninit();
98 sys::curandCreateGenerator(g.as_mut_ptr(), kind.to_sys()).result()?;
99 Ok(g.assume_init())
100}
101
102/// Create a host-API generator of the given `kind`
103/// (`curandCreateGeneratorHost`). The handle's `curandGenerate*`
104/// targets must be **host-resident** memory.
105///
106/// # Safety
107/// Same lifecycle rules as [`create_generator`]; in addition the
108/// destination pointers passed to subsequent generate calls must be
109/// host pointers, *not* device pointers.
110#[cfg(feature = "curand-host")]
111pub unsafe fn create_generator_host(
112 kind: RngGeneratorKind,
113) -> Result<sys::curandGenerator_t, CurandError> {
114 let mut g = MaybeUninit::uninit();
115 sys::curandCreateGeneratorHost(g.as_mut_ptr(), kind.to_sys()).result()?;
116 Ok(g.assume_init())
117}
118
119/// Bind `gen` to `stream` (`curandSetStream`).
120///
121/// # Safety
122/// `gen` must be live; `stream` must be the raw cuStream pointer from
123/// a `cudarc::driver::CudaStream` whose context owns `gen`.
124pub unsafe fn set_stream(
125 gen: sys::curandGenerator_t,
126 stream: sys::cudaStream_t,
127) -> Result<(), CurandError> {
128 sys::curandSetStream(gen, stream).result()
129}
130
131/// `curandSetPseudoRandomGeneratorSeed`.
132///
133/// # Safety
134/// `gen` must be live and pseudo-random.
135pub unsafe fn set_seed(gen: sys::curandGenerator_t, seed: u64) -> Result<(), CurandError> {
136 sys::curandSetPseudoRandomGeneratorSeed(gen, seed).result()
137}
138
139/// `curandSetGeneratorOffset`.
140///
141/// # Safety
142/// `gen` must be live.
143pub unsafe fn set_offset(gen: sys::curandGenerator_t, offset: u64) -> Result<(), CurandError> {
144 sys::curandSetGeneratorOffset(gen, offset).result()
145}
146
147/// `curandSetQuasiRandomGeneratorDimensions`.
148///
149/// # Safety
150/// `gen` must be a live quasi-random generator. `dimensions` must be
151/// >= 1. Length-allocated buffer per dimension is implementation-defined
152/// (usually 20000 for Sobol32, fewer for Sobol64).
153#[cfg(feature = "curand-quasirandom")]
154pub unsafe fn set_quasi_random_dimensions(
155 gen: sys::curandGenerator_t,
156 dimensions: u32,
157) -> Result<(), CurandError> {
158 sys::curandSetQuasiRandomGeneratorDimensions(gen, dimensions).result()
159}
160
161/// `curandDestroyGenerator`.
162///
163/// # Safety
164/// `gen` must not have been destroyed already. After this call the
165/// pointer is dangling.
166pub unsafe fn destroy_generator(gen: sys::curandGenerator_t) -> Result<(), CurandError> {
167 sys::curandDestroyGenerator(gen).result()
168}
169
170/// `curandGeneratePoisson` — fill `out` with `n` u32 values drawn
171/// from a Poisson distribution parameterised by `lambda` (f64).
172///
173/// # Safety
174/// `gen` must be live and bound to the same stream/context as the
175/// device pointer `out`. `out` must point to at least `n` u32 slots.
176pub unsafe fn generate_poisson_u32(
177 gen: sys::curandGenerator_t,
178 out: *mut u32,
179 n: usize,
180 lambda: f64,
181) -> Result<(), CurandError> {
182 sys::curandGeneratePoisson(gen, out, n, lambda).result()
183}
184
185/// `curandGenerate` — raw u32 bit fill (used as the building block
186/// for any custom transform).
187///
188/// # Safety
189/// Same invariants as [`generate_poisson_u32`].
190pub unsafe fn generate_u32(
191 gen: sys::curandGenerator_t,
192 out: *mut u32,
193 n: usize,
194) -> Result<(), CurandError> {
195 sys::curandGenerate(gen, out, n).result()
196}
197
198/// `curandGenerateLongLong` — raw u64 bit fill.
199///
200/// # Safety
201/// `gen` must be a 64-bit quasi-random or pseudo-random generator
202/// supporting long-long output.
203pub unsafe fn generate_u64(
204 gen: sys::curandGenerator_t,
205 out: *mut u64,
206 n: usize,
207) -> Result<(), CurandError> {
208 sys::curandGenerateLongLong(gen, out as *mut std::os::raw::c_ulonglong, n).result()
209}