Skip to main content

atomr_accel_cuda/graph/
child.rs

1//! Child-graph composition.
2//!
3//! Wraps `cuGraphAddChildGraphNode` so an existing `GraphHandle` can
4//! be embedded as a node in a parent capture. This is how higher-level
5//! pipelines compose: a sub-graph for "data load" can be re-used
6//! across many enclosing per-step graphs.
7
8use std::sync::Arc;
9
10use cudarc::driver::sys as driver_sys;
11
12use crate::error::GpuError;
13use crate::graph::{GraphHandle, GraphOpRecord, GraphRecordCtx};
14
15const LIB: &str = "graph";
16
17/// Op variant for embedding a previously-recorded sub-graph.
18pub struct ChildGraphOp {
19    pub child: GraphHandle,
20}
21
22impl GraphOpRecord for ChildGraphOp {
23    fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
24        // SAFETY: we pull the raw cu_graph from the wrapped CudaGraph;
25        // CUDA owns it and we pass it through to cuGraphAddChildGraphNode.
26        let parent = ctx.parent_graph();
27        let cu_child = self.child.cu_graph();
28        let mut node: driver_sys::CUgraphNode = std::ptr::null_mut();
29        let s = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
30            driver_sys::cuGraphAddChildGraphNode(
31                &mut node as *mut _,
32                parent,
33                std::ptr::null(),
34                0,
35                cu_child,
36            )
37        }));
38        match s {
39            Ok(s) => {
40                if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
41                    Ok(())
42                } else {
43                    Err(GpuError::LibraryError {
44                        lib: LIB,
45                        msg: format!("cuGraphAddChildGraphNode: {s:?}"),
46                    })
47                }
48            }
49            Err(_) => Err(GpuError::Unrecoverable(
50                "ChildGraphOp::record: CUDA driver not loadable".into(),
51            )),
52        }
53    }
54}
55
56/// Convenience: create a child-graph op from a `GraphHandle` clone.
57pub fn child_graph_op(child: GraphHandle) -> ChildGraphOp {
58    ChildGraphOp { child }
59}
60
61/// Helper used by pipeline builders that want to keep a reference to
62/// the inserted child-graph for later parameter rebinding.
63pub struct ChildGraphInsertion {
64    pub op: ChildGraphOp,
65    pub keep_alive: Arc<()>,
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::graph::{GraphHandle, MockGraphRecordCtx};
72    use std::sync::Arc;
73
74    #[test]
75    fn child_graph_op_records_into_parent() {
76        // Build a synthetic GraphHandle (mock-mode — the wrapped
77        // CudaGraph is a dangling sentinel, but the GraphRecordCtx is
78        // also a mock that doesn't dereference it).
79        let child = GraphHandle::synthetic_for_tests();
80        let op = child_graph_op(child);
81        let parent_graph: driver_sys::CUgraph = std::ptr::null_mut();
82        let mock = MockGraphRecordCtx::new(parent_graph);
83        let ctx: GraphRecordCtx<'_> = mock.as_ctx();
84        let r = op.record(&ctx);
85        // No driver → Unrecoverable (panic caught) or LibraryError on
86        // null parent. Both acceptable; the point is to confirm we
87        // route into the cuGraphAddChildGraphNode path without
88        // panicking out.
89        match r {
90            Ok(()) => {}
91            Err(GpuError::Unrecoverable(_)) => {}
92            Err(GpuError::LibraryError { .. }) => {}
93            other => panic!("unexpected: {other:?}"),
94        }
95        let _ = Arc::new(()); // keep_alive type smoke check
96    }
97}