Skip to main content

atomr_accel_cuda/sys/
cudnn.rs

1//! Thin Rust safety layer over the cuDNN v9 backend graph FFI.
2//!
3//! cudarc 0.19.4 ships the raw bindings for `cudnnBackendCreateDescriptor`,
4//! `cudnnBackendSetAttribute`, `cudnnBackendFinalize`,
5//! `cudnnBackendExecute`, etc. but no safe wrapper. This module owns
6//! that wrapping so the cuDNN actor's graph builder can call into the
7//! backend API without scattering `unsafe` everywhere.
8//!
9//! Scope: a [`BackendDescriptor`] RAII handle plus typed `set_*` helpers
10//! covering the attribute kinds the cuDNN actor uses (i64 / i64 array
11//! / f32 / f64 / pointer / `BackendDescriptor` / enum).
12//!
13//! Everything in this module is a no-op on hosts without cuDNN at
14//! runtime — actual FFI calls only fire when the descriptor is fed to
15//! a real `cudnnHandle_t` via [`backend_execute`].
16
17#![allow(dead_code)]
18
19use std::ffi::c_void;
20
21use cudarc::cudnn::sys as cudnn_sys;
22
23use crate::error::GpuError;
24
25const LIB: &str = "cudnn";
26
27fn check(s: cudnn_sys::cudnnStatus_t, what: &'static str) -> Result<(), GpuError> {
28    if s == cudnn_sys::cudnnStatus_t::CUDNN_STATUS_SUCCESS {
29        Ok(())
30    } else {
31        Err(GpuError::LibraryError {
32            lib: LIB,
33            msg: format!("{what}: cudnnStatus={:?}", s),
34        })
35    }
36}
37
38/// RAII wrapper around a `cudnnBackendDescriptor_t` that destroys the
39/// descriptor on drop.
40pub struct BackendDescriptor {
41    raw: cudnn_sys::cudnnBackendDescriptor_t,
42    finalized: bool,
43}
44
45impl std::fmt::Debug for BackendDescriptor {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("BackendDescriptor")
48            .field("raw", &self.raw)
49            .field("finalized", &self.finalized)
50            .finish()
51    }
52}
53
54unsafe impl Send for BackendDescriptor {}
55
56impl BackendDescriptor {
57    /// `cudnnBackendCreateDescriptor`.
58    pub fn create(kind: cudnn_sys::cudnnBackendDescriptorType_t) -> Result<Self, GpuError> {
59        let mut raw: cudnn_sys::cudnnBackendDescriptor_t = std::ptr::null_mut();
60        let s = unsafe { cudnn_sys::cudnnBackendCreateDescriptor(kind, &mut raw) };
61        check(s, "cudnnBackendCreateDescriptor")?;
62        Ok(Self {
63            raw,
64            finalized: false,
65        })
66    }
67
68    /// Raw handle (only valid until `Drop`).
69    pub fn as_raw(&self) -> cudnn_sys::cudnnBackendDescriptor_t {
70        self.raw
71    }
72
73    /// `cudnnBackendFinalize`.
74    pub fn finalize(&mut self) -> Result<(), GpuError> {
75        if self.finalized {
76            return Ok(());
77        }
78        let s = unsafe { cudnn_sys::cudnnBackendFinalize(self.raw) };
79        check(s, "cudnnBackendFinalize")?;
80        self.finalized = true;
81        Ok(())
82    }
83
84    /// True iff `finalize()` has succeeded.
85    pub fn is_finalized(&self) -> bool {
86        self.finalized
87    }
88
89    /// Generic `cudnnBackendSetAttribute` — caller supplies the
90    /// element type and a `*const c_void` pointer to a contiguous
91    /// array of `count` elements.
92    ///
93    /// # Safety
94    /// `data` must point to `count` valid elements of the type
95    /// expected by `attr_type`. The descriptor must outlive the call.
96    pub unsafe fn set_attribute_raw(
97        &mut self,
98        name: cudnn_sys::cudnnBackendAttributeName_t,
99        attr_type: cudnn_sys::cudnnBackendAttributeType_t,
100        count: i64,
101        data: *const c_void,
102    ) -> Result<(), GpuError> {
103        let s = unsafe {
104            cudnn_sys::cudnnBackendSetAttribute(
105                self.raw,
106                name,
107                attr_type,
108                count,
109                data as *mut c_void,
110            )
111        };
112        check(s, "cudnnBackendSetAttribute")
113    }
114
115    /// Set a single i64 attribute.
116    pub fn set_i64(
117        &mut self,
118        name: cudnn_sys::cudnnBackendAttributeName_t,
119        value: i64,
120    ) -> Result<(), GpuError> {
121        let v = value;
122        unsafe {
123            self.set_attribute_raw(
124                name,
125                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_INT64,
126                1,
127                &v as *const i64 as *const c_void,
128            )
129        }
130    }
131
132    /// Set an array of i64 attributes.
133    pub fn set_i64_array(
134        &mut self,
135        name: cudnn_sys::cudnnBackendAttributeName_t,
136        values: &[i64],
137    ) -> Result<(), GpuError> {
138        unsafe {
139            self.set_attribute_raw(
140                name,
141                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_INT64,
142                values.len() as i64,
143                values.as_ptr() as *const c_void,
144            )
145        }
146    }
147
148    /// Set a single f32 attribute (e.g. an alpha scaling parameter).
149    pub fn set_f32(
150        &mut self,
151        name: cudnn_sys::cudnnBackendAttributeName_t,
152        value: f32,
153    ) -> Result<(), GpuError> {
154        let v = value;
155        unsafe {
156            self.set_attribute_raw(
157                name,
158                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
159                1,
160                &v as *const f32 as *const c_void,
161            )
162        }
163    }
164
165    /// Set a single f64 attribute.
166    pub fn set_f64(
167        &mut self,
168        name: cudnn_sys::cudnnBackendAttributeName_t,
169        value: f64,
170    ) -> Result<(), GpuError> {
171        let v = value;
172        unsafe {
173            self.set_attribute_raw(
174                name,
175                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
176                1,
177                &v as *const f64 as *const c_void,
178            )
179        }
180    }
181
182    /// Set a single device-pointer attribute.
183    pub fn set_dev_ptr(
184        &mut self,
185        name: cudnn_sys::cudnnBackendAttributeName_t,
186        ptr: *mut c_void,
187    ) -> Result<(), GpuError> {
188        let p = ptr;
189        unsafe {
190            self.set_attribute_raw(
191                name,
192                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_VOID_PTR,
193                1,
194                &p as *const *mut c_void as *const c_void,
195            )
196        }
197    }
198
199    /// Set a single sub-descriptor reference.
200    pub fn set_descriptor(
201        &mut self,
202        name: cudnn_sys::cudnnBackendAttributeName_t,
203        sub: &BackendDescriptor,
204    ) -> Result<(), GpuError> {
205        let p = sub.raw;
206        unsafe {
207            self.set_attribute_raw(
208                name,
209                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
210                1,
211                &p as *const _ as *const c_void,
212            )
213        }
214    }
215
216    /// Set an array of sub-descriptor references.
217    pub fn set_descriptors(
218        &mut self,
219        name: cudnn_sys::cudnnBackendAttributeName_t,
220        subs: &[&BackendDescriptor],
221    ) -> Result<(), GpuError> {
222        let raws: Vec<cudnn_sys::cudnnBackendDescriptor_t> = subs.iter().map(|s| s.raw).collect();
223        unsafe {
224            self.set_attribute_raw(
225                name,
226                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
227                raws.len() as i64,
228                raws.as_ptr() as *const c_void,
229            )
230        }
231    }
232
233    /// Set a `cudnnDataType_t` attribute.
234    pub fn set_data_type(
235        &mut self,
236        name: cudnn_sys::cudnnBackendAttributeName_t,
237        dt: cudnn_sys::cudnnDataType_t,
238    ) -> Result<(), GpuError> {
239        let v = dt;
240        unsafe {
241            self.set_attribute_raw(
242                name,
243                cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DATA_TYPE,
244                1,
245                &v as *const _ as *const c_void,
246            )
247        }
248    }
249}
250
251impl Drop for BackendDescriptor {
252    fn drop(&mut self) {
253        if !self.raw.is_null() {
254            // Best-effort destroy. Ignore status — there is nothing
255            // sensible to do on failure during drop.
256            unsafe {
257                let _ = cudnn_sys::cudnnBackendDestroyDescriptor(self.raw);
258            }
259            self.raw = std::ptr::null_mut();
260        }
261    }
262}
263
264/// Run a finalised execution plan against a finalised variant pack.
265///
266/// # Safety
267/// Both descriptors must be finalised. `handle` must be valid for the
268/// stream the plan was built against.
269pub unsafe fn backend_execute(
270    handle: cudnn_sys::cudnnHandle_t,
271    plan: &BackendDescriptor,
272    variant_pack: &BackendDescriptor,
273) -> Result<(), GpuError> {
274    let s = unsafe { cudnn_sys::cudnnBackendExecute(handle, plan.raw, variant_pack.raw) };
275    check(s, "cudnnBackendExecute")
276}
277
278#[cfg(test)]
279mod tests {
280    // No-op tests: these helpers only call into cuDNN when a real
281    // descriptor is created, which requires a loaded cuDNN runtime.
282    // The cuDNN actor's tests cover the round-trip path under
283    // host-builds via the spec layer in `kernel::cudnn::graph`.
284}