atomr_accel_cuda/graph/conditional.rs
1//! Conditional graph nodes (`cudaGraphConditionalNode`, CUDA 12.4+).
2//!
3//! Gated behind the `graphs-conditional` Cargo feature. Even with the
4//! feature on, the [`GraphActor`] runtime probe disables the path on
5//! older drivers — `IfNodeDescriptor::record` returns
6//! `Unrecoverable("conditional graphs unsupported")` if the
7//! `cuGraphConditionalHandleCreate` symbol can't be resolved.
8//!
9//! Public surface:
10//! - [`ConditionalKind`] — `If` or `While`.
11//! - [`IfNodeDescriptor`] / [`WhileNodeDescriptor`] — typed
12//! descriptors carrying the inner graph that will be replayed when
13//! the predicate is non-zero.
14
15#![cfg(feature = "graphs-conditional")]
16
17use std::sync::Arc;
18
19use cudarc::driver::sys as driver_sys;
20use cudarc::driver::CudaContext;
21
22use crate::error::GpuError;
23
24const LIB: &str = "graph";
25
26/// Kind of conditional node. Matches `CU_GRAPH_COND_TYPE_*`.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ConditionalKind {
29 /// Execute the inner graph at most once when the handle's value
30 /// is non-zero.
31 If,
32 /// Execute the inner graph repeatedly while the handle's value
33 /// stays non-zero. The inner graph is responsible for clearing
34 /// the handle when it should exit.
35 While,
36}
37
38impl ConditionalKind {
39 fn raw(self) -> driver_sys::CUgraphConditionalNodeType {
40 match self {
41 ConditionalKind::If => driver_sys::CUgraphConditionalNodeType::CU_GRAPH_COND_TYPE_IF,
42 ConditionalKind::While => {
43 driver_sys::CUgraphConditionalNodeType::CU_GRAPH_COND_TYPE_WHILE
44 }
45 }
46 }
47}
48
49/// Descriptor for an `If`-style conditional node. The inner graph is
50/// out-parameter-allocated by CUDA when the node is created — callers
51/// receive its handle back so they can populate it before exec
52/// instantiation.
53#[derive(Clone)]
54pub struct IfNodeDescriptor {
55 pub default_value: u32,
56}
57
58/// Descriptor for a `While`-style conditional node.
59#[derive(Clone)]
60pub struct WhileNodeDescriptor {
61 pub default_value: u32,
62}
63
64/// Build the raw `CUDA_CONDITIONAL_NODE_PARAMS` for `kind` against
65/// `parent`'s context. The returned struct embeds an out-pointer
66/// (`phGraph_out`) that CUDA fills with the inner graph; the actor
67/// adds it to the parent via `cuGraphAddNode_v2`.
68pub fn build_params(
69 kind: ConditionalKind,
70 handle: driver_sys::CUgraphConditionalHandle,
71 ctx: &Arc<CudaContext>,
72 inner_graph_out: *mut driver_sys::CUgraph,
73) -> driver_sys::CUDA_CONDITIONAL_NODE_PARAMS {
74 driver_sys::CUDA_CONDITIONAL_NODE_PARAMS {
75 handle,
76 type_: kind.raw(),
77 size: 1,
78 phGraph_out: inner_graph_out,
79 ctx: ctx.cu_ctx(),
80 }
81}
82
83/// Probe whether the running CUDA driver supports conditional graphs.
84/// Returns `Ok(true)` if `cuGraphConditionalHandleCreate` is callable
85/// (i.e. CUDA ≥ 12.4 with the symbol present); `Ok(false)` otherwise;
86/// `Err(GpuError::Unrecoverable)` if the loader panics.
87pub fn driver_supports_conditional() -> Result<bool, GpuError> {
88 // We probe by attempting to call the symbol with bogus args; CUDA
89 // returns CUDA_ERROR_INVALID_VALUE if the symbol is loadable but
90 // rejects the args, and CUDA_ERROR_NOT_SUPPORTED on older drivers.
91 let probe = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
92 let mut h: driver_sys::CUgraphConditionalHandle = 0;
93 // SAFETY: out-pointer; the other args are intentional null/0
94 // probes — CUDA validates and returns an error code.
95 let s = unsafe {
96 driver_sys::cuGraphConditionalHandleCreate(
97 &mut h as *mut _,
98 std::ptr::null_mut(),
99 std::ptr::null_mut(),
100 0,
101 0,
102 )
103 };
104 s
105 }));
106 match probe {
107 Ok(s) => match s {
108 driver_sys::cudaError_enum::CUDA_ERROR_NOT_SUPPORTED => Ok(false),
109 // Anything else (success, INVALID_VALUE, INVALID_CONTEXT)
110 // means the symbol is at least linked.
111 _ => Ok(true),
112 },
113 Err(_) => Err(GpuError::Unrecoverable(
114 "conditional probe: CUDA driver not loadable".into(),
115 )),
116 }
117}
118
119/// Helper: lift a `CUresult` into our error taxonomy.
120pub(crate) fn check(s: driver_sys::CUresult, op: &str) -> Result<(), GpuError> {
121 if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
122 Ok(())
123 } else {
124 Err(GpuError::LibraryError {
125 lib: LIB,
126 msg: format!("{op}: {s:?}"),
127 })
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn if_node_descriptor_compiles() {
137 let d = IfNodeDescriptor { default_value: 1 };
138 assert_eq!(d.default_value, 1);
139 assert_eq!(ConditionalKind::If, ConditionalKind::If);
140 assert_ne!(ConditionalKind::If, ConditionalKind::While);
141 // raw() round-trip — the variants must map to distinct CUDA
142 // constants.
143 let _ = ConditionalKind::If.raw();
144 let _ = ConditionalKind::While.raw();
145 }
146
147 #[test]
148 fn while_node_descriptor_compiles() {
149 let d = WhileNodeDescriptor { default_value: 0 };
150 assert_eq!(d.default_value, 0);
151 }
152
153 #[test]
154 fn driver_probe_returns_typed_result() {
155 // On no-GPU hosts the probe surfaces Unrecoverable; on real
156 // hardware it returns Ok(true|false). Either way, no panic.
157 let r = driver_supports_conditional();
158 match r {
159 Ok(_) => {}
160 Err(GpuError::Unrecoverable(_)) => {}
161 other => panic!("unexpected: {other:?}"),
162 }
163 }
164}