Skip to main content

atomr_accel_cuda/graph/
conditional.rs

1//! Conditional graph nodes (`cudaGraphConditionalNode`, CUDA 12.4+).
2//!
3//! Gated behind the `graphs-conditional` Cargo feature. Even with the
4//! feature on, the [`GraphActor`] runtime probe disables the path on
5//! older drivers — `IfNodeDescriptor::record` returns
6//! `Unrecoverable("conditional graphs unsupported")` if the
7//! `cuGraphConditionalHandleCreate` symbol can't be resolved.
8//!
9//! Public surface:
10//! - [`ConditionalKind`] — `If` or `While`.
11//! - [`IfNodeDescriptor`] / [`WhileNodeDescriptor`] — typed
12//!   descriptors carrying the inner graph that will be replayed when
13//!   the predicate is non-zero.
14
15#![cfg(feature = "graphs-conditional")]
16
17use std::sync::Arc;
18
19use cudarc::driver::sys as driver_sys;
20use cudarc::driver::CudaContext;
21
22use crate::error::GpuError;
23
24const LIB: &str = "graph";
25
26/// Kind of conditional node. Matches `CU_GRAPH_COND_TYPE_*`.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ConditionalKind {
29    /// Execute the inner graph at most once when the handle's value
30    /// is non-zero.
31    If,
32    /// Execute the inner graph repeatedly while the handle's value
33    /// stays non-zero. The inner graph is responsible for clearing
34    /// the handle when it should exit.
35    While,
36}
37
38impl ConditionalKind {
39    fn raw(self) -> driver_sys::CUgraphConditionalNodeType {
40        match self {
41            ConditionalKind::If => driver_sys::CUgraphConditionalNodeType::CU_GRAPH_COND_TYPE_IF,
42            ConditionalKind::While => {
43                driver_sys::CUgraphConditionalNodeType::CU_GRAPH_COND_TYPE_WHILE
44            }
45        }
46    }
47}
48
49/// Descriptor for an `If`-style conditional node. The inner graph is
50/// out-parameter-allocated by CUDA when the node is created — callers
51/// receive its handle back so they can populate it before exec
52/// instantiation.
53#[derive(Clone)]
54pub struct IfNodeDescriptor {
55    pub default_value: u32,
56}
57
58/// Descriptor for a `While`-style conditional node.
59#[derive(Clone)]
60pub struct WhileNodeDescriptor {
61    pub default_value: u32,
62}
63
64/// Build the raw `CUDA_CONDITIONAL_NODE_PARAMS` for `kind` against
65/// `parent`'s context. The returned struct embeds an out-pointer
66/// (`phGraph_out`) that CUDA fills with the inner graph; the actor
67/// adds it to the parent via `cuGraphAddNode_v2`.
68pub fn build_params(
69    kind: ConditionalKind,
70    handle: driver_sys::CUgraphConditionalHandle,
71    ctx: &Arc<CudaContext>,
72    inner_graph_out: *mut driver_sys::CUgraph,
73) -> driver_sys::CUDA_CONDITIONAL_NODE_PARAMS {
74    driver_sys::CUDA_CONDITIONAL_NODE_PARAMS {
75        handle,
76        type_: kind.raw(),
77        size: 1,
78        phGraph_out: inner_graph_out,
79        ctx: ctx.cu_ctx(),
80    }
81}
82
83/// Probe whether the running CUDA driver supports conditional graphs.
84/// Returns `Ok(true)` if `cuGraphConditionalHandleCreate` is callable
85/// (i.e. CUDA ≥ 12.4 with the symbol present); `Ok(false)` otherwise;
86/// `Err(GpuError::Unrecoverable)` if the loader panics.
87pub fn driver_supports_conditional() -> Result<bool, GpuError> {
88    // We probe by attempting to call the symbol with bogus args; CUDA
89    // returns CUDA_ERROR_INVALID_VALUE if the symbol is loadable but
90    // rejects the args, and CUDA_ERROR_NOT_SUPPORTED on older drivers.
91    let probe = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
92        let mut h: driver_sys::CUgraphConditionalHandle = 0;
93        // SAFETY: out-pointer; the other args are intentional null/0
94        // probes — CUDA validates and returns an error code.
95        let s = unsafe {
96            driver_sys::cuGraphConditionalHandleCreate(
97                &mut h as *mut _,
98                std::ptr::null_mut(),
99                std::ptr::null_mut(),
100                0,
101                0,
102            )
103        };
104        s
105    }));
106    match probe {
107        Ok(s) => match s {
108            driver_sys::cudaError_enum::CUDA_ERROR_NOT_SUPPORTED => Ok(false),
109            // Anything else (success, INVALID_VALUE, INVALID_CONTEXT)
110            // means the symbol is at least linked.
111            _ => Ok(true),
112        },
113        Err(_) => Err(GpuError::Unrecoverable(
114            "conditional probe: CUDA driver not loadable".into(),
115        )),
116    }
117}
118
119/// Helper: lift a `CUresult` into our error taxonomy.
120pub(crate) fn check(s: driver_sys::CUresult, op: &str) -> Result<(), GpuError> {
121    if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
122        Ok(())
123    } else {
124        Err(GpuError::LibraryError {
125            lib: LIB,
126            msg: format!("{op}: {s:?}"),
127        })
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn if_node_descriptor_compiles() {
137        let d = IfNodeDescriptor { default_value: 1 };
138        assert_eq!(d.default_value, 1);
139        assert_eq!(ConditionalKind::If, ConditionalKind::If);
140        assert_ne!(ConditionalKind::If, ConditionalKind::While);
141        // raw() round-trip — the variants must map to distinct CUDA
142        // constants.
143        let _ = ConditionalKind::If.raw();
144        let _ = ConditionalKind::While.raw();
145    }
146
147    #[test]
148    fn while_node_descriptor_compiles() {
149        let d = WhileNodeDescriptor { default_value: 0 };
150        assert_eq!(d.default_value, 0);
151    }
152
153    #[test]
154    fn driver_probe_returns_typed_result() {
155        // On no-GPU hosts the probe surfaces Unrecoverable; on real
156        // hardware it returns Ok(true|false). Either way, no panic.
157        let r = driver_supports_conditional();
158        match r {
159            Ok(_) => {}
160            Err(GpuError::Unrecoverable(_)) => {}
161            other => panic!("unexpected: {other:?}"),
162        }
163    }
164}