Skip to main content

atomr_accel_cuda/kernel/tensor/
permute.rs

1//! `PermutationRequest<T>` — dtype-generic mode permutation.
2//!
3//! B(modes_b) = alpha * op_A(A(modes_a)). Mirrors
4//! `cutensorCreatePermutation` + `cutensorPermute`.
5
6use std::sync::Arc;
7
8use cudarc::cutensor::result as ct_result;
9use cudarc::cutensor::sys as ct_sys;
10use cudarc::driver::{DevicePtr, DevicePtrMut};
11use parking_lot::Mutex;
12use tokio::sync::oneshot;
13
14use crate::dtype::TensorSupported;
15use crate::error::GpuError;
16use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
17use crate::kernel::envelope;
18use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
19use crate::kernel::tensor::contract::OperandSpec;
20use crate::kernel::tensor::plan_cache::{
21    hash_i32s, hash_i64s, CachedPlan, OpKind, PlanCache, PlanKey,
22};
23use crate::kernel::tensor::SendHandle;
24use crate::sys::cutensor as ct_local;
25
26const LIB: &str = "cutensor";
27
28pub struct PermutationRequest<T: TensorSupported> {
29    pub a: OperandSpec<T>,
30    pub b: OperandSpec<T>,
31    pub alpha: T,
32    pub op_a: ct_sys::cutensorOperator_t,
33    pub compute: ComputeDesc,
34    pub alignment: u32,
35    pub reply: oneshot::Sender<Result<(), GpuError>>,
36}
37
38impl<T: TensorSupported> PermutationRequest<T> {
39    pub fn new(
40        a: OperandSpec<T>,
41        b: OperandSpec<T>,
42        alpha: T,
43        reply: oneshot::Sender<Result<(), GpuError>>,
44    ) -> Self {
45        Self {
46            a,
47            b,
48            alpha,
49            op_a: ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
50            compute: super::contract::default_compute_for::<T>(),
51            alignment: 16,
52            reply,
53        }
54    }
55}
56
57impl<T: TensorSupported> TensorDispatch for PermutationRequest<T> {
58    fn op_tag(&self) -> &'static str {
59        "permute"
60    }
61    fn dtype_tag(&self) -> &'static str {
62        <T as atomr_accel::AccelDtype>::NAME
63    }
64    fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
65        execute(*self, ctx);
66    }
67    fn fail_mock(self: Box<Self>) {
68        let _ = self.reply.send(Err(GpuError::Unrecoverable(
69            "TensorActor in mock mode".into(),
70        )));
71    }
72}
73
74fn execute<T: TensorSupported>(req: PermutationRequest<T>, ctx: &TensorDispatchCtx) {
75    let PermutationRequest {
76        a,
77        b,
78        alpha,
79        op_a,
80        compute,
81        alignment,
82        reply,
83    } = req;
84
85    if a.extent.len() != a.modes.len() || b.extent.len() != b.modes.len() {
86        let _ = reply.send(Err(GpuError::Unrecoverable(
87            "Permutation: extent/modes length mismatch".into(),
88        )));
89        return;
90    }
91
92    let a_slice = match a.buf.access() {
93        Ok(s) => s.clone(),
94        Err(e) => {
95            let _ = reply.send(Err(e));
96            return;
97        }
98    };
99    let b_slice = match b.buf.access() {
100        Ok(s) => s.clone(),
101        Err(e) => {
102            let _ = reply.send(Err(e));
103            return;
104        }
105    };
106    let mut b_owned = match Arc::try_unwrap(b_slice) {
107        Ok(s) => s,
108        Err(_) => {
109            let _ = reply.send(Err(GpuError::Unrecoverable(
110                "Permutation b has multiple live references".into(),
111            )));
112            return;
113        }
114    };
115
116    let key = build_permutation_key_raw(
117        <T as atomr_accel::AccelDtype>::NAME,
118        &a.modes,
119        &b.modes,
120        &a.extent,
121        &b.extent,
122        alignment,
123        compute,
124        op_a,
125    );
126    let cached = match get_or_build_plan::<T>(
127        &ctx.handle,
128        &ctx.plan_cache,
129        &key,
130        &a,
131        &b,
132        alignment,
133        compute,
134        op_a,
135    ) {
136        Ok(p) => p,
137        Err(e) => {
138            let _ = reply.send(Err(e));
139            return;
140        }
141    };
142
143    b.buf.record_write(&ctx.stream);
144
145    let stream_for_check = ctx.stream.clone();
146    let handle_clone = ctx.handle.clone();
147    let plan_keepalive = cached.clone();
148
149    envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
150        let h = handle_clone.lock();
151        let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
152        let (b_ptr, _gb) = b_owned.device_ptr_mut(&stream_for_check);
153        let alpha_h = alpha;
154        let res = unsafe {
155            ct_local::permute(
156                h.0,
157                plan_keepalive.plan,
158                &alpha_h as *const T as *const _,
159                a_ptr as *const _,
160                b_ptr as *mut _,
161                stream_for_check.cu_stream() as *mut _,
162            )
163        };
164        drop((_ga, _gb));
165        match res {
166            Ok(()) => Ok((b_owned, a_slice, plan_keepalive)),
167            Err(e) => Err(GpuError::LibraryError {
168                lib: LIB,
169                msg: format!("Permute: {e}"),
170            }),
171        }
172    });
173}
174
175#[allow(clippy::too_many_arguments)]
176pub fn build_permutation_key_raw(
177    dtype_tag: &'static str,
178    modes_a: &[i32],
179    modes_b: &[i32],
180    extent_a: &[i64],
181    extent_b: &[i64],
182    alignment: u32,
183    compute: ComputeDesc,
184    op_a: ct_sys::cutensorOperator_t,
185) -> PlanKey {
186    let mut modes = Vec::with_capacity(modes_a.len() + modes_b.len() + 2);
187    modes.extend_from_slice(modes_a);
188    modes.push(i32::MIN);
189    modes.extend_from_slice(modes_b);
190    let mut extents = Vec::with_capacity(extent_a.len() + extent_b.len() + 2);
191    extents.extend_from_slice(extent_a);
192    extents.push(i64::MIN);
193    extents.extend_from_slice(extent_b);
194    PlanKey {
195        op_kind: OpKind::Permutation,
196        modes_hash: hash_i32s(&modes),
197        extents_hash: hash_i64s(&extents),
198        alignment,
199        compute_desc_tag: compute_desc_tag(compute) ^ ((op_a as u32).wrapping_mul(0x27d4_eb2f)),
200        dtype_tag,
201        algo: 0,
202    }
203}
204
205#[allow(clippy::too_many_arguments)]
206fn get_or_build_plan<T: TensorSupported>(
207    handle: &Arc<Mutex<SendHandle>>,
208    plan_cache: &Arc<PlanCache>,
209    key: &PlanKey,
210    a: &OperandSpec<T>,
211    b: &OperandSpec<T>,
212    alignment: u32,
213    compute: ComputeDesc,
214    op_a: ct_sys::cutensorOperator_t,
215) -> Result<Arc<CachedPlan>, GpuError> {
216    if let Some(p) = plan_cache.get(key) {
217        return Ok(p);
218    }
219    let plan = build_plan::<T>(handle, a, b, alignment, compute, op_a)?;
220    let arc = Arc::new(plan);
221    plan_cache.put(*key, arc.clone());
222    Ok(arc)
223}
224
225fn build_plan<T: TensorSupported>(
226    handle: &Arc<Mutex<SendHandle>>,
227    a: &OperandSpec<T>,
228    b: &OperandSpec<T>,
229    alignment: u32,
230    compute: ComputeDesc,
231    op_a: ct_sys::cutensorOperator_t,
232) -> Result<CachedPlan, GpuError> {
233    let h = handle.lock();
234    let dt: cudarc::cutensor::sys::cudaDataType_t =
235        unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
236    let cd = resolve_compute_desc(compute);
237    let stride_ptr = |v: &Vec<i64>| {
238        if v.is_empty() {
239            std::ptr::null()
240        } else {
241            v.as_ptr()
242        }
243    };
244    let desc_a = unsafe {
245        ct_result::create_tensor_descriptor(
246            h.0,
247            a.extent.len() as u32,
248            a.extent.as_ptr(),
249            stride_ptr(&a.stride),
250            dt,
251            alignment,
252        )
253    }
254    .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
255    let desc_b = unsafe {
256        ct_result::create_tensor_descriptor(
257            h.0,
258            b.extent.len() as u32,
259            b.extent.as_ptr(),
260            stride_ptr(&b.stride),
261            dt,
262            alignment,
263        )
264    }
265    .map_err(|e| {
266        unsafe {
267            let _ = ct_result::destroy_tensor_descriptor(desc_a);
268        }
269        GpuError::lib(LIB, format!("CreateTensorDescriptor(B): {e}"))
270    })?;
271
272    let op = unsafe {
273        ct_local::create_permutation(
274            h.0,
275            desc_a,
276            a.modes.as_ptr(),
277            op_a,
278            desc_b,
279            b.modes.as_ptr(),
280            cd,
281        )
282    }
283    .map_err(|e| {
284        unsafe {
285            let _ = ct_result::destroy_tensor_descriptor(desc_b);
286            let _ = ct_result::destroy_tensor_descriptor(desc_a);
287        }
288        GpuError::lib(LIB, format!("CreatePermutation: {e}"))
289    })?;
290
291    let pref = unsafe {
292        ct_result::create_plan_preference(
293            h.0,
294            ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
295            ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
296        )
297    }
298    .map_err(|e| {
299        unsafe {
300            let _ = ct_result::destroy_operation_descriptor(op);
301            let _ = ct_result::destroy_tensor_descriptor(desc_b);
302            let _ = ct_result::destroy_tensor_descriptor(desc_a);
303        }
304        GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
305    })?;
306
307    let ws_size = unsafe {
308        ct_result::estimate_workspace_size(
309            h.0,
310            op,
311            pref,
312            ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
313        )
314    }
315    .map_err(|e| {
316        unsafe {
317            let _ = ct_result::destroy_plan_preference(pref);
318            let _ = ct_result::destroy_operation_descriptor(op);
319            let _ = ct_result::destroy_tensor_descriptor(desc_b);
320            let _ = ct_result::destroy_tensor_descriptor(desc_a);
321        }
322        GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
323    })?;
324
325    let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
326        unsafe {
327            let _ = ct_result::destroy_plan_preference(pref);
328            let _ = ct_result::destroy_operation_descriptor(op);
329            let _ = ct_result::destroy_tensor_descriptor(desc_b);
330            let _ = ct_result::destroy_tensor_descriptor(desc_a);
331        }
332        GpuError::lib(LIB, format!("CreatePlan: {e}"))
333    })?;
334
335    Ok(CachedPlan {
336        plan,
337        pref,
338        op,
339        descs: vec![desc_a, desc_b],
340        workspace_size: ws_size,
341    })
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn permutation_request_round_trip() {
350        let key32 = build_permutation_key_raw(
351            <f32 as atomr_accel::AccelDtype>::NAME,
352            &[1, 2, 3],
353            &[3, 1, 2],
354            &[8, 16, 32],
355            &[32, 8, 16],
356            16,
357            ComputeDesc::MinF32,
358            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
359        );
360        let key32_neg = build_permutation_key_raw(
361            <f32 as atomr_accel::AccelDtype>::NAME,
362            &[1, 2, 3],
363            &[3, 1, 2],
364            &[8, 16, 32],
365            &[32, 8, 16],
366            16,
367            ComputeDesc::MinF32,
368            ct_sys::cutensorOperator_t::CUTENSOR_OP_NEG,
369        );
370        // Same shapes, different op_a → distinct.
371        assert_ne!(key32, key32_neg);
372        assert_eq!(key32.op_kind, OpKind::Permutation);
373        assert_eq!(key32.dtype_tag, "f32");
374
375        let key64 = build_permutation_key_raw(
376            <f64 as atomr_accel::AccelDtype>::NAME,
377            &[1, 2, 3],
378            &[3, 1, 2],
379            &[8, 16, 32],
380            &[32, 8, 16],
381            16,
382            ComputeDesc::MinF64,
383            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
384        );
385        assert_ne!(key32, key64);
386        assert_eq!(key64.dtype_tag, "f64");
387    }
388}