Skip to main content

atomr_accel_cuda/sys/
cutensor.rs

1//! Local sys-level wrappers around `cudarc::cutensor::sys` for the
2//! cuTENSOR entry points the safe `cudarc::cutensor::result` module
3//! does not expose (Reduce/ElementwiseBinary/ElementwiseTrinary
4//! create+execute, Permutation create+execute, predefined compute
5//! descriptors).
6//!
7//! Every function here is `unsafe` and takes the raw cuTENSOR enum
8//! types from `cudarc::cutensor::sys`. The entire crate's actor layer
9//! drives these through `kernel/tensor/`.
10//!
11//! All wrappers convert `cutensorStatus_t` into a thin
12//! [`CutensorError`] that mirrors what `cudarc::cutensor::result`
13//! emits; callers that already use `cudarc::cutensor::result` can
14//! interleave both freely.
15
16use core::ffi::c_void;
17use core::mem::MaybeUninit;
18
19use cudarc::cutensor::sys as ct_sys;
20
21/// Error wrapper around a `cutensorStatus_t`. Mirrors
22/// `cudarc::cutensor::result::CutensorError` so error messages are
23/// consistent across the safe/sys boundary.
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub struct CutensorError(pub ct_sys::cutensorStatus_t);
26
27impl std::fmt::Display for CutensorError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "cutensor status: {:?}", self.0)
30    }
31}
32
33impl std::error::Error for CutensorError {}
34
35#[inline]
36fn check(status: ct_sys::cutensorStatus_t) -> Result<(), CutensorError> {
37    match status {
38        ct_sys::cutensorStatus_t::CUTENSOR_STATUS_SUCCESS => Ok(()),
39        e => Err(CutensorError(e)),
40    }
41}
42
43// ---------------------------------------------------------------------
44// Predefined compute descriptors.
45//
46// libcutensor exports a handful of `cutensorComputeDescriptor_t` global
47// constants (`CUTENSOR_R_MIN_32F`, `CUTENSOR_R_MIN_64F`, etc.). cudarc
48// 0.19 doesn't surface them. Linking them as `extern "C" { static ... }`
49// requires libcutensor.so at link time, which fails on no-GPU hosts —
50// instead, we resolve them at first use through `libloading`. The
51// returned pointer is cached in a `OnceLock` per symbol.
52// ---------------------------------------------------------------------
53
54use std::sync::OnceLock;
55
56struct ComputeDescriptors {
57    r_min_32f: ct_sys::cutensorComputeDescriptor_t,
58    r_min_64f: ct_sys::cutensorComputeDescriptor_t,
59    r_min_16f: ct_sys::cutensorComputeDescriptor_t,
60    r_min_16bf: ct_sys::cutensorComputeDescriptor_t,
61    r_min_tf32: ct_sys::cutensorComputeDescriptor_t,
62    r_32f: ct_sys::cutensorComputeDescriptor_t,
63    r_64f: ct_sys::cutensorComputeDescriptor_t,
64    c_32f: ct_sys::cutensorComputeDescriptor_t,
65}
66
67unsafe impl Send for ComputeDescriptors {}
68unsafe impl Sync for ComputeDescriptors {}
69
70static DESCRIPTORS: OnceLock<ComputeDescriptors> = OnceLock::new();
71
72fn load_descriptors() -> ComputeDescriptors {
73    // libcutensor candidates. Mirrors cudarc's lookup but kept local so
74    // we don't reach into cudarc's private `culib()`.
75    let candidates = [
76        "libcutensor.so.2",
77        "libcutensor.so.1",
78        "libcutensor.so",
79        "cutensor.dll",
80    ];
81    for cand in candidates.iter() {
82        let lib = unsafe { libloading::Library::new(*cand) };
83        let Ok(lib) = lib else { continue };
84        let read = |name: &[u8]| -> Option<ct_sys::cutensorComputeDescriptor_t> {
85            unsafe {
86                // The exported symbol is the variable itself; libloading
87                // returns a `Symbol<T>` whose `Deref` yields `T`. We
88                // ask for `T = *const cutensorComputeDescriptor_t` so
89                // `*s` is the variable's address, and one further deref
90                // (`**s`) reads the descriptor pointer value out of it.
91                let s: libloading::Symbol<*const ct_sys::cutensorComputeDescriptor_t> =
92                    lib.get(name).ok()?;
93                Some(**s)
94            }
95        };
96        let r_min_32f = read(b"CUTENSOR_R_MIN_32F\0");
97        let r_min_64f = read(b"CUTENSOR_R_MIN_64F\0");
98        let r_min_16f = read(b"CUTENSOR_R_MIN_16F\0");
99        let r_min_16bf = read(b"CUTENSOR_R_MIN_16BF\0");
100        let r_min_tf32 = read(b"CUTENSOR_R_MIN_TF32\0");
101        let r_32f = read(b"CUTENSOR_R_32F\0");
102        let r_64f = read(b"CUTENSOR_R_64F\0");
103        let c_32f = read(b"CUTENSOR_C_32F\0");
104        // Require at least the f32/f64 min descriptors — every supported
105        // dtype routes through one of those by default.
106        if let (Some(a), Some(b)) = (r_min_32f, r_min_64f) {
107            // Forget the library so its destructor doesn't unload while
108            // we're still holding pointers into it.
109            std::mem::forget(lib);
110            return ComputeDescriptors {
111                r_min_32f: a,
112                r_min_64f: b,
113                r_min_16f: r_min_16f.unwrap_or(a),
114                r_min_16bf: r_min_16bf.unwrap_or(a),
115                r_min_tf32: r_min_tf32.unwrap_or(a),
116                r_32f: r_32f.unwrap_or(a),
117                r_64f: r_64f.unwrap_or(b),
118                c_32f: c_32f.unwrap_or(a),
119            };
120        }
121    }
122    panic!(
123        "ContextPoisoned: failed to dlopen libcutensor.so / locate \
124         CUTENSOR_R_MIN_32F (compute descriptor symbol). cuTENSOR \
125         must be installed on the host for cutensor-feature builds."
126    );
127}
128
129#[inline]
130fn descriptors() -> &'static ComputeDescriptors {
131    DESCRIPTORS.get_or_init(load_descriptors)
132}
133
134pub fn r_min_32f() -> ct_sys::cutensorComputeDescriptor_t {
135    descriptors().r_min_32f
136}
137pub fn r_min_64f() -> ct_sys::cutensorComputeDescriptor_t {
138    descriptors().r_min_64f
139}
140pub fn r_min_16f() -> ct_sys::cutensorComputeDescriptor_t {
141    descriptors().r_min_16f
142}
143pub fn r_min_16bf() -> ct_sys::cutensorComputeDescriptor_t {
144    descriptors().r_min_16bf
145}
146pub fn r_min_tf32() -> ct_sys::cutensorComputeDescriptor_t {
147    descriptors().r_min_tf32
148}
149pub fn r_32f() -> ct_sys::cutensorComputeDescriptor_t {
150    descriptors().r_32f
151}
152pub fn r_64f() -> ct_sys::cutensorComputeDescriptor_t {
153    descriptors().r_64f
154}
155pub fn c_32f() -> ct_sys::cutensorComputeDescriptor_t {
156    descriptors().c_32f
157}
158
159// ---------------------------------------------------------------------
160// Reduction
161// ---------------------------------------------------------------------
162
163/// Wraps `cutensorReduce` (post-plan execution).
164///
165/// # Safety
166/// All pointers must be valid. `workspace` must hold at least
167/// `workspace_size` bytes. `stream` must outlive the call.
168pub unsafe fn reduce(
169    handle: ct_sys::cutensorHandle_t,
170    plan: ct_sys::cutensorPlan_t,
171    alpha: *const c_void,
172    a: *const c_void,
173    beta: *const c_void,
174    c: *const c_void,
175    d: *mut c_void,
176    workspace: *mut c_void,
177    workspace_size: u64,
178    stream: ct_sys::cudaStream_t,
179) -> Result<(), CutensorError> {
180    check(ct_sys::cutensorReduce(
181        handle,
182        plan,
183        alpha,
184        a,
185        beta,
186        c,
187        d,
188        workspace,
189        workspace_size,
190        stream,
191    ))
192}
193
194// ---------------------------------------------------------------------
195// Elementwise binary
196// ---------------------------------------------------------------------
197
198/// Create a binary elementwise operation descriptor. Wraps
199/// `cutensorCreateElementwiseBinary`.
200///
201/// # Safety
202/// All handles/descriptors must be valid; mode arrays must align with
203/// each tensor descriptor's rank.
204pub unsafe fn create_elementwise_binary(
205    handle: ct_sys::cutensorHandle_t,
206    desc_a: ct_sys::cutensorTensorDescriptor_t,
207    mode_a: *const i32,
208    op_a: ct_sys::cutensorOperator_t,
209    desc_c: ct_sys::cutensorTensorDescriptor_t,
210    mode_c: *const i32,
211    op_c: ct_sys::cutensorOperator_t,
212    desc_d: ct_sys::cutensorTensorDescriptor_t,
213    mode_d: *const i32,
214    op_ac: ct_sys::cutensorOperator_t,
215    desc_compute: ct_sys::cutensorComputeDescriptor_t,
216) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
217    let mut desc = MaybeUninit::uninit();
218    check(ct_sys::cutensorCreateElementwiseBinary(
219        handle,
220        desc.as_mut_ptr(),
221        desc_a,
222        mode_a,
223        op_a,
224        desc_c,
225        mode_c,
226        op_c,
227        desc_d,
228        mode_d,
229        op_ac,
230        desc_compute,
231    ))?;
232    Ok(desc.assume_init())
233}
234
235/// Execute a previously-planned binary elementwise op.
236///
237/// # Safety
238/// `handle`, `plan`, all data pointers, and `stream` must be valid.
239pub unsafe fn elementwise_binary_execute(
240    handle: ct_sys::cutensorHandle_t,
241    plan: ct_sys::cutensorPlan_t,
242    alpha: *const c_void,
243    a: *const c_void,
244    gamma: *const c_void,
245    c: *const c_void,
246    d: *mut c_void,
247    stream: ct_sys::cudaStream_t,
248) -> Result<(), CutensorError> {
249    check(ct_sys::cutensorElementwiseBinaryExecute(
250        handle, plan, alpha, a, gamma, c, d, stream,
251    ))
252}
253
254// ---------------------------------------------------------------------
255// Elementwise trinary
256// ---------------------------------------------------------------------
257
258/// Create a trinary elementwise operation descriptor. Wraps
259/// `cutensorCreateElementwiseTrinary`.
260///
261/// # Safety
262/// As [`create_elementwise_binary`].
263pub unsafe fn create_elementwise_trinary(
264    handle: ct_sys::cutensorHandle_t,
265    desc_a: ct_sys::cutensorTensorDescriptor_t,
266    mode_a: *const i32,
267    op_a: ct_sys::cutensorOperator_t,
268    desc_b: ct_sys::cutensorTensorDescriptor_t,
269    mode_b: *const i32,
270    op_b: ct_sys::cutensorOperator_t,
271    desc_c: ct_sys::cutensorTensorDescriptor_t,
272    mode_c: *const i32,
273    op_c: ct_sys::cutensorOperator_t,
274    desc_d: ct_sys::cutensorTensorDescriptor_t,
275    mode_d: *const i32,
276    op_ab: ct_sys::cutensorOperator_t,
277    op_abc: ct_sys::cutensorOperator_t,
278    desc_compute: ct_sys::cutensorComputeDescriptor_t,
279) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
280    let mut desc = MaybeUninit::uninit();
281    check(ct_sys::cutensorCreateElementwiseTrinary(
282        handle,
283        desc.as_mut_ptr(),
284        desc_a,
285        mode_a,
286        op_a,
287        desc_b,
288        mode_b,
289        op_b,
290        desc_c,
291        mode_c,
292        op_c,
293        desc_d,
294        mode_d,
295        op_ab,
296        op_abc,
297        desc_compute,
298    ))?;
299    Ok(desc.assume_init())
300}
301
302/// Execute a previously-planned trinary elementwise op.
303///
304/// # Safety
305/// As [`elementwise_binary_execute`].
306pub unsafe fn elementwise_trinary_execute(
307    handle: ct_sys::cutensorHandle_t,
308    plan: ct_sys::cutensorPlan_t,
309    alpha: *const c_void,
310    a: *const c_void,
311    beta: *const c_void,
312    b: *const c_void,
313    gamma: *const c_void,
314    c: *const c_void,
315    d: *mut c_void,
316    stream: ct_sys::cudaStream_t,
317) -> Result<(), CutensorError> {
318    check(ct_sys::cutensorElementwiseTrinaryExecute(
319        handle, plan, alpha, a, beta, b, gamma, c, d, stream,
320    ))
321}
322
323// ---------------------------------------------------------------------
324// Permutation
325// ---------------------------------------------------------------------
326
327/// Create a permutation operation descriptor. Wraps
328/// `cutensorCreatePermutation`.
329///
330/// # Safety
331/// As [`create_elementwise_binary`].
332pub unsafe fn create_permutation(
333    handle: ct_sys::cutensorHandle_t,
334    desc_a: ct_sys::cutensorTensorDescriptor_t,
335    mode_a: *const i32,
336    op_a: ct_sys::cutensorOperator_t,
337    desc_b: ct_sys::cutensorTensorDescriptor_t,
338    mode_b: *const i32,
339    desc_compute: ct_sys::cutensorComputeDescriptor_t,
340) -> Result<ct_sys::cutensorOperationDescriptor_t, CutensorError> {
341    let mut desc = MaybeUninit::uninit();
342    check(ct_sys::cutensorCreatePermutation(
343        handle,
344        desc.as_mut_ptr(),
345        desc_a,
346        mode_a,
347        op_a,
348        desc_b,
349        mode_b,
350        desc_compute,
351    ))?;
352    Ok(desc.assume_init())
353}
354
355/// Execute a previously-planned permutation.
356///
357/// # Safety
358/// As [`elementwise_binary_execute`].
359pub unsafe fn permute(
360    handle: ct_sys::cutensorHandle_t,
361    plan: ct_sys::cutensorPlan_t,
362    alpha: *const c_void,
363    a: *const c_void,
364    b: *mut c_void,
365    stream: ct_sys::cudaStream_t,
366) -> Result<(), CutensorError> {
367    check(ct_sys::cutensorPermute(handle, plan, alpha, a, b, stream))
368}
369
370// ---------------------------------------------------------------------
371// Plan-preference attribute setters used by autotune.
372// ---------------------------------------------------------------------
373
374/// Set the pinned algorithm on a plan-preference object. Used by the
375/// contraction autotune to probe a specific `cutensorAlgo_t` value.
376///
377/// # Safety
378/// `handle` and `pref` must be valid.
379pub unsafe fn plan_preference_set_algo(
380    handle: ct_sys::cutensorHandle_t,
381    pref: ct_sys::cutensorPlanPreference_t,
382    algo: ct_sys::cutensorAlgo_t,
383) -> Result<(), CutensorError> {
384    let value = algo as i32;
385    check(ct_sys::cutensorPlanPreferenceSetAttribute(
386        handle,
387        pref,
388        ct_sys::cutensorPlanPreferenceAttribute_t::CUTENSOR_PLAN_PREFERENCE_ALGO,
389        &value as *const i32 as *const c_void,
390        std::mem::size_of::<i32>(),
391    ))
392}