1#![allow(dead_code)]
18
19use std::ffi::c_void;
20
21use cudarc::cudnn::sys as cudnn_sys;
22
23use crate::error::GpuError;
24
25const LIB: &str = "cudnn";
26
27fn check(s: cudnn_sys::cudnnStatus_t, what: &'static str) -> Result<(), GpuError> {
28 if s == cudnn_sys::cudnnStatus_t::CUDNN_STATUS_SUCCESS {
29 Ok(())
30 } else {
31 Err(GpuError::LibraryError {
32 lib: LIB,
33 msg: format!("{what}: cudnnStatus={:?}", s),
34 })
35 }
36}
37
38pub struct BackendDescriptor {
41 raw: cudnn_sys::cudnnBackendDescriptor_t,
42 finalized: bool,
43}
44
45impl std::fmt::Debug for BackendDescriptor {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("BackendDescriptor")
48 .field("raw", &self.raw)
49 .field("finalized", &self.finalized)
50 .finish()
51 }
52}
53
54unsafe impl Send for BackendDescriptor {}
55
56impl BackendDescriptor {
57 pub fn create(kind: cudnn_sys::cudnnBackendDescriptorType_t) -> Result<Self, GpuError> {
59 let mut raw: cudnn_sys::cudnnBackendDescriptor_t = std::ptr::null_mut();
60 let s = unsafe { cudnn_sys::cudnnBackendCreateDescriptor(kind, &mut raw) };
61 check(s, "cudnnBackendCreateDescriptor")?;
62 Ok(Self {
63 raw,
64 finalized: false,
65 })
66 }
67
68 pub fn as_raw(&self) -> cudnn_sys::cudnnBackendDescriptor_t {
70 self.raw
71 }
72
73 pub fn finalize(&mut self) -> Result<(), GpuError> {
75 if self.finalized {
76 return Ok(());
77 }
78 let s = unsafe { cudnn_sys::cudnnBackendFinalize(self.raw) };
79 check(s, "cudnnBackendFinalize")?;
80 self.finalized = true;
81 Ok(())
82 }
83
84 pub fn is_finalized(&self) -> bool {
86 self.finalized
87 }
88
89 pub unsafe fn set_attribute_raw(
97 &mut self,
98 name: cudnn_sys::cudnnBackendAttributeName_t,
99 attr_type: cudnn_sys::cudnnBackendAttributeType_t,
100 count: i64,
101 data: *const c_void,
102 ) -> Result<(), GpuError> {
103 let s = unsafe {
104 cudnn_sys::cudnnBackendSetAttribute(
105 self.raw,
106 name,
107 attr_type,
108 count,
109 data as *mut c_void,
110 )
111 };
112 check(s, "cudnnBackendSetAttribute")
113 }
114
115 pub fn set_i64(
117 &mut self,
118 name: cudnn_sys::cudnnBackendAttributeName_t,
119 value: i64,
120 ) -> Result<(), GpuError> {
121 let v = value;
122 unsafe {
123 self.set_attribute_raw(
124 name,
125 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_INT64,
126 1,
127 &v as *const i64 as *const c_void,
128 )
129 }
130 }
131
132 pub fn set_i64_array(
134 &mut self,
135 name: cudnn_sys::cudnnBackendAttributeName_t,
136 values: &[i64],
137 ) -> Result<(), GpuError> {
138 unsafe {
139 self.set_attribute_raw(
140 name,
141 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_INT64,
142 values.len() as i64,
143 values.as_ptr() as *const c_void,
144 )
145 }
146 }
147
148 pub fn set_f32(
150 &mut self,
151 name: cudnn_sys::cudnnBackendAttributeName_t,
152 value: f32,
153 ) -> Result<(), GpuError> {
154 let v = value;
155 unsafe {
156 self.set_attribute_raw(
157 name,
158 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
159 1,
160 &v as *const f32 as *const c_void,
161 )
162 }
163 }
164
165 pub fn set_f64(
167 &mut self,
168 name: cudnn_sys::cudnnBackendAttributeName_t,
169 value: f64,
170 ) -> Result<(), GpuError> {
171 let v = value;
172 unsafe {
173 self.set_attribute_raw(
174 name,
175 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
176 1,
177 &v as *const f64 as *const c_void,
178 )
179 }
180 }
181
182 pub fn set_dev_ptr(
184 &mut self,
185 name: cudnn_sys::cudnnBackendAttributeName_t,
186 ptr: *mut c_void,
187 ) -> Result<(), GpuError> {
188 let p = ptr;
189 unsafe {
190 self.set_attribute_raw(
191 name,
192 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_VOID_PTR,
193 1,
194 &p as *const *mut c_void as *const c_void,
195 )
196 }
197 }
198
199 pub fn set_descriptor(
201 &mut self,
202 name: cudnn_sys::cudnnBackendAttributeName_t,
203 sub: &BackendDescriptor,
204 ) -> Result<(), GpuError> {
205 let p = sub.raw;
206 unsafe {
207 self.set_attribute_raw(
208 name,
209 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
210 1,
211 &p as *const _ as *const c_void,
212 )
213 }
214 }
215
216 pub fn set_descriptors(
218 &mut self,
219 name: cudnn_sys::cudnnBackendAttributeName_t,
220 subs: &[&BackendDescriptor],
221 ) -> Result<(), GpuError> {
222 let raws: Vec<cudnn_sys::cudnnBackendDescriptor_t> = subs.iter().map(|s| s.raw).collect();
223 unsafe {
224 self.set_attribute_raw(
225 name,
226 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
227 raws.len() as i64,
228 raws.as_ptr() as *const c_void,
229 )
230 }
231 }
232
233 pub fn set_data_type(
235 &mut self,
236 name: cudnn_sys::cudnnBackendAttributeName_t,
237 dt: cudnn_sys::cudnnDataType_t,
238 ) -> Result<(), GpuError> {
239 let v = dt;
240 unsafe {
241 self.set_attribute_raw(
242 name,
243 cudnn_sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DATA_TYPE,
244 1,
245 &v as *const _ as *const c_void,
246 )
247 }
248 }
249}
250
251impl Drop for BackendDescriptor {
252 fn drop(&mut self) {
253 if !self.raw.is_null() {
254 unsafe {
257 let _ = cudnn_sys::cudnnBackendDestroyDescriptor(self.raw);
258 }
259 self.raw = std::ptr::null_mut();
260 }
261 }
262}
263
264pub unsafe fn backend_execute(
270 handle: cudnn_sys::cudnnHandle_t,
271 plan: &BackendDescriptor,
272 variant_pack: &BackendDescriptor,
273) -> Result<(), GpuError> {
274 let s = unsafe { cudnn_sys::cudnnBackendExecute(handle, plan.raw, variant_pack.raw) };
275 check(s, "cudnnBackendExecute")
276}
277
278#[cfg(test)]
279mod tests {
280 }