1use core::ffi::c_void;
17use core::mem::MaybeUninit;
18
19use cudarc::cutensor::sys as ct_sys;
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub struct CutensorError(pub ct_sys::cutensorStatus_t);
26
27impl std::fmt::Display for CutensorError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "cutensor status: {:?}", self.0)
30 }
31}
32
33impl std::error::Error for CutensorError {}
34
35#[inline]
36fn check(status: ct_sys::cutensorStatus_t) -> Result<(), CutensorError> {
37 match status {
38 ct_sys::cutensorStatus_t::CUTENSOR_STATUS_SUCCESS => Ok(()),
39 e => Err(CutensorError(e)),
40 }
41}
42
43use std::sync::OnceLock;
55
56struct ComputeDescriptors {
57 r_min_32f: ct_sys::cutensorComputeDescriptor_t,
58 r_min_64f: ct_sys::cutensorComputeDescriptor_t,
59 r_min_16f: ct_sys::cutensorComputeDescriptor_t,
60 r_min_16bf: ct_sys::cutensorComputeDescriptor_t,
61 r_min_tf32: ct_sys::cutensorComputeDescriptor_t,
62 r_32f: ct_sys::cutensorComputeDescriptor_t,
63 r_64f: ct_sys::cutensorComputeDescriptor_t,
64 c_32f: ct_sys::cutensorComputeDescriptor_t,
65}
66
67unsafe impl Send for ComputeDescriptors {}
68unsafe impl Sync for ComputeDescriptors {}
69
70static DESCRIPTORS: OnceLock<ComputeDescriptors> = OnceLock::new();
71
72fn load_descriptors() -> ComputeDescriptors {
73 let candidates = [
76 "libcutensor.so.2",
77 "libcutensor.so.1",
78 "libcutensor.so",
79 "cutensor.dll",
80 ];
81 for cand in candidates.iter() {
82 let lib = unsafe { libloading::Library::new(*cand) };
83 let Ok(lib) = lib else { continue };
84 let read = |name: &[u8]| -> Option<ct_sys::cutensorComputeDescriptor_t> {
85 unsafe {
86 let s: libloading::Symbol<*const ct_sys::cutensorComputeDescriptor_t> =
92 lib.get(name).ok()?;
93 Some(**s)
94 }
95 };
96 let r_min_32f = read(b"CUTENSOR_R_MIN_32F\0");
97 let r_min_64f = read(b"CUTENSOR_R_MIN_64F\0");
98 let r_min_16f = read(b"CUTENSOR_R_MIN_16F\0");
99 let r_min_16bf = read(b"CUTENSOR_R_MIN_16BF\0");
100 let r_min_tf32 = read(b"CUTENSOR_R_MIN_TF32\0");
101 let r_32f = read(b"CUTENSOR_R_32F\0");
102 let r_64f = read(b"CUTENSOR_R_64F\0");
103 let c_32f = read(b"CUTENSOR_C_32F\0");
104 if let (Some(a), Some(b)) = (r_min_32f, r_min_64f) {
107 std::mem::forget(lib);
110 return ComputeDescriptors {
111 r_min_32f: a,
112 r_min_64f: b,
113 r_min_16f: r_min_16f.unwrap_or(a),
114 r_min_16bf: r_min_16bf.unwrap_or(a),
115 r_min_tf32: r_min_tf32.unwrap_or(a),
116 r_32f: r_32f.unwrap_or(a),
117 r_64f: r_64f.unwrap_or(b),
118 c_32f: c_32f.unwrap_or(a),
119 };
120 }
121 }
122 panic!(
123 "ContextPoisoned: failed to dlopen libcutensor.so / locate \
124 CUTENSOR_R_MIN_32F (compute descriptor symbol). cuTENSOR \
125 must be installed on the host for cutensor-feature builds."
126 );
127}
128
129#[inline]
130fn descriptors() -> &'static ComputeDescriptors {
131 DESCRIPTORS.get_or_init(load_descriptors)
132}
133
134pub fn r_min_32f() -> ct_sys::cutensorComputeDescriptor_t {
135 descriptors().r_min_32f
136}
137pub fn r_min_64f() -> ct_sys::cutensorComputeDescriptor_t {
138 descriptors().r_min_64f
139}
140pub fn r_min_16f() -> ct_sys::cutensorComputeDescriptor_t {
141 descriptors().r_min_16f
142}
143pub fn r_min_16bf() -> ct_sys::cutensorComputeDescriptor_t {
144 descriptors().r_min_16bf
145}
146pub fn r_min_tf32() -> ct_sys::cutensorComputeDescriptor_t {
147 descriptors().r_min_tf32
148}
149pub fn r_32f() -> ct_sys::cutensorComputeDescriptor_t {
150 descriptors().r_32f
151}
152pub fn r_64f() -> ct_sys::cutensorComputeDescriptor_t {
153 descriptors().r_64f
154}
155pub fn c_32f() -> ct_sys::cutensorComputeDescriptor_t {
156 descriptors().c_32f
157}
158
159pub unsafe fn reduce(
169 handle: ct_sys::cutensorHandle_t,
170 plan: ct_sys::cutensorPlan_t,
171 alpha: *const c_void,
172 a: *const c_void,
173 beta: *const c_void,
174 c: *const c_void,
175 d: *mut c_void,
176 workspace: *mut c_void,
177 workspace_size: u64,
178 stream: ct_sys::cudaStream_t,
179) -> Result<(), CutensorError> {
180 check(ct_sys::cutensorReduce(
181 handle,
182 plan,
183 alpha,
184 a,
185 beta,
186 c,
187 d,
188 workspace,
189 workspace_size,
190 stream,
191 ))
192}
193
194pub unsafe fn create_elementwise_binary(
205 handle: ct_sys::cutensorHandle_t,
206 desc_a: ct_sys::cutensorTensorDescriptor_t,
207 mode_a: *const i32,
208 op_a: ct_sys::cutensorOperator_t,
209 desc_c: ct_sys::cutensorTensorDescriptor_t,
210 mode_c: *const i32,
211 op_c: ct_sys::cutensorOperator_t,
212 desc_d: ct_sys::cutensorTensorDescriptor_t,
213 mode_d: *const i32,
214 op_ac: ct_sys::cutensorOperator_t,
215 desc_compute: ct_sys::cutensorComputeDescriptor_t,
216) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
217 let mut desc = MaybeUninit::uninit();
218 check(ct_sys::cutensorCreateElementwiseBinary(
219 handle,
220 desc.as_mut_ptr(),
221 desc_a,
222 mode_a,
223 op_a,
224 desc_c,
225 mode_c,
226 op_c,
227 desc_d,
228 mode_d,
229 op_ac,
230 desc_compute,
231 ))?;
232 Ok(desc.assume_init())
233}
234
235pub unsafe fn elementwise_binary_execute(
240 handle: ct_sys::cutensorHandle_t,
241 plan: ct_sys::cutensorPlan_t,
242 alpha: *const c_void,
243 a: *const c_void,
244 gamma: *const c_void,
245 c: *const c_void,
246 d: *mut c_void,
247 stream: ct_sys::cudaStream_t,
248) -> Result<(), CutensorError> {
249 check(ct_sys::cutensorElementwiseBinaryExecute(
250 handle, plan, alpha, a, gamma, c, d, stream,
251 ))
252}
253
254pub unsafe fn create_elementwise_trinary(
264 handle: ct_sys::cutensorHandle_t,
265 desc_a: ct_sys::cutensorTensorDescriptor_t,
266 mode_a: *const i32,
267 op_a: ct_sys::cutensorOperator_t,
268 desc_b: ct_sys::cutensorTensorDescriptor_t,
269 mode_b: *const i32,
270 op_b: ct_sys::cutensorOperator_t,
271 desc_c: ct_sys::cutensorTensorDescriptor_t,
272 mode_c: *const i32,
273 op_c: ct_sys::cutensorOperator_t,
274 desc_d: ct_sys::cutensorTensorDescriptor_t,
275 mode_d: *const i32,
276 op_ab: ct_sys::cutensorOperator_t,
277 op_abc: ct_sys::cutensorOperator_t,
278 desc_compute: ct_sys::cutensorComputeDescriptor_t,
279) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
280 let mut desc = MaybeUninit::uninit();
281 check(ct_sys::cutensorCreateElementwiseTrinary(
282 handle,
283 desc.as_mut_ptr(),
284 desc_a,
285 mode_a,
286 op_a,
287 desc_b,
288 mode_b,
289 op_b,
290 desc_c,
291 mode_c,
292 op_c,
293 desc_d,
294 mode_d,
295 op_ab,
296 op_abc,
297 desc_compute,
298 ))?;
299 Ok(desc.assume_init())
300}
301
302pub unsafe fn elementwise_trinary_execute(
307 handle: ct_sys::cutensorHandle_t,
308 plan: ct_sys::cutensorPlan_t,
309 alpha: *const c_void,
310 a: *const c_void,
311 beta: *const c_void,
312 b: *const c_void,
313 gamma: *const c_void,
314 c: *const c_void,
315 d: *mut c_void,
316 stream: ct_sys::cudaStream_t,
317) -> Result<(), CutensorError> {
318 check(ct_sys::cutensorElementwiseTrinaryExecute(
319 handle, plan, alpha, a, beta, b, gamma, c, d, stream,
320 ))
321}
322
323pub unsafe fn create_permutation(
333 handle: ct_sys::cutensorHandle_t,
334 desc_a: ct_sys::cutensorTensorDescriptor_t,
335 mode_a: *const i32,
336 op_a: ct_sys::cutensorOperator_t,
337 desc_b: ct_sys::cutensorTensorDescriptor_t,
338 mode_b: *const i32,
339 desc_compute: ct_sys::cutensorComputeDescriptor_t,
340) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
341 let mut desc = MaybeUninit::uninit();
342 check(ct_sys::cutensorCreatePermutation(
343 handle,
344 desc.as_mut_ptr(),
345 desc_a,
346 mode_a,
347 op_a,
348 desc_b,
349 mode_b,
350 desc_compute,
351 ))?;
352 Ok(desc.assume_init())
353}
354
355pub unsafe fn permute(
360 handle: ct_sys::cutensorHandle_t,
361 plan: ct_sys::cutensorPlan_t,
362 alpha: *const c_void,
363 a: *const c_void,
364 b: *mut c_void,
365 stream: ct_sys::cudaStream_t,
366) -> Result<(), CutensorError> {
367 check(ct_sys::cutensorPermute(handle, plan, alpha, a, b, stream))
368}
369
370pub unsafe fn plan_preference_set_algo(
380 handle: ct_sys::cutensorHandle_t,
381 pref: ct_sys::cutensorPlanPreference_t,
382 algo: ct_sys::cutensorAlgo_t,
383) -> Result<(), CutensorError> {
384 let value = algo as i32;
385 check(ct_sys::cutensorPlanPreferenceSetAttribute(
386 handle,
387 pref,
388 ct_sys::cutensorPlanPreferenceAttribute_t::CUTENSOR_PLAN_PREFERENCE_ALGO,
389 &value as *const i32 as *const c_void,
390 std::mem::size_of::<i32>(),
391 ))
392}