Skip to main content

atomr_accel_cuda/graph/
exec_update.rs

1//! `cuGraphExecUpdate` — in-place parameter update for an instantiated
2//! graph. Lets callers re-bind `GpuRef` pointers without re-running
3//! `cuGraphInstantiate`.
4//!
5//! The driver returns a [`GraphExecUpdateOutcome`] indicating whether
6//! the update succeeded as-is or whether the topology changed and the
7//! caller must re-instantiate.
8
9use cudarc::driver::sys as driver_sys;
10
11use crate::error::GpuError;
12use crate::graph::GraphHandle;
13
14const LIB: &str = "graph";
15
16/// Result of an update attempt. Values mirror the
17/// `CUgraphExecUpdateResult` enum.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum GraphExecUpdateOutcome {
20    Success,
21    /// Topology mismatch: the new graph's nodes don't line up. The
22    /// caller must rebuild from scratch.
23    TopologyChanged,
24    /// Other (driver-classified) failure.
25    Other,
26}
27
28impl From<driver_sys::CUgraphExecUpdateResult> for GraphExecUpdateOutcome {
29    fn from(r: driver_sys::CUgraphExecUpdateResult) -> Self {
30        // CU_GRAPH_EXEC_UPDATE_SUCCESS = 0
31        match r as u32 {
32            0 => GraphExecUpdateOutcome::Success,
33            // CUDA_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED = 2 (older drivers)
34            // CUDA_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED = 3
35            // We classify everything except success as TopologyChanged
36            // for the conservative "must reinstantiate" path; the Other
37            // bucket is reserved for future granularity.
38            2..=8 => GraphExecUpdateOutcome::TopologyChanged,
39            _ => GraphExecUpdateOutcome::Other,
40        }
41    }
42}
43
44/// Try to apply `new_graph`'s parameters to `exec`'s instantiated
45/// state. Wraps `cuGraphExecUpdate_v2` (CUDA 12+); returns an
46/// `Unrecoverable` on hosts where the symbol isn't loadable.
47pub fn exec_update(
48    exec: &GraphHandle,
49    new_graph_cu: driver_sys::CUgraph,
50) -> Result<GraphExecUpdateOutcome, GpuError> {
51    let exec_handle = exec.cu_graph_exec();
52    let probe = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
53        let mut info = driver_sys::CUgraphExecUpdateResultInfo_st {
54            result: driver_sys::CUgraphExecUpdateResult_enum::CU_GRAPH_EXEC_UPDATE_SUCCESS,
55            errorNode: std::ptr::null_mut(),
56            errorFromNode: std::ptr::null_mut(),
57        };
58        // SAFETY: exec_handle and new_graph_cu are caller-owned;
59        // info is a local out-pointer.
60        let s = unsafe {
61            driver_sys::cuGraphExecUpdate_v2(exec_handle, new_graph_cu, &mut info as *mut _)
62        };
63        (s, info.result)
64    }));
65    match probe {
66        Ok((s, result)) => {
67            if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
68                Ok(GraphExecUpdateOutcome::from(result))
69            } else if s == driver_sys::cudaError_enum::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE {
70                // Treat as a topology-class error — the caller should
71                // reinstantiate.
72                Ok(GraphExecUpdateOutcome::TopologyChanged)
73            } else {
74                Err(GpuError::LibraryError {
75                    lib: LIB,
76                    msg: format!("cuGraphExecUpdate_v2: {s:?}"),
77                })
78            }
79        }
80        Err(_) => Err(GpuError::Unrecoverable(
81            "exec_update: CUDA driver not loadable".into(),
82        )),
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn outcome_classification_round_trip() {
92        // Synthesize each variant via the From impl using cudarc's
93        // enum where possible.
94        use driver_sys::CUgraphExecUpdateResult_enum::*;
95        assert_eq!(
96            GraphExecUpdateOutcome::from(CU_GRAPH_EXEC_UPDATE_SUCCESS),
97            GraphExecUpdateOutcome::Success
98        );
99        // Topology-changed bucket — any non-zero, non-other value.
100        let topology_value: driver_sys::CUgraphExecUpdateResult =
101            unsafe { std::mem::transmute::<u32, _>(2) };
102        assert_eq!(
103            GraphExecUpdateOutcome::from(topology_value),
104            GraphExecUpdateOutcome::TopologyChanged
105        );
106    }
107
108    #[test]
109    fn param_rebind_round_trip() {
110        // Mock-mode: with a synthetic GraphHandle and null new graph,
111        // the call surfaces Unrecoverable on no-GPU hosts and a
112        // LibraryError on real ones. No panic.
113        let exec = GraphHandle::synthetic_for_tests();
114        let r = exec_update(&exec, std::ptr::null_mut());
115        match r {
116            Ok(_) => {}
117            Err(GpuError::Unrecoverable(_)) => {}
118            Err(GpuError::LibraryError { .. }) => {}
119            other => panic!("unexpected: {other:?}"),
120        }
121    }
122}