atomr_accel_cuda/graph/
child.rs1use std::sync::Arc;
9
10use cudarc::driver::sys as driver_sys;
11
12use crate::error::GpuError;
13use crate::graph::{GraphHandle, GraphOpRecord, GraphRecordCtx};
14
15const LIB: &str = "graph";
16
17pub struct ChildGraphOp {
19 pub child: GraphHandle,
20}
21
22impl GraphOpRecord for ChildGraphOp {
23 fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
24 let parent = ctx.parent_graph();
27 let cu_child = self.child.cu_graph();
28 let mut node: driver_sys::CUgraphNode = std::ptr::null_mut();
29 let s = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
30 driver_sys::cuGraphAddChildGraphNode(
31 &mut node as *mut _,
32 parent,
33 std::ptr::null(),
34 0,
35 cu_child,
36 )
37 }));
38 match s {
39 Ok(s) => {
40 if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
41 Ok(())
42 } else {
43 Err(GpuError::LibraryError {
44 lib: LIB,
45 msg: format!("cuGraphAddChildGraphNode: {s:?}"),
46 })
47 }
48 }
49 Err(_) => Err(GpuError::Unrecoverable(
50 "ChildGraphOp::record: CUDA driver not loadable".into(),
51 )),
52 }
53 }
54}
55
56pub fn child_graph_op(child: GraphHandle) -> ChildGraphOp {
58 ChildGraphOp { child }
59}
60
61pub struct ChildGraphInsertion {
64 pub op: ChildGraphOp,
65 pub keep_alive: Arc<()>,
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::graph::{GraphHandle, MockGraphRecordCtx};
72 use std::sync::Arc;
73
74 #[test]
75 fn child_graph_op_records_into_parent() {
76 let child = GraphHandle::synthetic_for_tests();
80 let op = child_graph_op(child);
81 let parent_graph: driver_sys::CUgraph = std::ptr::null_mut();
82 let mock = MockGraphRecordCtx::new(parent_graph);
83 let ctx: GraphRecordCtx<'_> = mock.as_ctx();
84 let r = op.record(&ctx);
85 match r {
90 Ok(()) => {}
91 Err(GpuError::Unrecoverable(_)) => {}
92 Err(GpuError::LibraryError { .. }) => {}
93 other => panic!("unexpected: {other:?}"),
94 }
95 let _ = Arc::new(()); }
97}