1use std::any::{Any, TypeId};
12use std::collections::{HashMap, VecDeque};
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use atomr_core::actor::{Actor, ActorRef, Context, Props};
17use bitflags::bitflags;
18use parking_lot::RwLock;
19use tokio::sync::oneshot;
20use tracing::{debug, warn};
21
22use crate::dtype::CudaDtype;
23use crate::error::GpuError;
24use crate::gpu_ref::GpuRef;
25use crate::kernel::BlasMsg;
26
27use super::alloc_dispatch::{
28 AllocDispatch, AllocReq, CopyFromHostDispatch, CopyFromHostReq, CopyToHostDispatch,
29 CopyToHostReq,
30};
31use super::alloc_msg::{DeviceLoad, HostBuf};
32use super::context_actor::{ContextActor, ContextMsg};
33use super::state::DeviceState;
34
35bitflags! {
36 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
40 pub struct EnabledLibraries: u32 {
41 const BLAS = 1 << 0;
42 const CUDNN = 1 << 1;
43 const CUFFT = 1 << 2;
44 const CURAND = 1 << 3;
45 const CUSOLVER = 1 << 4;
46 const CUBLASLT = 1 << 5;
47 const NVRTC = 1 << 6;
48 const CUTENSOR = 1 << 7;
50 const CUSPARSE = 1 << 8;
51 const NCCL = 1 << 9;
52 const CUTLASS = 1 << 10;
53 const TENSORRT = 1 << 11;
54 const FLASHATTN = 1 << 12;
55 const CUB_THRUST = 1 << 13;
56 const TELEMETRY = 1 << 14;
57
58 const ALL = Self::BLAS.bits()
59 | Self::CUDNN.bits()
60 | Self::CUFFT.bits()
61 | Self::CURAND.bits()
62 | Self::CUSOLVER.bits()
63 | Self::CUBLASLT.bits()
64 | Self::NVRTC.bits()
65 | Self::CUTENSOR.bits()
66 | Self::CUSPARSE.bits()
67 | Self::NCCL.bits()
68 | Self::CUTLASS.bits()
69 | Self::TENSORRT.bits()
70 | Self::FLASHATTN.bits()
71 | Self::CUB_THRUST.bits()
72 | Self::TELEMETRY.bits();
73 }
74}
75
76impl Default for EnabledLibraries {
77 fn default() -> Self {
80 Self::BLAS
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct DeviceConfig {
87 pub device_id: u32,
88 pub mock_mode: bool,
92 pub pending_queue_capacity: usize,
96 pub enabled_libraries: EnabledLibraries,
100}
101
102impl DeviceConfig {
103 pub fn new(device_id: u32) -> Self {
104 Self {
105 device_id,
106 mock_mode: false,
107 pending_queue_capacity: 1024,
108 enabled_libraries: EnabledLibraries::default(),
109 }
110 }
111
112 pub fn mock(device_id: u32) -> Self {
113 Self {
114 device_id,
115 mock_mode: true,
116 pending_queue_capacity: 1024,
117 enabled_libraries: EnabledLibraries::default(),
118 }
119 }
120
121 pub fn with_libraries(mut self, libs: EnabledLibraries) -> Self {
123 self.enabled_libraries = libs;
124 self
125 }
126}
127
128pub enum DeviceMsg {
148 Alloc(Box<dyn AllocDispatch>),
152 CopyToHost(Box<dyn CopyToHostDispatch>),
154 CopyFromHost(Box<dyn CopyFromHostDispatch>),
156
157 #[deprecated(note = "use DeviceMsg::alloc::<f32>(len, reply)")]
160 Allocate {
161 len: usize,
162 reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
163 },
164 #[deprecated(note = "use DeviceMsg::alloc::<f32>(len, reply)")]
165 AllocateF32 {
166 len: usize,
167 reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
168 },
169 #[deprecated(note = "use DeviceMsg::alloc::<f64>(len, reply)")]
170 AllocateF64 {
171 len: usize,
172 reply: oneshot::Sender<Result<GpuRef<f64>, GpuError>>,
173 },
174 #[deprecated(note = "use DeviceMsg::alloc::<i8>(len, reply)")]
175 AllocateI8 {
176 len: usize,
177 reply: oneshot::Sender<Result<GpuRef<i8>, GpuError>>,
178 },
179 #[deprecated(note = "use DeviceMsg::alloc::<i32>(len, reply)")]
180 AllocateI32 {
181 len: usize,
182 reply: oneshot::Sender<Result<GpuRef<i32>, GpuError>>,
183 },
184 #[deprecated(note = "use DeviceMsg::alloc::<i64>(len, reply)")]
185 AllocateI64 {
186 len: usize,
187 reply: oneshot::Sender<Result<GpuRef<i64>, GpuError>>,
188 },
189 #[deprecated(note = "use DeviceMsg::alloc::<u8>(len, reply)")]
190 AllocateU8 {
191 len: usize,
192 reply: oneshot::Sender<Result<GpuRef<u8>, GpuError>>,
193 },
194 #[deprecated(note = "use DeviceMsg::alloc::<u32>(len, reply)")]
195 AllocateU32 {
196 len: usize,
197 reply: oneshot::Sender<Result<GpuRef<u32>, GpuError>>,
198 },
199 #[deprecated(note = "use DeviceMsg::alloc::<u64>(len, reply)")]
200 AllocateU64 {
201 len: usize,
202 reply: oneshot::Sender<Result<GpuRef<u64>, GpuError>>,
203 },
204 #[cfg(feature = "f16")]
205 #[deprecated(note = "use DeviceMsg::alloc::<half::f16>(len, reply)")]
206 AllocateF16 {
207 len: usize,
208 reply: oneshot::Sender<Result<GpuRef<half::f16>, GpuError>>,
209 },
210 #[cfg(feature = "f16")]
211 #[deprecated(note = "use DeviceMsg::alloc::<half::bf16>(len, reply)")]
212 AllocateBf16 {
213 len: usize,
214 reply: oneshot::Sender<Result<GpuRef<half::bf16>, GpuError>>,
215 },
216
217 #[deprecated(note = "use DeviceMsg::copy_to_host::<f32>(src, dst, reply)")]
220 CopyToHostF32 {
221 src: GpuRef<f32>,
222 dst: HostBuf<f32>,
223 reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
224 },
225 #[deprecated(note = "use DeviceMsg::copy_from_host::<f32>(src, dst, reply)")]
226 CopyFromHostF32 {
227 src: HostBuf<f32>,
228 dst: GpuRef<f32>,
229 reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
230 },
231 #[deprecated(note = "use DeviceMsg::copy_to_host::<f64>(src, dst, reply)")]
232 CopyToHostF64 {
233 src: GpuRef<f64>,
234 dst: HostBuf<f64>,
235 reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
236 },
237 #[deprecated(note = "use DeviceMsg::copy_from_host::<f64>(src, dst, reply)")]
238 CopyFromHostF64 {
239 src: HostBuf<f64>,
240 dst: GpuRef<f64>,
241 reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
242 },
243 #[deprecated(note = "use DeviceMsg::copy_to_host::<i32>(src, dst, reply)")]
244 CopyToHostI32 {
245 src: GpuRef<i32>,
246 dst: HostBuf<i32>,
247 reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
248 },
249 #[deprecated(note = "use DeviceMsg::copy_from_host::<i32>(src, dst, reply)")]
250 CopyFromHostI32 {
251 src: HostBuf<i32>,
252 dst: GpuRef<i32>,
253 reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
254 },
255 #[deprecated(note = "use DeviceMsg::copy_to_host::<u32>(src, dst, reply)")]
256 CopyToHostU32 {
257 src: GpuRef<u32>,
258 dst: HostBuf<u32>,
259 reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
260 },
261 #[deprecated(note = "use DeviceMsg::copy_from_host::<u32>(src, dst, reply)")]
262 CopyFromHostU32 {
263 src: HostBuf<u32>,
264 dst: GpuRef<u32>,
265 reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
266 },
267 #[deprecated(note = "use DeviceMsg::copy_to_host::<u8>(src, dst, reply)")]
268 CopyToHostU8 {
269 src: GpuRef<u8>,
270 dst: HostBuf<u8>,
271 reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
272 },
273 #[deprecated(note = "use DeviceMsg::copy_from_host::<u8>(src, dst, reply)")]
274 CopyFromHostU8 {
275 src: HostBuf<u8>,
276 dst: GpuRef<u8>,
277 reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
278 },
279
280 Sgemm(Box<SgemmRequest>),
282
283 SnapshotContext {
287 reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaContext>>>,
288 },
289
290 SnapshotStream {
301 reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaStream>>>,
302 },
303
304 SnapshotChildren {
309 reply: oneshot::Sender<Option<KernelChildren>>,
310 },
311
312 WatchGeneration {
317 reply: oneshot::Sender<tokio::sync::watch::Receiver<u64>>,
318 },
319
320 Stats { reply: oneshot::Sender<DeviceLoad> },
322
323 ContextReady { children: KernelChildren },
326 ContextLost,
330}
331
332#[derive(Clone)]
344pub struct KernelChildren {
345 pub blas: ActorRef<BlasMsg>,
346 #[cfg(feature = "cudnn")]
347 pub cudnn: Option<ActorRef<crate::kernel::CudnnMsg>>,
348 #[cfg(feature = "cufft")]
349 pub fft: Option<ActorRef<crate::kernel::FftMsg>>,
350 #[cfg(feature = "curand")]
351 pub rng: Option<ActorRef<crate::kernel::RngMsg>>,
352 #[cfg(feature = "cusolver")]
353 pub solver: Option<ActorRef<crate::kernel::SolverMsg>>,
354 #[cfg(feature = "nvrtc")]
355 pub nvrtc: Option<ActorRef<crate::kernel::NvrtcMsg>>,
356 extras: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
361}
362
363impl KernelChildren {
364 pub fn new(blas: ActorRef<BlasMsg>) -> Self {
368 Self {
369 blas,
370 #[cfg(feature = "cudnn")]
371 cudnn: None,
372 #[cfg(feature = "cufft")]
373 fft: None,
374 #[cfg(feature = "curand")]
375 rng: None,
376 #[cfg(feature = "cusolver")]
377 solver: None,
378 #[cfg(feature = "nvrtc")]
379 nvrtc: None,
380 extras: Arc::new(RwLock::new(HashMap::new())),
381 }
382 }
383
384 pub fn register_extra<T: Any + Send + Sync>(&self, value: T) {
393 let mut g = self.extras.write();
394 g.insert(TypeId::of::<T>(), Arc::new(value));
395 }
396
397 pub fn extra<T: Any + Send + Sync + Clone>(&self) -> Option<T> {
401 let g = self.extras.read();
402 g.get(&TypeId::of::<T>())
403 .and_then(|v| v.clone().downcast::<T>().ok())
404 .map(|arc| (*arc).clone())
405 }
406
407 pub fn extras_len(&self) -> usize {
409 self.extras.read().len()
410 }
411}
412
413impl DeviceMsg {
414 pub fn alloc<T: CudaDtype>(
417 len: usize,
418 reply: oneshot::Sender<Result<GpuRef<T>, GpuError>>,
419 ) -> Self {
420 DeviceMsg::Alloc(Box::new(AllocReq::<T> { len, reply }))
421 }
422
423 pub fn copy_to_host<T: CudaDtype>(
425 src: GpuRef<T>,
426 dst: HostBuf<T>,
427 reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
428 ) -> Self {
429 DeviceMsg::CopyToHost(Box::new(CopyToHostReq::<T> { src, dst, reply }))
430 }
431
432 pub fn copy_from_host<T: CudaDtype>(
434 src: HostBuf<T>,
435 dst: GpuRef<T>,
436 reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
437 ) -> Self {
438 DeviceMsg::CopyFromHost(Box::new(CopyFromHostReq::<T> { src, dst, reply }))
439 }
440}
441
442pub struct SgemmRequest {
446 pub a: GpuRef<f32>,
447 pub b: GpuRef<f32>,
448 pub c: GpuRef<f32>,
449 pub m: i32,
450 pub n: i32,
451 pub k: i32,
452 pub alpha: f32,
453 pub beta: f32,
454 pub reply: oneshot::Sender<Result<(), GpuError>>,
455}
456
457pub enum WorkRequest {
461 Boxed(Box<dyn FnOnce(&ActorRef<ContextMsg>, &ActorRef<BlasMsg>) + Send>),
466 Sgemm(Box<SgemmRequest>),
467 SnapshotContext {
470 reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaContext>>>,
471 },
472}
473
474pub struct DeviceActor {
475 config: DeviceConfig,
476 state: Arc<DeviceState>,
477 context_ref: Option<ActorRef<ContextMsg>>,
478 children: Option<KernelChildren>,
479 pending: VecDeque<WorkRequest>,
480}
481
482impl DeviceActor {
483 pub fn new(config: DeviceConfig) -> Self {
484 let state = Arc::new(DeviceState::new(config.device_id));
485 Self {
486 config,
487 state,
488 context_ref: None,
489 children: None,
490 pending: VecDeque::new(),
491 }
492 }
493
494 pub fn props(config: DeviceConfig) -> Props<Self> {
496 let cfg = config.clone();
497 Props::create(move || DeviceActor::new(cfg.clone()))
498 }
499
500 pub fn state(&self) -> &Arc<DeviceState> {
503 &self.state
504 }
505
506 fn enqueue_pending(&mut self, work: WorkRequest) {
507 if self.pending.len() >= self.config.pending_queue_capacity {
508 warn!(
509 device_id = self.config.device_id,
510 cap = self.config.pending_queue_capacity,
511 "dropping work — pending queue full"
512 );
513 match work {
517 WorkRequest::Sgemm(req) => {
518 let _ = req.reply.send(Err(GpuError::Unrecoverable(
519 "device pending queue full".into(),
520 )));
521 }
522 WorkRequest::SnapshotContext { reply } => {
523 let _ = reply.send(None);
524 }
525 WorkRequest::Boxed(_) => { }
526 }
527 return;
528 }
529 self.pending.push_back(work);
530 }
531
532 fn drain_pending(&mut self) {
533 let Some(children) = self.children.clone() else {
534 return;
535 };
536 let Some(ctx) = self.context_ref.clone() else {
537 return;
538 };
539 while let Some(work) = self.pending.pop_front() {
540 match work {
541 WorkRequest::Boxed(f) => f(&ctx, &children.blas),
542 WorkRequest::Sgemm(req) => {
543 children.blas.tell(BlasMsg::Sgemm(req));
544 }
545 WorkRequest::SnapshotContext { reply } => {
546 let _ = reply.send(self.state.current_context());
549 }
550 }
551 }
552 }
553}
554
555#[async_trait]
556impl Actor for DeviceActor {
557 type Msg = DeviceMsg;
558
559 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
560 debug!(device_id = self.config.device_id, "DeviceActor pre_start");
561 let parent_ref = ctx.self_ref().clone();
562 let props = ContextActor::props(self.state.clone(), self.config.clone(), parent_ref);
563 match ctx.spawn::<ContextActor>(props, "ctx") {
564 Ok(r) => {
565 self.context_ref = Some(r);
566 }
567 Err(e) => {
568 panic!("Unrecoverable: failed to spawn ContextActor: {e}");
571 }
572 }
573 }
574
575 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: DeviceMsg) {
576 #[allow(deprecated)]
582 let msg = match msg {
583 DeviceMsg::Allocate { len, reply } | DeviceMsg::AllocateF32 { len, reply } => {
585 DeviceMsg::alloc::<f32>(len, reply)
586 }
587 DeviceMsg::AllocateF64 { len, reply } => DeviceMsg::alloc::<f64>(len, reply),
588 DeviceMsg::AllocateI8 { len, reply } => DeviceMsg::alloc::<i8>(len, reply),
589 DeviceMsg::AllocateI32 { len, reply } => DeviceMsg::alloc::<i32>(len, reply),
590 DeviceMsg::AllocateI64 { len, reply } => DeviceMsg::alloc::<i64>(len, reply),
591 DeviceMsg::AllocateU8 { len, reply } => DeviceMsg::alloc::<u8>(len, reply),
592 DeviceMsg::AllocateU32 { len, reply } => DeviceMsg::alloc::<u32>(len, reply),
593 DeviceMsg::AllocateU64 { len, reply } => DeviceMsg::alloc::<u64>(len, reply),
594 #[cfg(feature = "f16")]
595 DeviceMsg::AllocateF16 { len, reply } => DeviceMsg::alloc::<half::f16>(len, reply),
596 #[cfg(feature = "f16")]
597 DeviceMsg::AllocateBf16 { len, reply } => DeviceMsg::alloc::<half::bf16>(len, reply),
598 DeviceMsg::CopyToHostF32 { src, dst, reply } => {
600 DeviceMsg::copy_to_host::<f32>(src, dst, reply)
601 }
602 DeviceMsg::CopyToHostF64 { src, dst, reply } => {
603 DeviceMsg::copy_to_host::<f64>(src, dst, reply)
604 }
605 DeviceMsg::CopyToHostI32 { src, dst, reply } => {
606 DeviceMsg::copy_to_host::<i32>(src, dst, reply)
607 }
608 DeviceMsg::CopyToHostU32 { src, dst, reply } => {
609 DeviceMsg::copy_to_host::<u32>(src, dst, reply)
610 }
611 DeviceMsg::CopyToHostU8 { src, dst, reply } => {
612 DeviceMsg::copy_to_host::<u8>(src, dst, reply)
613 }
614 DeviceMsg::CopyFromHostF32 { src, dst, reply } => {
616 DeviceMsg::copy_from_host::<f32>(src, dst, reply)
617 }
618 DeviceMsg::CopyFromHostF64 { src, dst, reply } => {
619 DeviceMsg::copy_from_host::<f64>(src, dst, reply)
620 }
621 DeviceMsg::CopyFromHostI32 { src, dst, reply } => {
622 DeviceMsg::copy_from_host::<i32>(src, dst, reply)
623 }
624 DeviceMsg::CopyFromHostU32 { src, dst, reply } => {
625 DeviceMsg::copy_from_host::<u32>(src, dst, reply)
626 }
627 DeviceMsg::CopyFromHostU8 { src, dst, reply } => {
628 DeviceMsg::copy_from_host::<u8>(src, dst, reply)
629 }
630 other => other,
632 };
633
634 let ready = self.context_ref.is_some() && self.children.is_some();
635
636 match msg {
637 DeviceMsg::Alloc(boxed) => {
639 if ready {
640 self.context_ref
641 .as_ref()
642 .unwrap()
643 .tell(ContextMsg::Alloc(boxed));
644 } else {
645 self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
646 c.tell(ContextMsg::Alloc(boxed))
647 })));
648 }
649 }
650 DeviceMsg::CopyToHost(boxed) => {
651 if ready {
652 self.context_ref
653 .as_ref()
654 .unwrap()
655 .tell(ContextMsg::CopyToHost(boxed));
656 } else {
657 self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
658 c.tell(ContextMsg::CopyToHost(boxed))
659 })));
660 }
661 }
662 DeviceMsg::CopyFromHost(boxed) => {
663 if ready {
664 self.context_ref
665 .as_ref()
666 .unwrap()
667 .tell(ContextMsg::CopyFromHost(boxed));
668 } else {
669 self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
670 c.tell(ContextMsg::CopyFromHost(boxed))
671 })));
672 }
673 }
674
675 #[allow(deprecated)]
679 DeviceMsg::Allocate { .. }
680 | DeviceMsg::AllocateF32 { .. }
681 | DeviceMsg::AllocateF64 { .. }
682 | DeviceMsg::AllocateI8 { .. }
683 | DeviceMsg::AllocateI32 { .. }
684 | DeviceMsg::AllocateI64 { .. }
685 | DeviceMsg::AllocateU8 { .. }
686 | DeviceMsg::AllocateU32 { .. }
687 | DeviceMsg::AllocateU64 { .. }
688 | DeviceMsg::CopyToHostF32 { .. }
689 | DeviceMsg::CopyFromHostF32 { .. }
690 | DeviceMsg::CopyToHostF64 { .. }
691 | DeviceMsg::CopyFromHostF64 { .. }
692 | DeviceMsg::CopyToHostI32 { .. }
693 | DeviceMsg::CopyFromHostI32 { .. }
694 | DeviceMsg::CopyToHostU32 { .. }
695 | DeviceMsg::CopyFromHostU32 { .. }
696 | DeviceMsg::CopyToHostU8 { .. }
697 | DeviceMsg::CopyFromHostU8 { .. } => unreachable!(
698 "Phase 0.4 translation collapses all legacy alloc/copy variants \
699 into DeviceMsg::Alloc / CopyToHost / CopyFromHost"
700 ),
701 #[cfg(feature = "f16")]
702 #[allow(deprecated)]
703 DeviceMsg::AllocateF16 { .. } | DeviceMsg::AllocateBf16 { .. } => {
704 unreachable!(
705 "Phase 0.4 translation collapses all legacy alloc/copy variants \
706 into DeviceMsg::Alloc"
707 )
708 }
709
710 DeviceMsg::Sgemm(req) => match &self.children {
711 Some(c) => c.blas.tell(BlasMsg::Sgemm(req)),
712 None => self.enqueue_pending(WorkRequest::Sgemm(req)),
713 },
714
715 DeviceMsg::SnapshotContext { reply } => {
716 let _ = reply.send(self.state.current_context());
717 }
718 DeviceMsg::SnapshotStream { reply } => {
719 if let Some(ctx) = self.context_ref.as_ref() {
727 ctx.tell(ContextMsg::SnapshotStream { reply });
728 } else {
729 let _ = reply.send(None);
730 }
731 }
732 DeviceMsg::SnapshotChildren { reply } => {
733 let _ = reply.send(self.children.clone());
734 }
735 DeviceMsg::WatchGeneration { reply } => {
736 let _ = reply.send(self.state.generation_watch());
737 }
738 DeviceMsg::Stats { reply } => {
739 let _ = reply.send(self.snapshot_load());
740 }
741
742 DeviceMsg::ContextReady { children } => {
743 debug!(device_id = self.config.device_id, "context ready");
744 self.children = Some(children);
745 self.drain_pending();
746 }
747 DeviceMsg::ContextLost => {
748 debug!(device_id = self.config.device_id, "context lost");
749 self.children = None;
750 }
751 }
752 }
753
754 async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
755 debug!(device_id = self.config.device_id, "DeviceActor post_stop");
756 self.state.begin_shutdown();
757 while let Some(work) = self.pending.pop_front() {
759 match work {
760 WorkRequest::Boxed(_) => { }
761 WorkRequest::Sgemm(req) => {
762 let _ = req
763 .reply
764 .send(Err(GpuError::GpuRefStale("device shutting down")));
765 }
766 WorkRequest::SnapshotContext { reply } => {
767 let _ = reply.send(None);
768 }
769 }
770 }
771 }
772}
773
774impl DeviceActor {
775 fn snapshot_load(&self) -> DeviceLoad {
776 DeviceLoad {
777 free_bytes: 0,
778 total_bytes: 0,
779 active_streams: 0,
780 queue_depth: self.pending.len() as u32,
781 compute_cap: (0, 0),
782 }
783 }
784}
785
786#[cfg(test)]
787#[allow(deprecated)] mod tests {
789 use super::*;
790 use crate::dtype::DType;
791 use atomr_config::Config;
792 use atomr_core::actor::ActorSystem;
793 use std::time::Duration;
794
795 #[test]
799 fn enabled_libraries_bit_values_are_stable() {
800 assert_eq!(EnabledLibraries::BLAS.bits(), 1 << 0);
801 assert_eq!(EnabledLibraries::CUDNN.bits(), 1 << 1);
802 assert_eq!(EnabledLibraries::CUFFT.bits(), 1 << 2);
803 assert_eq!(EnabledLibraries::CURAND.bits(), 1 << 3);
804 assert_eq!(EnabledLibraries::CUSOLVER.bits(), 1 << 4);
805 assert_eq!(EnabledLibraries::CUBLASLT.bits(), 1 << 5);
806 assert_eq!(EnabledLibraries::NVRTC.bits(), 1 << 6);
807 assert_eq!(EnabledLibraries::CUTENSOR.bits(), 1 << 7);
809 assert_eq!(EnabledLibraries::CUSPARSE.bits(), 1 << 8);
810 assert_eq!(EnabledLibraries::NCCL.bits(), 1 << 9);
811 assert_eq!(EnabledLibraries::CUTLASS.bits(), 1 << 10);
812 assert_eq!(EnabledLibraries::TENSORRT.bits(), 1 << 11);
813 assert_eq!(EnabledLibraries::FLASHATTN.bits(), 1 << 12);
814 assert_eq!(EnabledLibraries::CUB_THRUST.bits(), 1 << 13);
815 assert_eq!(EnabledLibraries::TELEMETRY.bits(), 1 << 14);
816 }
817
818 #[test]
819 fn enabled_libraries_round_trip_via_bits() {
820 let original = EnabledLibraries::BLAS
821 | EnabledLibraries::CUTENSOR
822 | EnabledLibraries::FLASHATTN
823 | EnabledLibraries::TELEMETRY;
824 let bits = original.bits();
825 let restored =
826 EnabledLibraries::from_bits(bits).expect("known bits round-trip through from_bits");
827 assert_eq!(original, restored);
828 assert!(restored.contains(EnabledLibraries::FLASHATTN));
829 assert!(!restored.contains(EnabledLibraries::CUDNN));
830 }
831
832 #[test]
833 fn enabled_libraries_all_contains_every_phase_0_8_bit() {
834 let all = EnabledLibraries::ALL;
835 for bit in [
836 EnabledLibraries::BLAS,
837 EnabledLibraries::CUDNN,
838 EnabledLibraries::CUFFT,
839 EnabledLibraries::CURAND,
840 EnabledLibraries::CUSOLVER,
841 EnabledLibraries::CUBLASLT,
842 EnabledLibraries::NVRTC,
843 EnabledLibraries::CUTENSOR,
844 EnabledLibraries::CUSPARSE,
845 EnabledLibraries::NCCL,
846 EnabledLibraries::CUTLASS,
847 EnabledLibraries::TENSORRT,
848 EnabledLibraries::FLASHATTN,
849 EnabledLibraries::CUB_THRUST,
850 EnabledLibraries::TELEMETRY,
851 ] {
852 assert!(all.contains(bit), "ALL missing {bit:?}");
853 }
854 }
855
856 #[test]
861 fn kernel_children_extras_register_and_retrieve_by_type() {
862 let extras: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
867 Arc::new(RwLock::new(HashMap::new()));
868
869 fn register<T: Any + Send + Sync>(
873 map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
874 v: T,
875 ) {
876 map.write().insert(TypeId::of::<T>(), Arc::new(v));
877 }
878 fn lookup<T: Any + Send + Sync + Clone>(
879 map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
880 ) -> Option<T> {
881 map.read()
882 .get(&TypeId::of::<T>())
883 .and_then(|v| v.clone().downcast::<T>().ok())
884 .map(|arc| (*arc).clone())
885 }
886
887 #[derive(Clone, PartialEq, Eq, Debug)]
888 struct CutlassRef(u32);
889 #[derive(Clone, PartialEq, Eq, Debug)]
890 struct TensorRtRef(&'static str);
891
892 register(&extras, CutlassRef(7));
893 register(&extras, TensorRtRef("trt"));
894
895 assert_eq!(lookup::<CutlassRef>(&extras), Some(CutlassRef(7)));
896 assert_eq!(lookup::<TensorRtRef>(&extras), Some(TensorRtRef("trt")));
897 #[derive(Clone)]
899 struct Unknown;
900 assert!(lookup::<Unknown>(&extras).is_none());
901 assert_eq!(extras.read().len(), 2);
902
903 register(&extras, CutlassRef(99));
905 assert_eq!(lookup::<CutlassRef>(&extras), Some(CutlassRef(99)));
906 assert_eq!(extras.read().len(), 2);
907 }
908
909 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
913 async fn kernel_children_extras_via_snapshot() {
914 let sys = ActorSystem::create("kc_extras", Config::empty())
915 .await
916 .unwrap();
917 let dev = sys
918 .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev0")
919 .unwrap();
920
921 let mut snap: Option<KernelChildren> = None;
923 for _ in 0..50 {
924 let (tx, rx) = oneshot::channel();
925 dev.tell(DeviceMsg::SnapshotChildren { reply: tx });
926 if let Ok(Some(c)) = rx.await {
927 snap = Some(c);
928 break;
929 }
930 tokio::time::sleep(Duration::from_millis(20)).await;
931 }
932 let children = snap.expect("KernelChildren snapshot should arrive in mock mode");
933 assert_eq!(children.extras_len(), 0);
934
935 #[derive(Clone, Debug, PartialEq, Eq)]
936 struct FakeCutlassRef(u64);
937 children.register_extra(FakeCutlassRef(42));
938 assert_eq!(children.extras_len(), 1);
939 assert_eq!(children.extra::<FakeCutlassRef>(), Some(FakeCutlassRef(42)));
940 let cloned = children.clone();
942 assert_eq!(cloned.extras_len(), 1);
943 assert_eq!(cloned.extra::<FakeCutlassRef>(), Some(FakeCutlassRef(42)));
944
945 sys.terminate().await;
946 }
947
948 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
953 async fn pending_work_drains_on_context_ready() {
954 let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
955 let dev = sys
956 .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev0")
957 .unwrap();
958
959 let (tx, rx) = oneshot::channel();
963 dev.tell(DeviceMsg::Allocate { len: 16, reply: tx });
964 let res = tokio::time::timeout(Duration::from_secs(2), rx)
965 .await
966 .expect("alloc reply should arrive within timeout")
967 .expect("oneshot dropped");
968 assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
972
973 sys.terminate().await;
974 }
975
976 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
981 async fn alloc_dispatch_via_typed_constructor() {
982 let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
983 let dev = sys
984 .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev1")
985 .unwrap();
986
987 let (tx, rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
988 dev.tell(DeviceMsg::alloc::<f32>(64, tx));
989 let res = tokio::time::timeout(Duration::from_secs(2), rx)
990 .await
991 .expect("alloc reply within timeout")
992 .expect("oneshot dropped");
993 assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
994
995 sys.terminate().await;
996 }
997
998 #[test]
1002 fn alloc_dispatch_dtype_kind_correct() {
1003 let (tx, _rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
1005 let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<f32> { len: 4, reply: tx });
1006 assert_eq!(boxed.dtype(), DType::F32);
1007 assert_eq!(boxed.len(), 4);
1008
1009 let (tx, _rx) = oneshot::channel::<Result<GpuRef<i32>, GpuError>>();
1011 let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<i32> { len: 7, reply: tx });
1012 assert_eq!(boxed.dtype(), DType::I32);
1013
1014 let (tx, _rx) = oneshot::channel::<Result<GpuRef<u8>, GpuError>>();
1016 let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<u8> { len: 1, reply: tx });
1017 assert_eq!(boxed.dtype(), DType::U8);
1018 }
1019
1020 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1024 async fn deprecated_allocate_f32_still_works() {
1025 let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
1026 let dev = sys
1027 .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev2")
1028 .unwrap();
1029
1030 let (tx, rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
1031 dev.tell(DeviceMsg::AllocateF32 { len: 8, reply: tx });
1034 let res = tokio::time::timeout(Duration::from_secs(2), rx)
1035 .await
1036 .expect("alloc reply within timeout")
1037 .expect("oneshot dropped");
1038 assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
1039
1040 sys.terminate().await;
1041 }
1042
1043 #[test]
1048 fn copy_to_host_typed() {
1049 struct Stub<T: CudaDtype>(std::marker::PhantomData<T>);
1050 impl<T: CudaDtype> CopyToHostDispatch for Stub<T> {
1051 fn dtype(&self) -> DType {
1052 T::KIND
1053 }
1054 fn run(
1055 self: Box<Self>,
1056 _stream: Arc<cudarc::driver::CudaStream>,
1057 _completion: Arc<dyn crate::completion::CompletionStrategy>,
1058 ) {
1059 }
1061 }
1062
1063 let boxed: Box<dyn CopyToHostDispatch> = Box::new(Stub::<f32>(std::marker::PhantomData));
1064 assert_eq!(boxed.dtype(), DType::F32);
1065 let boxed: Box<dyn CopyToHostDispatch> = Box::new(Stub::<i32>(std::marker::PhantomData));
1066 assert_eq!(boxed.dtype(), DType::I32);
1067
1068 let msg = DeviceMsg::CopyToHost(Box::new(Stub::<u32>(std::marker::PhantomData)));
1072 match msg {
1073 DeviceMsg::CopyToHost(b) => assert_eq!(b.dtype(), DType::U32),
1074 _ => panic!("expected CopyToHost variant"),
1075 }
1076 }
1077
1078 #[test]
1080 fn copy_from_host_typed() {
1081 struct Stub<T: CudaDtype>(std::marker::PhantomData<T>);
1082 impl<T: CudaDtype> CopyFromHostDispatch for Stub<T> {
1083 fn dtype(&self) -> DType {
1084 T::KIND
1085 }
1086 fn run(
1087 self: Box<Self>,
1088 _stream: Arc<cudarc::driver::CudaStream>,
1089 _completion: Arc<dyn crate::completion::CompletionStrategy>,
1090 ) {
1091 }
1092 }
1093
1094 let boxed: Box<dyn CopyFromHostDispatch> = Box::new(Stub::<u8>(std::marker::PhantomData));
1095 assert_eq!(boxed.dtype(), DType::U8);
1096
1097 let msg = DeviceMsg::CopyFromHost(Box::new(Stub::<f64>(std::marker::PhantomData)));
1098 match msg {
1099 DeviceMsg::CopyFromHost(b) => assert_eq!(b.dtype(), DType::F64),
1100 _ => panic!("expected CopyFromHost variant"),
1101 }
1102 }
1103}