1use std::sync::Arc;
27
28use async_trait::async_trait;
29use atomr_core::actor::{Actor, Context, Props};
30use cudarc::cublas::CudaBlas;
31use cudarc::driver::sys as driver_sys;
32use cudarc::driver::CudaGraph;
33use parking_lot::Mutex;
34use tokio::sync::oneshot;
35
36use crate::completion::CompletionStrategy;
37use crate::device::DeviceState;
38use crate::error::GpuError;
39
40pub mod record;
41
42#[cfg(feature = "cufft")]
43pub use record::fft_r2c::FftR2COp;
44pub use record::memcpy::MemcpyOp;
45#[cfg(feature = "curand")]
46pub use record::rng_fill_uniform::RngFillUniformOp;
47pub use record::sgemm::SgemmOp;
48
49pub mod child;
50#[cfg(feature = "graphs-conditional")]
51pub mod conditional;
52pub mod dot;
53pub mod exec_update;
54
55pub use child::ChildGraphOp;
56pub use dot::{export_dot, DotFlags};
57pub use exec_update::{exec_update, GraphExecUpdateOutcome};
58
59const LIB: &str = "graph";
60
61#[doc(hidden)]
74pub struct MockGraphRecordCtx {
75 parent_graph: driver_sys::CUgraph,
76 stream: Option<Arc<cudarc::driver::CudaStream>>,
77}
78
79impl MockGraphRecordCtx {
80 pub fn new(parent_graph: driver_sys::CUgraph) -> Self {
81 Self {
82 parent_graph,
83 stream: None,
84 }
85 }
86
87 pub fn with_stream(mut self, stream: Arc<cudarc::driver::CudaStream>) -> Self {
88 self.stream = Some(stream);
89 self
90 }
91
92 pub fn parent_graph(&self) -> driver_sys::CUgraph {
93 self.parent_graph
94 }
95
96 pub fn stream(&self) -> Option<&Arc<cudarc::driver::CudaStream>> {
97 self.stream.as_ref()
98 }
99
100 pub fn as_ctx(&self) -> GraphRecordCtx<'_> {
102 GraphRecordCtx {
103 stream: self.stream.as_ref(),
104 blas: None,
105 #[cfg(feature = "curand")]
106 rng: None,
107 #[cfg(feature = "cufft")]
108 fft: None,
109 parent_graph: Some(self.parent_graph),
110 }
111 }
112}
113
114pub trait GraphOpRecord {
118 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError>;
119}
120
121pub struct SendGraph(Arc<CudaGraph>);
125unsafe impl Send for SendGraph {}
126unsafe impl Sync for SendGraph {}
127
128impl Clone for SendGraph {
129 fn clone(&self) -> Self {
130 Self(self.0.clone())
131 }
132}
133
134#[derive(Clone)]
135pub struct GraphHandle {
136 graph: Option<SendGraph>,
137 generation: u64,
138 #[doc(hidden)]
142 synthetic_cu_graph: driver_sys::CUgraph,
143 #[doc(hidden)]
144 synthetic_cu_graph_exec: driver_sys::CUgraphExec,
145}
146
147unsafe impl Send for GraphHandle {}
153unsafe impl Sync for GraphHandle {}
154
155impl GraphHandle {
156 pub fn from_graph(graph: Arc<CudaGraph>, state: &Arc<DeviceState>) -> Self {
159 Self {
160 graph: Some(SendGraph(graph)),
161 generation: state.generation(),
162 synthetic_cu_graph: std::ptr::null_mut(),
163 synthetic_cu_graph_exec: std::ptr::null_mut(),
164 }
165 }
166
167 pub fn generation(&self) -> u64 {
168 self.generation
169 }
170
171 pub fn cu_graph(&self) -> driver_sys::CUgraph {
179 if let Some(g) = self.graph.as_ref() {
180 g.0.cu_graph()
181 } else {
182 self.synthetic_cu_graph
183 }
184 }
185
186 pub fn cu_graph_exec(&self) -> driver_sys::CUgraphExec {
192 if let Some(g) = self.graph.as_ref() {
193 g.0.cu_graph_exec()
194 } else {
195 self.synthetic_cu_graph_exec
196 }
197 }
198
199 #[doc(hidden)]
204 pub fn synthetic_for_tests() -> Self {
205 Self {
206 graph: None,
207 generation: 0,
208 synthetic_cu_graph: std::ptr::null_mut(),
209 synthetic_cu_graph_exec: std::ptr::null_mut(),
210 }
211 }
212}
213
214pub struct GraphRecordCtx<'a> {
227 pub stream: Option<&'a Arc<cudarc::driver::CudaStream>>,
232 pub blas: Option<&'a CudaBlas>,
235 #[cfg(feature = "curand")]
237 pub rng: Option<&'a cudarc::curand::CudaRng>,
238 #[cfg(feature = "cufft")]
240 pub fft: Option<&'a cudarc::cufft::CudaFft>,
241 pub parent_graph: Option<driver_sys::CUgraph>,
245}
246
247impl<'a> GraphRecordCtx<'a> {
248 pub fn require_stream(&self) -> Result<&'a Arc<cudarc::driver::CudaStream>, GpuError> {
251 self.stream.ok_or_else(|| {
252 GpuError::Unrecoverable("GraphRecordCtx: no captured stream available".into())
253 })
254 }
255
256 pub fn parent_graph(&self) -> driver_sys::CUgraph {
261 self.parent_graph.unwrap_or(std::ptr::null_mut())
262 }
263}
264
265pub trait GraphOp: Send + 'static {
273 fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError>;
277
278 fn op_name(&self) -> &'static str {
281 "graph_op"
282 }
283}
284
285impl GraphOp for Box<dyn GraphOp> {
286 fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
287 (**self).record(ctx)
288 }
289 fn op_name(&self) -> &'static str {
290 (**self).op_name()
291 }
292}
293
294#[deprecated(
299 since = "0.1.0",
300 note = "construct individual `impl GraphOp` types (e.g. `SgemmOp`, `MemcpyOp`) and \
301 push them as `Box<dyn GraphOp>` instead of using the closed enum"
302)]
303#[allow(deprecated)]
304pub enum GraphOpLegacy {
305 Sgemm(Box<SgemmOp>),
306 Memcpy(Box<MemcpyOp>),
308 #[cfg(feature = "curand")]
310 RngFillUniform(Box<RngFillUniformOp>),
311 #[cfg(feature = "cufft")]
315 FftR2C(Box<FftR2COp>),
316}
317
318#[allow(deprecated)]
319impl GraphOp for GraphOpLegacy {
320 fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
321 match self {
322 GraphOpLegacy::Sgemm(b) => b.record(ctx),
323 GraphOpLegacy::Memcpy(m) => m.record(ctx),
324 #[cfg(feature = "curand")]
325 GraphOpLegacy::RngFillUniform(r) => r.record(ctx),
326 #[cfg(feature = "cufft")]
327 GraphOpLegacy::FftR2C(r) => r.record(ctx),
328 }
329 }
330
331 fn op_name(&self) -> &'static str {
332 match self {
333 GraphOpLegacy::Sgemm(b) => b.op_name(),
334 GraphOpLegacy::Memcpy(m) => m.op_name(),
335 #[cfg(feature = "curand")]
336 GraphOpLegacy::RngFillUniform(r) => r.op_name(),
337 #[cfg(feature = "cufft")]
338 GraphOpLegacy::FftR2C(r) => r.op_name(),
339 }
340 }
341}
342
343pub enum GraphMsg {
344 Record {
346 script: Vec<Box<dyn GraphOp>>,
347 reply: oneshot::Sender<Result<GraphHandle, GpuError>>,
348 },
349 Launch {
351 handle: GraphHandle,
352 reply: oneshot::Sender<Result<(), GpuError>>,
353 },
354 #[cfg(feature = "cufft")]
357 SetFftPlan {
358 plan: cudarc::cufft::CudaFft,
359 reply: oneshot::Sender<()>,
360 },
361}
362
363struct SendBlas(CudaBlas);
364unsafe impl Send for SendBlas {}
365unsafe impl Sync for SendBlas {}
366
367#[cfg(feature = "curand")]
368struct SendRng(cudarc::curand::CudaRng);
369#[cfg(feature = "curand")]
370unsafe impl Send for SendRng {}
371#[cfg(feature = "curand")]
372unsafe impl Sync for SendRng {}
373
374#[cfg(feature = "cufft")]
375struct SendFft(cudarc::cufft::CudaFft);
376#[cfg(feature = "cufft")]
377unsafe impl Send for SendFft {}
378#[cfg(feature = "cufft")]
379unsafe impl Sync for SendFft {}
380
381pub struct GraphActor {
382 inner: GraphInner,
383}
384
385#[allow(dead_code)]
386enum GraphInner {
387 Real {
388 stream: Arc<cudarc::driver::CudaStream>,
389 completion: Arc<dyn CompletionStrategy>,
390 state: Arc<DeviceState>,
391 blas: Option<Mutex<SendBlas>>,
394 #[cfg(feature = "curand")]
395 rng: Option<Mutex<SendRng>>,
396 #[cfg(feature = "cufft")]
397 fft: Mutex<Option<SendFft>>,
398 },
399 Mock,
400}
401
402impl GraphActor {
403 pub fn props(
404 stream: Arc<cudarc::driver::CudaStream>,
405 completion: Arc<dyn CompletionStrategy>,
406 state: Arc<DeviceState>,
407 ) -> Props<Self> {
408 Props::create(move || {
409 let blas = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
413 CudaBlas::new(stream.clone())
414 })) {
415 Ok(Ok(b)) => Some(Mutex::new(SendBlas(b))),
416 _ => None,
417 };
418 #[cfg(feature = "curand")]
419 let rng = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
420 cudarc::curand::CudaRng::new(0, stream.clone())
421 })) {
422 Ok(Ok(r)) => Some(Mutex::new(SendRng(r))),
423 _ => None,
424 };
425 GraphActor {
426 inner: GraphInner::Real {
427 stream: stream.clone(),
428 completion: completion.clone(),
429 state: state.clone(),
430 blas,
431 #[cfg(feature = "curand")]
432 rng,
433 #[cfg(feature = "cufft")]
434 fft: Mutex::new(None),
435 },
436 }
437 })
438 }
439
440 pub fn mock_props() -> Props<Self> {
441 Props::create(|| GraphActor {
442 inner: GraphInner::Mock,
443 })
444 }
445}
446
447fn run_record(
448 stream: &Arc<cudarc::driver::CudaStream>,
449 state: &Arc<DeviceState>,
450 blas: &Option<Mutex<SendBlas>>,
451 #[cfg(feature = "curand")] rng: &Option<Mutex<SendRng>>,
452 #[cfg(feature = "cufft")] fft: &Mutex<Option<SendFft>>,
453 mut script: Vec<Box<dyn GraphOp>>,
454) -> Result<GraphHandle, GpuError> {
455 let begin_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
457 stream.begin_capture(driver_sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL)
458 }));
459 match begin_res {
460 Ok(Ok(())) => {}
461 Ok(Err(e)) => {
462 return Err(GpuError::LibraryError {
463 lib: LIB,
464 msg: format!("begin_capture: {e}"),
465 });
466 }
467 Err(_) => {
468 return Err(GpuError::Unrecoverable(
469 "GraphActor::Record: CUDA driver not loadable".into(),
470 ));
471 }
472 }
473
474 let bail = |e: GpuError, stream: &Arc<cudarc::driver::CudaStream>| -> GpuError {
476 let _ = stream.end_capture(
477 driver_sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH,
478 );
479 e
480 };
481
482 let blas_guard = blas.as_ref().map(|m| m.lock());
487 #[cfg(feature = "curand")]
488 let rng_guard = rng.as_ref().map(|m| m.lock());
489 #[cfg(feature = "cufft")]
490 let fft_guard = fft.lock();
491
492 let mut ctx = GraphRecordCtx {
493 stream: Some(stream),
494 blas: blas_guard.as_ref().map(|g| &g.0),
495 #[cfg(feature = "curand")]
496 rng: rng_guard.as_ref().map(|g| &g.0),
497 #[cfg(feature = "cufft")]
498 fft: fft_guard.as_ref().map(|g| &g.0),
499 parent_graph: None,
500 };
501
502 for op in script.iter_mut() {
503 if let Err(e) = op.record(&mut ctx) {
504 drop(ctx);
505 #[cfg(feature = "cufft")]
506 drop(fft_guard);
507 #[cfg(feature = "curand")]
508 drop(rng_guard);
509 drop(blas_guard);
510 return Err(bail(e, stream));
511 }
512 }
513
514 drop(ctx);
515 #[cfg(feature = "cufft")]
516 drop(fft_guard);
517 #[cfg(feature = "curand")]
518 drop(rng_guard);
519 drop(blas_guard);
520
521 let end_res = stream.end_capture(
523 driver_sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH,
524 );
525 let cuda_graph = match end_res {
526 Ok(Some(g)) => g,
527 Ok(None) => {
528 return Err(GpuError::LibraryError {
529 lib: LIB,
530 msg: "end_capture returned None".into(),
531 });
532 }
533 Err(e) => {
534 return Err(GpuError::LibraryError {
535 lib: LIB,
536 msg: format!("end_capture: {e}"),
537 });
538 }
539 };
540 Ok(GraphHandle::from_graph(Arc::new(cuda_graph), state))
541}
542
543#[async_trait]
544impl Actor for GraphActor {
545 type Msg = GraphMsg;
546
547 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: GraphMsg) {
548 match &self.inner {
549 GraphInner::Mock => match msg {
550 GraphMsg::Record { reply, .. } => {
551 let _ = reply.send(Err(GpuError::Unrecoverable(
552 "GraphActor in mock mode".into(),
553 )));
554 }
555 GraphMsg::Launch { reply, .. } => {
556 let _ = reply.send(Err(GpuError::Unrecoverable(
557 "GraphActor in mock mode".into(),
558 )));
559 }
560 #[cfg(feature = "cufft")]
561 GraphMsg::SetFftPlan { reply, .. } => {
562 let _ = reply.send(());
563 }
564 },
565 GraphInner::Real {
566 stream,
567 completion,
568 state,
569 blas,
570 #[cfg(feature = "curand")]
571 rng,
572 #[cfg(feature = "cufft")]
573 fft,
574 } => match msg {
575 GraphMsg::Record { script, reply } => {
576 let res = run_record(
577 stream,
578 state,
579 blas,
580 #[cfg(feature = "curand")]
581 rng,
582 #[cfg(feature = "cufft")]
583 fft,
584 script,
585 );
586 let _ = reply.send(res);
587 }
588 #[cfg(feature = "cufft")]
589 GraphMsg::SetFftPlan { plan, reply } => {
590 *fft.lock() = Some(SendFft(plan));
591 let _ = reply.send(());
592 }
593 GraphMsg::Launch { handle, reply } => {
594 if handle.generation != state.generation() {
595 let _ = reply.send(Err(GpuError::GpuRefStale(
596 "graph captured against rebuilt context",
597 )));
598 return;
599 }
600 let Some(graph) = handle.graph.as_ref() else {
601 let _ = reply.send(Err(GpuError::Unrecoverable(
602 "GraphActor::Launch: synthetic GraphHandle has no captured graph"
603 .into(),
604 )));
605 return;
606 };
607 let res = graph.0.launch().map_err(|e| GpuError::LibraryError {
608 lib: LIB,
609 msg: format!("launch: {e}"),
610 });
611 if let Err(e) = res {
612 let _ = reply.send(Err(e));
613 return;
614 }
615 let stream = stream.clone();
616 let completion = completion.clone();
617 tokio::spawn(async move {
618 let r = completion.await_completion(&stream).await;
619 let _ = reply.send(r);
620 });
621 }
622 },
623 }
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use std::sync::Mutex as StdMutex;
631
632 struct MockOp {
636 name: &'static str,
637 trace: Arc<StdMutex<Vec<&'static str>>>,
638 }
639
640 impl GraphOp for MockOp {
641 fn record(&mut self, _ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
642 self.trace.lock().unwrap().push(self.name);
643 Ok(())
644 }
645 fn op_name(&self) -> &'static str {
646 self.name
647 }
648 }
649
650 struct CounterOp {
652 count: Arc<StdMutex<u32>>,
653 }
654 impl GraphOp for CounterOp {
655 fn record(&mut self, _ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
656 *self.count.lock().unwrap() += 1;
657 Ok(())
658 }
659 fn op_name(&self) -> &'static str {
660 "counter_op"
661 }
662 }
663
664 fn no_gpu_ctx<'a>() -> GraphRecordCtx<'a> {
665 GraphRecordCtx {
666 stream: None,
667 blas: None,
668 #[cfg(feature = "curand")]
669 rng: None,
670 #[cfg(feature = "cufft")]
671 fft: None,
672 parent_graph: None,
673 }
674 }
675
676 #[test]
677 fn external_graph_op_impls_can_be_appended_and_recorded() {
678 let trace: Arc<StdMutex<Vec<&'static str>>> = Arc::new(StdMutex::new(Vec::new()));
679 let count = Arc::new(StdMutex::new(0u32));
680
681 let mut script: Vec<Box<dyn GraphOp>> = Vec::new();
684 script.push(Box::new(MockOp {
685 name: "first_mock",
686 trace: trace.clone(),
687 }));
688 script.push(Box::new(CounterOp {
689 count: count.clone(),
690 }));
691 script.push(Box::new(MockOp {
692 name: "second_mock",
693 trace: trace.clone(),
694 }));
695 script.push(Box::new(CounterOp {
696 count: count.clone(),
697 }));
698
699 assert_eq!(script[0].op_name(), "first_mock");
701 assert_eq!(script[1].op_name(), "counter_op");
702 assert_eq!(script[2].op_name(), "second_mock");
703 assert_eq!(script[3].op_name(), "counter_op");
704
705 let mut ctx = no_gpu_ctx();
709 for op in script.iter_mut() {
710 op.record(&mut ctx).expect("mock op must record");
711 }
712
713 assert_eq!(
714 *trace.lock().unwrap(),
715 vec!["first_mock", "second_mock"],
716 "MockOp::record should append its name in script order"
717 );
718 assert_eq!(*count.lock().unwrap(), 2, "CounterOp ran twice");
719 }
720
721 #[test]
722 fn require_stream_returns_clean_error_in_no_gpu_ctx() {
723 let ctx = no_gpu_ctx();
724 let err = ctx.require_stream().unwrap_err();
725 assert!(matches!(err, GpuError::Unrecoverable(_)));
726 }
727
728 #[test]
729 fn graph_op_legacy_dispatches_to_inner_op() {
730 let trace: Arc<StdMutex<Vec<&'static str>>> = Arc::new(StdMutex::new(Vec::new()));
740
741 let mut boxed: Box<dyn GraphOp> = Box::new(MockOp {
744 name: "via_box_dyn",
745 trace: trace.clone(),
746 });
747 let mut ctx = no_gpu_ctx();
748 boxed.record(&mut ctx).unwrap();
749 assert_eq!(*trace.lock().unwrap(), vec!["via_box_dyn"]);
750 assert_eq!(boxed.op_name(), "via_box_dyn");
751 }
752}