1use std::sync::Arc;
34
35use cudarc::curand::result::{LogNormalFill, NormalFill, UniformFill};
36use cudarc::curand::sys;
37use tokio::sync::oneshot;
38
39use crate::completion::CompletionStrategy;
40use crate::dtype::RngFloatSupported;
41use crate::error::GpuError;
42use crate::gpu_ref::GpuRef;
43use crate::kernel::dispatch::RngDispatch;
44use crate::kernel::envelope;
45
46use super::LIB;
47
48pub enum Distribution<T: RngFloatSupported> {
52 Uniform {
53 lo: T::Scalar,
54 hi: T::Scalar,
55 },
56 Normal {
57 mean: T::Scalar,
58 std: T::Scalar,
59 },
60 LogNormal {
61 mean: T::Scalar,
62 std: T::Scalar,
63 },
64 Poisson {
68 lambda: f64,
69 },
70 Exponential {
71 lambda: T::Scalar,
72 },
73 Beta {
74 alpha: T::Scalar,
75 beta: T::Scalar,
76 },
77 Cauchy {
78 loc: T::Scalar,
79 scale: T::Scalar,
80 },
81 Gamma {
82 shape: T::Scalar,
83 scale: T::Scalar,
84 },
85 Discrete {
86 weights: GpuRef<f32>,
87 },
88}
89
90pub struct FillRequest<T: RngFloatSupported> {
93 pub buf: GpuRef<T>,
94 pub dist: Distribution<T>,
95 pub reply: oneshot::Sender<Result<(), GpuError>>,
96}
97
98impl RngDispatch for FillRequest<f32> {
105 fn fill(
106 self: Box<Self>,
107 gen: sys::curandGenerator_t,
108 stream: &Arc<cudarc::driver::CudaStream>,
109 completion: &Arc<dyn CompletionStrategy>,
110 ) -> Result<(), GpuError> {
111 fill_float::<f32>(*self, gen, stream, completion)
112 }
113}
114
115impl RngDispatch for FillRequest<f64> {
116 fn fill(
117 self: Box<Self>,
118 gen: sys::curandGenerator_t,
119 stream: &Arc<cudarc::driver::CudaStream>,
120 completion: &Arc<dyn CompletionStrategy>,
121 ) -> Result<(), GpuError> {
122 fill_float::<f64>(*self, gen, stream, completion)
123 }
124}
125
126fn fill_float<T>(
127 req: FillRequest<T>,
128 gen: sys::curandGenerator_t,
129 stream: &Arc<cudarc::driver::CudaStream>,
130 completion: &Arc<dyn CompletionStrategy>,
131) -> Result<(), GpuError>
132where
133 T: RngFloatSupported,
134 sys::curandGenerator_t: UniformFill<T> + NormalFill<T> + LogNormalFill<T>,
135 T::Scalar: Into<f64> + Copy,
136 T: NormalParam<T::Scalar>,
137{
138 let FillRequest { buf, dist, reply } = req;
139
140 match dist {
141 Distribution::Uniform { lo, hi } => {
142 enqueue_uniform::<T>(gen, stream, completion, buf, reply, lo, hi)
149 }
150 Distribution::Normal { mean, std } => {
151 enqueue_normal::<T>(gen, stream, completion, buf, mean, std, reply)
152 }
153 Distribution::LogNormal { mean, std } => {
154 enqueue_log_normal::<T>(gen, stream, completion, buf, mean, std, reply)
155 }
156 Distribution::Poisson { lambda } => {
157 let _ = (gen, stream, completion, buf);
161 let _ = lambda;
162 let _ = reply.send(Err(GpuError::LibraryError {
163 lib: LIB,
164 msg:
165 "Poisson<T> not yet wired for floats (Phase 1: use FillRequest<u32> + Poisson)"
166 .into(),
167 }));
168 Ok(())
169 }
170 Distribution::Exponential { .. }
171 | Distribution::Beta { .. }
172 | Distribution::Cauchy { .. }
173 | Distribution::Gamma { .. }
174 | Distribution::Discrete { .. } => {
175 let _ = (gen, stream, completion, buf);
176 let _ = reply.send(Err(GpuError::LibraryError {
177 lib: LIB,
178 msg: "distribution not yet wired (Phase 1: needs custom kernel / NVRTC)".into(),
179 }));
180 Ok(())
181 }
182 }
183}
184
185fn enqueue_uniform<T>(
189 gen: sys::curandGenerator_t,
190 stream: &Arc<cudarc::driver::CudaStream>,
191 completion: &Arc<dyn CompletionStrategy>,
192 dst: GpuRef<T>,
193 reply: oneshot::Sender<Result<(), GpuError>>,
194 lo: T::Scalar,
195 hi: T::Scalar,
196) -> Result<(), GpuError>
197where
198 T: RngFloatSupported,
199 T::Scalar: Into<f64> + Copy,
200 sys::curandGenerator_t: UniformFill<T>,
201{
202 let lo_f: f64 = lo.into();
203 let hi_f: f64 = hi.into();
204 let trivial = lo_f == 0.0 && hi_f == 1.0;
205
206 let dst_arc = match dst.access() {
207 Ok(s) => s.clone(),
208 Err(e) => {
209 let _ = reply.send(Err(e));
210 return Ok(());
211 }
212 };
213 let mut owned = match Arc::try_unwrap(dst_arc) {
214 Ok(s) => s,
215 Err(_) => {
216 let _ = reply.send(Err(GpuError::Unrecoverable(
217 "RNG dst has multiple live references".into(),
218 )));
219 return Ok(());
220 }
221 };
222 if !trivial {
223 let _ = reply.send(Err(GpuError::LibraryError {
224 lib: LIB,
225 msg: format!(
226 "Uniform({lo_f},{hi_f}): non-(0,1] bounds need an affine transform kernel (Phase 1: not wired)"
227 ),
228 }));
229 return Ok(());
230 }
231
232 dst.record_write(stream);
233 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
234 let n = owned.len();
237 let res = unsafe {
238 let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
239 UniformFill::fill(gen, ptr as *mut T, n)
240 };
241 res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
242 lib: LIB,
243 msg: format!("fill_uniform: {e}"),
244 })
245 });
246 Ok(())
247}
248
249fn enqueue_normal<T>(
250 gen: sys::curandGenerator_t,
251 stream: &Arc<cudarc::driver::CudaStream>,
252 completion: &Arc<dyn CompletionStrategy>,
253 dst: GpuRef<T>,
254 mean: T::Scalar,
255 std: T::Scalar,
256 reply: oneshot::Sender<Result<(), GpuError>>,
257) -> Result<(), GpuError>
258where
259 T: RngFloatSupported + NormalParam<T::Scalar>,
260 sys::curandGenerator_t: NormalFill<T>,
261{
262 let dst_arc = match dst.access() {
263 Ok(s) => s.clone(),
264 Err(e) => {
265 let _ = reply.send(Err(e));
266 return Ok(());
267 }
268 };
269 let mut owned = match Arc::try_unwrap(dst_arc) {
270 Ok(s) => s,
271 Err(_) => {
272 let _ = reply.send(Err(GpuError::Unrecoverable(
273 "RNG dst has multiple live references".into(),
274 )));
275 return Ok(());
276 }
277 };
278 let mean_t = T::from_scalar(mean);
279 let std_t = T::from_scalar(std);
280 dst.record_write(stream);
281 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
282 let n = owned.len();
283 let res = unsafe {
284 let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
285 NormalFill::fill(gen, ptr as *mut T, n, mean_t, std_t)
286 };
287 res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
288 lib: LIB,
289 msg: format!("fill_normal: {e}"),
290 })
291 });
292 Ok(())
293}
294
295fn enqueue_log_normal<T>(
296 gen: sys::curandGenerator_t,
297 stream: &Arc<cudarc::driver::CudaStream>,
298 completion: &Arc<dyn CompletionStrategy>,
299 dst: GpuRef<T>,
300 mean: T::Scalar,
301 std: T::Scalar,
302 reply: oneshot::Sender<Result<(), GpuError>>,
303) -> Result<(), GpuError>
304where
305 T: RngFloatSupported + NormalParam<T::Scalar>,
306 sys::curandGenerator_t: LogNormalFill<T>,
307{
308 let dst_arc = match dst.access() {
309 Ok(s) => s.clone(),
310 Err(e) => {
311 let _ = reply.send(Err(e));
312 return Ok(());
313 }
314 };
315 let mut owned = match Arc::try_unwrap(dst_arc) {
316 Ok(s) => s,
317 Err(_) => {
318 let _ = reply.send(Err(GpuError::Unrecoverable(
319 "RNG dst has multiple live references".into(),
320 )));
321 return Ok(());
322 }
323 };
324 let mean_t = T::from_scalar(mean);
325 let std_t = T::from_scalar(std);
326 dst.record_write(stream);
327 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
328 let n = owned.len();
329 let res = unsafe {
330 let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
331 LogNormalFill::fill(gen, ptr as *mut T, n, mean_t, std_t)
332 };
333 res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
334 lib: LIB,
335 msg: format!("fill_log_normal: {e}"),
336 })
337 });
338 Ok(())
339}
340
341pub trait NormalParam<S>: Sized {
346 fn from_scalar(s: S) -> Self;
347}
348impl NormalParam<f32> for f32 {
349 fn from_scalar(s: f32) -> Self {
350 s
351 }
352}
353impl NormalParam<f64> for f64 {
354 fn from_scalar(s: f64) -> Self {
355 s
356 }
357}
358
359pub(crate) fn fill_uniform_u32(
362 gen: sys::curandGenerator_t,
363 stream: &Arc<cudarc::driver::CudaStream>,
364 completion: &Arc<dyn CompletionStrategy>,
365 dst: GpuRef<u32>,
366 reply: oneshot::Sender<Result<(), GpuError>>,
367) {
368 let dst_arc = match dst.access() {
369 Ok(s) => s.clone(),
370 Err(e) => {
371 let _ = reply.send(Err(e));
372 return;
373 }
374 };
375 let mut owned = match Arc::try_unwrap(dst_arc) {
376 Ok(s) => s,
377 Err(_) => {
378 let _ = reply.send(Err(GpuError::Unrecoverable(
379 "RNG dst has multiple live references".into(),
380 )));
381 return;
382 }
383 };
384 dst.record_write(stream);
385 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
386 let n = owned.len();
387 let res = unsafe {
388 let (ptr, _rec) = cudarc::driver::DevicePtrMut::device_ptr_mut(&mut owned, stream);
389 UniformFill::fill(gen, ptr as *mut u32, n)
390 };
391 res.map(|_| (owned,)).map_err(|e| GpuError::LibraryError {
392 lib: LIB,
393 msg: format!("fill_uniform_u32: {e}"),
394 })
395 });
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
405 fn distribution_round_trip_f32_f64() {
406 let _: Distribution<f32> = Distribution::Uniform { lo: 0.0, hi: 1.0 };
408 let _: Distribution<f32> = Distribution::Normal {
409 mean: 0.0,
410 std: 1.0,
411 };
412 let _: Distribution<f32> = Distribution::LogNormal {
413 mean: 0.0,
414 std: 1.0,
415 };
416 let _: Distribution<f32> = Distribution::Poisson { lambda: 1.0 };
417 let _: Distribution<f32> = Distribution::Exponential { lambda: 1.0 };
418 let _: Distribution<f32> = Distribution::Beta {
419 alpha: 1.0,
420 beta: 1.0,
421 };
422 let _: Distribution<f32> = Distribution::Cauchy {
423 loc: 0.0,
424 scale: 1.0,
425 };
426 let _: Distribution<f32> = Distribution::Gamma {
427 shape: 1.0,
428 scale: 1.0,
429 };
430 let _: Distribution<f64> = Distribution::Uniform { lo: 0.0, hi: 1.0 };
432 let _: Distribution<f64> = Distribution::Normal {
433 mean: 0.0,
434 std: 1.0,
435 };
436 let _: Distribution<f64> = Distribution::LogNormal {
437 mean: 0.0,
438 std: 1.0,
439 };
440 let _: Distribution<f64> = Distribution::Poisson { lambda: 1.0 };
441 let _: Distribution<f64> = Distribution::Exponential { lambda: 1.0 };
442 let _: Distribution<f64> = Distribution::Beta {
443 alpha: 1.0,
444 beta: 1.0,
445 };
446 let _: Distribution<f64> = Distribution::Cauchy {
447 loc: 0.0,
448 scale: 1.0,
449 };
450 let _: Distribution<f64> = Distribution::Gamma {
451 shape: 1.0,
452 scale: 1.0,
453 };
454 }
455
456 #[test]
461 #[allow(deprecated)]
462 fn deprecated_fill_uniform_f32_still_works() {
463 let (tx, _rx) = tokio::sync::oneshot::channel::<Result<(), GpuError>>();
466 let _ = std::mem::ManuallyDrop::new(tx);
467 fn _assert<
469 F: FnOnce(GpuRef<f32>, oneshot::Sender<Result<(), GpuError>>) -> super::super::RngMsg,
470 >(
471 _f: F,
472 ) {
473 }
474 _assert(|dst, reply| super::super::RngMsg::FillUniformF32 { dst, reply });
475 }
476}