atomr_accel_cuda/graph/
exec_update.rs1use cudarc::driver::sys as driver_sys;
10
11use crate::error::GpuError;
12use crate::graph::GraphHandle;
13
14const LIB: &str = "graph";
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum GraphExecUpdateOutcome {
20 Success,
21 TopologyChanged,
24 Other,
26}
27
28impl From<driver_sys::CUgraphExecUpdateResult> for GraphExecUpdateOutcome {
29 fn from(r: driver_sys::CUgraphExecUpdateResult) -> Self {
30 match r as u32 {
32 0 => GraphExecUpdateOutcome::Success,
33 2..=8 => GraphExecUpdateOutcome::TopologyChanged,
39 _ => GraphExecUpdateOutcome::Other,
40 }
41 }
42}
43
44pub fn exec_update(
48 exec: &GraphHandle,
49 new_graph_cu: driver_sys::CUgraph,
50) -> Result<GraphExecUpdateOutcome, GpuError> {
51 let exec_handle = exec.cu_graph_exec();
52 let probe = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
53 let mut info = driver_sys::CUgraphExecUpdateResultInfo_st {
54 result: driver_sys::CUgraphExecUpdateResult_enum::CU_GRAPH_EXEC_UPDATE_SUCCESS,
55 errorNode: std::ptr::null_mut(),
56 errorFromNode: std::ptr::null_mut(),
57 };
58 let s = unsafe {
61 driver_sys::cuGraphExecUpdate_v2(exec_handle, new_graph_cu, &mut info as *mut _)
62 };
63 (s, info.result)
64 }));
65 match probe {
66 Ok((s, result)) => {
67 if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
68 Ok(GraphExecUpdateOutcome::from(result))
69 } else if s == driver_sys::cudaError_enum::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE {
70 Ok(GraphExecUpdateOutcome::TopologyChanged)
73 } else {
74 Err(GpuError::LibraryError {
75 lib: LIB,
76 msg: format!("cuGraphExecUpdate_v2: {s:?}"),
77 })
78 }
79 }
80 Err(_) => Err(GpuError::Unrecoverable(
81 "exec_update: CUDA driver not loadable".into(),
82 )),
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn outcome_classification_round_trip() {
92 use driver_sys::CUgraphExecUpdateResult_enum::*;
95 assert_eq!(
96 GraphExecUpdateOutcome::from(CU_GRAPH_EXEC_UPDATE_SUCCESS),
97 GraphExecUpdateOutcome::Success
98 );
99 let topology_value: driver_sys::CUgraphExecUpdateResult =
101 unsafe { std::mem::transmute::<u32, _>(2) };
102 assert_eq!(
103 GraphExecUpdateOutcome::from(topology_value),
104 GraphExecUpdateOutcome::TopologyChanged
105 );
106 }
107
108 #[test]
109 fn param_rebind_round_trip() {
110 let exec = GraphHandle::synthetic_for_tests();
114 let r = exec_update(&exec, std::ptr::null_mut());
115 match r {
116 Ok(_) => {}
117 Err(GpuError::Unrecoverable(_)) => {}
118 Err(GpuError::LibraryError { .. }) => {}
119 other => panic!("unexpected: {other:?}"),
120 }
121 }
122}