1use std::sync::Arc;
16
17use async_trait::async_trait;
18use atomr_core::actor::{Actor, ActorRef, Context, Props};
19use cudarc::driver::DeviceRepr;
20use cudarc::driver::ValidAsZeroBits;
21use tokio::sync::oneshot;
22use tracing::{debug, info, warn};
23
24use crate::completion::{CompletionStrategy, HostFnCompletion};
25use crate::error::{device_supervisor_strategy, GpuError, CONTEXT_POISONED_TAG};
26use crate::gpu_ref::GpuRef;
27use crate::kernel::{envelope, BlasActor};
28use crate::stream::{PerActorAllocator, StreamAllocator};
29
30use super::alloc_dispatch::{AllocDispatch, CopyFromHostDispatch, CopyToHostDispatch};
31use super::alloc_msg::HostBuf;
32use super::device_actor::{DeviceConfig, DeviceMsg, EnabledLibraries, KernelChildren};
33use super::state::DeviceState;
34
35pub enum ContextMsg {
36 Init,
39
40 SnapshotStream {
46 reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaStream>>>,
47 },
48
49 Alloc(Box<dyn AllocDispatch>),
56 CopyToHost(Box<dyn CopyToHostDispatch>),
58 CopyFromHost(Box<dyn CopyFromHostDispatch>),
60
61 AllocateF32 {
67 len: usize,
68 reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
69 },
70 AllocateF64 {
71 len: usize,
72 reply: oneshot::Sender<Result<GpuRef<f64>, GpuError>>,
73 },
74 AllocateI8 {
75 len: usize,
76 reply: oneshot::Sender<Result<GpuRef<i8>, GpuError>>,
77 },
78 AllocateI32 {
79 len: usize,
80 reply: oneshot::Sender<Result<GpuRef<i32>, GpuError>>,
81 },
82 AllocateI64 {
83 len: usize,
84 reply: oneshot::Sender<Result<GpuRef<i64>, GpuError>>,
85 },
86 AllocateU8 {
87 len: usize,
88 reply: oneshot::Sender<Result<GpuRef<u8>, GpuError>>,
89 },
90 AllocateU32 {
91 len: usize,
92 reply: oneshot::Sender<Result<GpuRef<u32>, GpuError>>,
93 },
94 AllocateU64 {
95 len: usize,
96 reply: oneshot::Sender<Result<GpuRef<u64>, GpuError>>,
97 },
98 #[cfg(feature = "f16")]
99 AllocateF16 {
100 len: usize,
101 reply: oneshot::Sender<Result<GpuRef<half::f16>, GpuError>>,
102 },
103 #[cfg(feature = "f16")]
104 AllocateBf16 {
105 len: usize,
106 reply: oneshot::Sender<Result<GpuRef<half::bf16>, GpuError>>,
107 },
108
109 CopyToHostF32 {
112 src: GpuRef<f32>,
113 dst: HostBuf<f32>,
114 reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
115 },
116 CopyFromHostF32 {
117 src: HostBuf<f32>,
118 dst: GpuRef<f32>,
119 reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
120 },
121 CopyToHostF64 {
122 src: GpuRef<f64>,
123 dst: HostBuf<f64>,
124 reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
125 },
126 CopyFromHostF64 {
127 src: HostBuf<f64>,
128 dst: GpuRef<f64>,
129 reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
130 },
131 CopyToHostI32 {
132 src: GpuRef<i32>,
133 dst: HostBuf<i32>,
134 reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
135 },
136 CopyFromHostI32 {
137 src: HostBuf<i32>,
138 dst: GpuRef<i32>,
139 reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
140 },
141 CopyToHostU32 {
142 src: GpuRef<u32>,
143 dst: HostBuf<u32>,
144 reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
145 },
146 CopyFromHostU32 {
147 src: HostBuf<u32>,
148 dst: GpuRef<u32>,
149 reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
150 },
151 CopyToHostU8 {
152 src: GpuRef<u8>,
153 dst: HostBuf<u8>,
154 reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
155 },
156 CopyFromHostU8 {
157 src: HostBuf<u8>,
158 dst: GpuRef<u8>,
159 reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
160 },
161}
162
163pub struct ContextActor {
164 state: Arc<DeviceState>,
165 config: DeviceConfig,
166 parent: ActorRef<DeviceMsg>,
167 stream: Option<Arc<cudarc::driver::CudaStream>>,
171 allocator: Option<Arc<dyn StreamAllocator>>,
174 completion: Arc<dyn CompletionStrategy>,
176 children: Option<KernelChildren>,
177}
178
179impl ContextActor {
180 pub fn props(
181 state: Arc<DeviceState>,
182 config: DeviceConfig,
183 parent: ActorRef<DeviceMsg>,
184 ) -> Props<Self> {
185 let s = state.clone();
186 let c = config.clone();
187 let p = parent.clone();
188 let completion: Arc<dyn CompletionStrategy> = Arc::new(HostFnCompletion::new());
189 Props::create(move || ContextActor {
190 state: s.clone(),
191 config: c.clone(),
192 parent: p.clone(),
193 stream: None,
194 allocator: None,
195 completion: completion.clone(),
196 children: None,
197 })
198 .with_supervisor_strategy(device_supervisor_strategy())
199 }
200
201 async fn run_init(&mut self, ctx: &mut Context<Self>) {
204 let device_id = self.config.device_id;
205
206 if self.config.mock_mode {
207 self.state.bump_generation();
208 let stub = ctx
209 .spawn::<BlasActor>(BlasActor::mock_props(), "blas")
210 .unwrap_or_else(|e| panic!("Unrecoverable: spawn mock BlasActor: {e}"));
211 #[allow(unused_mut)]
212 let mut children = KernelChildren::new(stub);
213 #[cfg(feature = "cusolver")]
214 {
215 if self
216 .config
217 .enabled_libraries
218 .contains(EnabledLibraries::CUSOLVER)
219 {
220 let solver_stub = ctx
221 .spawn::<crate::kernel::SolverActor>(
222 crate::kernel::SolverActor::mock_props(),
223 "solver",
224 )
225 .unwrap_or_else(|e| panic!("Unrecoverable: spawn mock SolverActor: {e}"));
226 children.solver = Some(solver_stub);
227 }
228 }
229 #[cfg(feature = "nvrtc")]
230 {
231 if self
232 .config
233 .enabled_libraries
234 .contains(EnabledLibraries::NVRTC)
235 {
236 let nvrtc_stub = ctx
237 .spawn::<crate::kernel::NvrtcActor>(
238 crate::kernel::NvrtcActor::mock_props(),
239 "nvrtc",
240 )
241 .unwrap_or_else(|e| panic!("Unrecoverable: spawn mock NvrtcActor: {e}"));
242 children.nvrtc = Some(nvrtc_stub);
243 }
244 }
245 self.children = Some(children.clone());
246 self.parent.tell(DeviceMsg::ContextReady { children });
247 info!(device_id, "ContextActor (mock) ready");
248 return;
249 }
250
251 let cuda_ctx = match cudarc::driver::CudaContext::new(device_id as usize) {
252 Ok(c) => c,
253 Err(e) => {
254 panic!("{CONTEXT_POISONED_TAG}: CudaContext::new({device_id}) failed: {e}");
255 }
256 };
257 let stream = match cuda_ctx.new_stream() {
258 Ok(s) => s,
259 Err(e) => {
260 panic!("{CONTEXT_POISONED_TAG}: new_stream failed on device {device_id}: {e}");
261 }
262 };
263
264 self.state.bump_generation();
265 self.state.install_context(cuda_ctx.clone());
266 self.stream = Some(stream.clone());
267
268 let allocator: Arc<dyn StreamAllocator> =
271 Arc::new(PerActorAllocator::with_context(cuda_ctx.clone()));
272 self.allocator = Some(allocator.clone());
273
274 let libs = self.config.enabled_libraries;
275
276 let blas_stream = if libs.contains(EnabledLibraries::BLAS) {
278 allocator.acquire(Default::default())
279 } else {
280 stream.clone()
281 };
282 let blas_props = BlasActor::props(
283 blas_stream.clone(),
284 allocator.clone(),
285 self.completion.clone(),
286 self.state.clone(),
287 );
288 let _ = blas_props; let blas_alloc = crate::stream::PerActorAllocator::new(blas_stream.clone());
300 let blas_props = BlasActor::props_legacy(
301 blas_stream.clone(),
302 blas_alloc,
303 HostFnCompletion::new(),
304 self.state.clone(),
305 );
306 let blas_ref = ctx
307 .spawn::<BlasActor>(blas_props, "blas")
308 .unwrap_or_else(|e| panic!("Unrecoverable: spawn BlasActor: {e}"));
309
310 #[cfg(feature = "cudnn")]
311 let cudnn_ref = if libs.contains(EnabledLibraries::CUDNN) {
312 let s = allocator.acquire(Default::default());
313 let props = crate::kernel::CudnnActor::props(
314 s,
315 allocator.clone(),
316 self.completion.clone(),
317 self.state.clone(),
318 );
319 Some(
320 ctx.spawn::<crate::kernel::CudnnActor>(props, "cudnn")
321 .unwrap_or_else(|e| panic!("Unrecoverable: spawn CudnnActor: {e}")),
322 )
323 } else {
324 None
325 };
326
327 #[cfg(feature = "cufft")]
328 let fft_ref = if libs.contains(EnabledLibraries::CUFFT) {
329 let s = allocator.acquire(Default::default());
330 let props = crate::kernel::FftActor::props(
331 s,
332 allocator.clone(),
333 self.completion.clone(),
334 self.state.clone(),
335 cuda_ctx.clone(),
336 );
337 Some(
338 ctx.spawn::<crate::kernel::FftActor>(props, "fft")
339 .unwrap_or_else(|e| panic!("Unrecoverable: spawn FftActor: {e}")),
340 )
341 } else {
342 None
343 };
344
345 #[cfg(feature = "curand")]
346 let rng_ref = if libs.contains(EnabledLibraries::CURAND) {
347 let s = allocator.acquire(Default::default());
348 let props = crate::kernel::RngActor::props(
349 s,
350 allocator.clone(),
351 self.completion.clone(),
352 self.state.clone(),
353 0,
354 );
355 Some(
356 ctx.spawn::<crate::kernel::RngActor>(props, "rng")
357 .unwrap_or_else(|e| panic!("Unrecoverable: spawn RngActor: {e}")),
358 )
359 } else {
360 None
361 };
362
363 #[cfg(feature = "cusolver")]
364 let solver_ref = if libs.contains(EnabledLibraries::CUSOLVER) {
365 let s = allocator.acquire(Default::default());
366 let props = crate::kernel::SolverActor::props(
367 s,
368 allocator.clone(),
369 self.completion.clone(),
370 self.state.clone(),
371 );
372 Some(
373 ctx.spawn::<crate::kernel::SolverActor>(props, "solver")
374 .unwrap_or_else(|e| panic!("Unrecoverable: spawn SolverActor: {e}")),
375 )
376 } else {
377 None
378 };
379
380 #[cfg(feature = "nvrtc")]
381 let nvrtc_ref = if libs.contains(EnabledLibraries::NVRTC) {
382 let s = allocator.acquire(Default::default());
383 let props = crate::kernel::NvrtcActor::props(
384 s,
385 allocator.clone(),
386 self.completion.clone(),
387 self.state.clone(),
388 cuda_ctx.clone(),
389 );
390 Some(
391 ctx.spawn::<crate::kernel::NvrtcActor>(props, "nvrtc")
392 .unwrap_or_else(|e| panic!("Unrecoverable: spawn NvrtcActor: {e}")),
393 )
394 } else {
395 None
396 };
397
398 #[allow(unused_mut)]
399 let mut children = KernelChildren::new(blas_ref);
400 #[cfg(feature = "cudnn")]
401 {
402 children.cudnn = cudnn_ref;
403 }
404 #[cfg(feature = "cufft")]
405 {
406 children.fft = fft_ref;
407 }
408 #[cfg(feature = "curand")]
409 {
410 children.rng = rng_ref;
411 }
412 #[cfg(feature = "cusolver")]
413 {
414 children.solver = solver_ref;
415 }
416 #[cfg(feature = "nvrtc")]
417 {
418 children.nvrtc = nvrtc_ref;
419 }
420 self.children = Some(children.clone());
421 self.parent.tell(DeviceMsg::ContextReady { children });
422 info!(
423 device_id,
424 generation = self.state.generation(),
425 "ContextActor ready"
426 );
427 }
428
429 fn alloc<T: DeviceRepr + ValidAsZeroBits>(&self, len: usize) -> Result<GpuRef<T>, GpuError> {
432 if self.config.mock_mode {
433 return Err(GpuError::Unrecoverable(
434 "alloc not supported in mock mode".into(),
435 ));
436 }
437 let Some(stream) = self.stream.clone() else {
438 return Err(GpuError::GpuRefStale("context not ready"));
439 };
440 match stream.alloc_zeros::<T>(len) {
441 Ok(slice) => Ok(GpuRef::<T>::new(Arc::new(slice), &self.state)),
442 Err(e) => Err(GpuError::OutOfMemory(format!("alloc {len}: {e}"))),
443 }
444 }
445}
446
447pub(super) fn run_copy_to_host<T: DeviceRepr + Send + 'static>(
450 src: GpuRef<T>,
451 mut dst: HostBuf<T>,
452 stream: Arc<cudarc::driver::CudaStream>,
453 completion: Arc<dyn CompletionStrategy>,
454 reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
455) {
456 let src_slice = match src.access() {
457 Ok(s) => s.clone(),
458 Err(e) => {
459 let _ = reply.send(Err(e));
460 return;
461 }
462 };
463 if src_slice.len() != dst.len() {
464 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
465 "memcpy len mismatch: src={} dst={}",
466 src_slice.len(),
467 dst.len()
468 ))));
469 return;
470 }
471
472 let res = match &mut dst {
475 HostBuf::Owned(v) => stream.memcpy_dtoh(&*src_slice, v.as_mut_slice()),
476 HostBuf::Pinned(p) => stream.memcpy_dtoh(&*src_slice, p.as_mut_slice()),
477 };
478 if let Err(e) = res {
479 let _ = reply.send(Err(GpuError::LibraryError {
480 lib: "driver",
481 msg: format!("memcpy_dtoh: {e}"),
482 }));
483 return;
484 }
485
486 envelope::run_kernel("driver", &stream, &completion, dst, reply, move || {
488 Ok::<_, GpuError>((src_slice,))
489 });
490}
491
492pub(super) fn run_copy_from_host<T: DeviceRepr + Send + 'static>(
493 src: HostBuf<T>,
494 dst: GpuRef<T>,
495 stream: Arc<cudarc::driver::CudaStream>,
496 completion: Arc<dyn CompletionStrategy>,
497 reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
498) {
499 let dst_slice = match dst.access() {
500 Ok(s) => s.clone(),
501 Err(e) => {
502 let _ = reply.send(Err(e));
503 return;
504 }
505 };
506 if dst_slice.len() != src.len() {
507 let _ = reply.send(Err(GpuError::Unrecoverable(format!(
508 "memcpy len mismatch: src={} dst={}",
509 src.len(),
510 dst_slice.len()
511 ))));
512 return;
513 }
514 let mut dst_owned = match Arc::try_unwrap(dst_slice) {
515 Ok(s) => s,
516 Err(_) => {
517 let _ = reply.send(Err(GpuError::Unrecoverable(
518 "H2D destination has multiple live references".into(),
519 )));
520 return;
521 }
522 };
523 let res = match &src {
524 HostBuf::Owned(v) => stream.memcpy_htod(v.as_slice(), &mut dst_owned),
525 HostBuf::Pinned(p) => stream.memcpy_htod(p.as_slice(), &mut dst_owned),
526 };
527 if let Err(e) = res {
528 let _ = reply.send(Err(GpuError::LibraryError {
529 lib: "driver",
530 msg: format!("memcpy_htod: {e}"),
531 }));
532 return;
533 }
534 dst.record_write(&stream);
535 envelope::run_kernel("driver", &stream, &completion, src, reply, move || {
536 Ok::<_, GpuError>((dst_owned,))
537 });
538}
539
540#[async_trait]
541impl Actor for ContextActor {
542 type Msg = ContextMsg;
543
544 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
545 ctx.self_ref().tell(ContextMsg::Init);
546 }
547
548 async fn handle(&mut self, ctx: &mut Context<Self>, msg: ContextMsg) {
549 match msg {
550 ContextMsg::Init => self.run_init(ctx).await,
551
552 ContextMsg::SnapshotStream { reply } => {
553 let _ = reply.send(self.stream.clone());
554 }
555
556 ContextMsg::Alloc(boxed) => {
558 boxed.run(self.stream.as_ref(), &self.state, self.config.mock_mode);
559 }
560 ContextMsg::CopyToHost(boxed) => {
561 let stream = self.stream.clone().expect("ctx not ready");
562 boxed.run(stream, self.completion.clone());
563 }
564 ContextMsg::CopyFromHost(boxed) => {
565 let stream = self.stream.clone().expect("ctx not ready");
566 boxed.run(stream, self.completion.clone());
567 }
568
569 ContextMsg::AllocateF32 { len, reply } => {
570 let _ = reply.send(self.alloc::<f32>(len));
571 }
572 ContextMsg::AllocateF64 { len, reply } => {
573 let _ = reply.send(self.alloc::<f64>(len));
574 }
575 ContextMsg::AllocateI8 { len, reply } => {
576 let _ = reply.send(self.alloc::<i8>(len));
577 }
578 ContextMsg::AllocateI32 { len, reply } => {
579 let _ = reply.send(self.alloc::<i32>(len));
580 }
581 ContextMsg::AllocateI64 { len, reply } => {
582 let _ = reply.send(self.alloc::<i64>(len));
583 }
584 ContextMsg::AllocateU8 { len, reply } => {
585 let _ = reply.send(self.alloc::<u8>(len));
586 }
587 ContextMsg::AllocateU32 { len, reply } => {
588 let _ = reply.send(self.alloc::<u32>(len));
589 }
590 ContextMsg::AllocateU64 { len, reply } => {
591 let _ = reply.send(self.alloc::<u64>(len));
592 }
593 #[cfg(feature = "f16")]
594 ContextMsg::AllocateF16 { len, reply } => {
595 let _ = reply.send(self.alloc::<half::f16>(len));
596 }
597 #[cfg(feature = "f16")]
598 ContextMsg::AllocateBf16 { len, reply } => {
599 let _ = reply.send(self.alloc::<half::bf16>(len));
600 }
601
602 ContextMsg::CopyToHostF32 { src, dst, reply } => {
603 let stream = self.stream.clone().expect("ctx not ready");
604 run_copy_to_host(src, dst, stream, self.completion.clone(), reply);
605 }
606 ContextMsg::CopyFromHostF32 { src, dst, reply } => {
607 let stream = self.stream.clone().expect("ctx not ready");
608 run_copy_from_host(src, dst, stream, self.completion.clone(), reply);
609 }
610 ContextMsg::CopyToHostF64 { src, dst, reply } => {
611 let stream = self.stream.clone().expect("ctx not ready");
612 run_copy_to_host(src, dst, stream, self.completion.clone(), reply);
613 }
614 ContextMsg::CopyFromHostF64 { src, dst, reply } => {
615 let stream = self.stream.clone().expect("ctx not ready");
616 run_copy_from_host(src, dst, stream, self.completion.clone(), reply);
617 }
618 ContextMsg::CopyToHostI32 { src, dst, reply } => {
619 let stream = self.stream.clone().expect("ctx not ready");
620 run_copy_to_host(src, dst, stream, self.completion.clone(), reply);
621 }
622 ContextMsg::CopyFromHostI32 { src, dst, reply } => {
623 let stream = self.stream.clone().expect("ctx not ready");
624 run_copy_from_host(src, dst, stream, self.completion.clone(), reply);
625 }
626 ContextMsg::CopyToHostU32 { src, dst, reply } => {
627 let stream = self.stream.clone().expect("ctx not ready");
628 run_copy_to_host(src, dst, stream, self.completion.clone(), reply);
629 }
630 ContextMsg::CopyFromHostU32 { src, dst, reply } => {
631 let stream = self.stream.clone().expect("ctx not ready");
632 run_copy_from_host(src, dst, stream, self.completion.clone(), reply);
633 }
634 ContextMsg::CopyToHostU8 { src, dst, reply } => {
635 let stream = self.stream.clone().expect("ctx not ready");
636 run_copy_to_host(src, dst, stream, self.completion.clone(), reply);
637 }
638 ContextMsg::CopyFromHostU8 { src, dst, reply } => {
639 let stream = self.stream.clone().expect("ctx not ready");
640 run_copy_from_host(src, dst, stream, self.completion.clone(), reply);
641 }
642 }
643 }
644
645 async fn post_restart(&mut self, ctx: &mut Context<Self>, err: &str) {
646 warn!(device_id = self.config.device_id, %err, "ContextActor post_restart");
647 self.parent.tell(DeviceMsg::ContextLost);
648 ctx.self_ref().tell(ContextMsg::Init);
649 }
650
651 async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
652 debug!(device_id = self.config.device_id, "ContextActor post_stop");
653 self.stream = None;
654 self.allocator = None;
655 self.children = None;
656 self.state.clear_context();
657 self.parent.tell(DeviceMsg::ContextLost);
658 }
659}