Skip to main content

atomr_accel_cuda/kernel/tensor/
elementwise.rs

1//! `ElementwiseBinaryRequest<T>` and `ElementwiseTrinaryRequest<T>`.
2//!
3//! Both wrap the cuTENSOR `cutensorCreate{Binary,Trinary}` +
4//! `cutensorElementwise{Binary,Trinary}Execute` pair via our local
5//! `crate::sys::cutensor` wrappers — cudarc's safe layer doesn't
6//! expose these.
7
8use std::sync::Arc;
9
10use cudarc::cutensor::result as ct_result;
11use cudarc::cutensor::sys as ct_sys;
12use cudarc::driver::{DevicePtr, DevicePtrMut};
13use parking_lot::Mutex;
14use tokio::sync::oneshot;
15
16use crate::dtype::TensorSupported;
17use crate::error::GpuError;
18use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
19use crate::kernel::envelope;
20use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
21use crate::kernel::tensor::contract::OperandSpec;
22use crate::kernel::tensor::plan_cache::{
23    hash_i32s, hash_i64s, CachedPlan, OpKind, PlanCache, PlanKey,
24};
25use crate::kernel::tensor::SendHandle;
26use crate::sys::cutensor as ct_local;
27
28const LIB: &str = "cutensor";
29
30// ---------------------------------------------------------------------
31// Binary
32// ---------------------------------------------------------------------
33
34/// D = op_AC( alpha * op_A(A), gamma * op_C(C) ).
35pub struct ElementwiseBinaryRequest<T: TensorSupported> {
36    pub a: OperandSpec<T>,
37    pub c: OperandSpec<T>,
38    pub d: OperandSpec<T>,
39    pub alpha: T,
40    pub gamma: T,
41    pub op_a: ct_sys::cutensorOperator_t,
42    pub op_c: ct_sys::cutensorOperator_t,
43    pub op_ac: ct_sys::cutensorOperator_t,
44    pub compute: ComputeDesc,
45    pub alignment: u32,
46    pub reply: oneshot::Sender<Result<(), GpuError>>,
47}
48
49impl<T: TensorSupported> ElementwiseBinaryRequest<T> {
50    pub fn new(
51        a: OperandSpec<T>,
52        c: OperandSpec<T>,
53        d: OperandSpec<T>,
54        alpha: T,
55        gamma: T,
56        op_ac: ct_sys::cutensorOperator_t,
57        reply: oneshot::Sender<Result<(), GpuError>>,
58    ) -> Self {
59        Self {
60            a,
61            c,
62            d,
63            alpha,
64            gamma,
65            op_a: ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
66            op_c: ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
67            op_ac,
68            compute: super::contract::default_compute_for::<T>(),
69            alignment: 16,
70            reply,
71        }
72    }
73}
74
75impl<T: TensorSupported> TensorDispatch for ElementwiseBinaryRequest<T> {
76    fn op_tag(&self) -> &'static str {
77        "ewbin"
78    }
79    fn dtype_tag(&self) -> &'static str {
80        <T as atomr_accel::AccelDtype>::NAME
81    }
82    fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
83        execute_binary(*self, ctx);
84    }
85    fn fail_mock(self: Box<Self>) {
86        let _ = self.reply.send(Err(GpuError::Unrecoverable(
87            "TensorActor in mock mode".into(),
88        )));
89    }
90}
91
92fn execute_binary<T: TensorSupported>(req: ElementwiseBinaryRequest<T>, ctx: &TensorDispatchCtx) {
93    let ElementwiseBinaryRequest {
94        a,
95        c,
96        d,
97        alpha,
98        gamma,
99        op_a,
100        op_c,
101        op_ac,
102        compute,
103        alignment,
104        reply,
105    } = req;
106
107    if a.extent.len() != a.modes.len()
108        || c.extent.len() != c.modes.len()
109        || d.extent.len() != d.modes.len()
110    {
111        let _ = reply.send(Err(GpuError::Unrecoverable(
112            "ElementwiseBinary: extent/modes length mismatch".into(),
113        )));
114        return;
115    }
116
117    let a_slice = match a.buf.access() {
118        Ok(s) => s.clone(),
119        Err(e) => {
120            let _ = reply.send(Err(e));
121            return;
122        }
123    };
124    let c_slice = match c.buf.access() {
125        Ok(s) => s.clone(),
126        Err(e) => {
127            let _ = reply.send(Err(e));
128            return;
129        }
130    };
131    let d_slice = match d.buf.access() {
132        Ok(s) => s.clone(),
133        Err(e) => {
134            let _ = reply.send(Err(e));
135            return;
136        }
137    };
138    let mut d_owned = match Arc::try_unwrap(d_slice) {
139        Ok(s) => s,
140        Err(_) => {
141            let _ = reply.send(Err(GpuError::Unrecoverable(
142                "ElementwiseBinary d has multiple live references".into(),
143            )));
144            return;
145        }
146    };
147
148    let key = build_binary_key_raw(
149        <T as atomr_accel::AccelDtype>::NAME,
150        &a.modes,
151        &c.modes,
152        &d.modes,
153        &a.extent,
154        &c.extent,
155        &d.extent,
156        alignment,
157        compute,
158        op_a,
159        op_c,
160        op_ac,
161    );
162    let cached = match get_or_build_binary::<T>(
163        &ctx.handle,
164        &ctx.plan_cache,
165        &key,
166        &a,
167        &c,
168        &d,
169        alignment,
170        compute,
171        op_a,
172        op_c,
173        op_ac,
174    ) {
175        Ok(p) => p,
176        Err(e) => {
177            let _ = reply.send(Err(e));
178            return;
179        }
180    };
181
182    d.buf.record_write(&ctx.stream);
183
184    let stream_for_check = ctx.stream.clone();
185    let handle_clone = ctx.handle.clone();
186    let plan_keepalive = cached.clone();
187
188    envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
189        let h = handle_clone.lock();
190        let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
191        let (c_ptr, _gc) = c_slice.device_ptr(&stream_for_check);
192        let (d_ptr, _gd) = d_owned.device_ptr_mut(&stream_for_check);
193        let alpha_h = alpha;
194        let gamma_h = gamma;
195        let res = unsafe {
196            ct_local::elementwise_binary_execute(
197                h.0,
198                plan_keepalive.plan,
199                &alpha_h as *const T as *const _,
200                a_ptr as *const _,
201                &gamma_h as *const T as *const _,
202                c_ptr as *const _,
203                d_ptr as *mut _,
204                stream_for_check.cu_stream() as *mut _,
205            )
206        };
207        drop((_ga, _gc, _gd));
208        match res {
209            Ok(()) => Ok((d_owned, a_slice, c_slice, plan_keepalive)),
210            Err(e) => Err(GpuError::LibraryError {
211                lib: LIB,
212                msg: format!("ElementwiseBinary: {e}"),
213            }),
214        }
215    });
216}
217
218#[allow(clippy::too_many_arguments)]
219pub fn build_binary_key_raw(
220    dtype_tag: &'static str,
221    modes_a: &[i32],
222    modes_c: &[i32],
223    modes_d: &[i32],
224    extent_a: &[i64],
225    extent_c: &[i64],
226    extent_d: &[i64],
227    alignment: u32,
228    compute: ComputeDesc,
229    op_a: ct_sys::cutensorOperator_t,
230    op_c: ct_sys::cutensorOperator_t,
231    op_ac: ct_sys::cutensorOperator_t,
232) -> PlanKey {
233    let mut modes = Vec::with_capacity(modes_a.len() + modes_c.len() + modes_d.len() + 3);
234    modes.extend_from_slice(modes_a);
235    modes.push(i32::MIN);
236    modes.extend_from_slice(modes_c);
237    modes.push(i32::MIN);
238    modes.extend_from_slice(modes_d);
239    let mut extents = Vec::with_capacity(extent_a.len() + extent_c.len() + extent_d.len() + 3);
240    extents.extend_from_slice(extent_a);
241    extents.push(i64::MIN);
242    extents.extend_from_slice(extent_c);
243    extents.push(i64::MIN);
244    extents.extend_from_slice(extent_d);
245    let op_mix = ((op_a as u32).wrapping_mul(0x85eb_ca77))
246        ^ ((op_c as u32).wrapping_mul(0xc2b2_ae3d))
247        ^ ((op_ac as u32).wrapping_mul(0x27d4_eb2f));
248    PlanKey {
249        op_kind: OpKind::ElementwiseBinary,
250        modes_hash: hash_i32s(&modes),
251        extents_hash: hash_i64s(&extents),
252        alignment,
253        compute_desc_tag: compute_desc_tag(compute) ^ op_mix,
254        dtype_tag,
255        algo: 0,
256    }
257}
258
259#[allow(clippy::too_many_arguments)]
260fn get_or_build_binary<T: TensorSupported>(
261    handle: &Arc<Mutex<SendHandle>>,
262    plan_cache: &Arc<PlanCache>,
263    key: &PlanKey,
264    a: &OperandSpec<T>,
265    c: &OperandSpec<T>,
266    d: &OperandSpec<T>,
267    alignment: u32,
268    compute: ComputeDesc,
269    op_a: ct_sys::cutensorOperator_t,
270    op_c: ct_sys::cutensorOperator_t,
271    op_ac: ct_sys::cutensorOperator_t,
272) -> Result<Arc<CachedPlan>, GpuError> {
273    if let Some(p) = plan_cache.get(key) {
274        return Ok(p);
275    }
276    let plan = build_binary_plan::<T>(handle, a, c, d, alignment, compute, op_a, op_c, op_ac)?;
277    let arc = Arc::new(plan);
278    plan_cache.put(*key, arc.clone());
279    Ok(arc)
280}
281
282#[allow(clippy::too_many_arguments)]
283fn build_binary_plan<T: TensorSupported>(
284    handle: &Arc<Mutex<SendHandle>>,
285    a: &OperandSpec<T>,
286    c: &OperandSpec<T>,
287    d: &OperandSpec<T>,
288    alignment: u32,
289    compute: ComputeDesc,
290    op_a: ct_sys::cutensorOperator_t,
291    op_c: ct_sys::cutensorOperator_t,
292    op_ac: ct_sys::cutensorOperator_t,
293) -> Result<CachedPlan, GpuError> {
294    let h = handle.lock();
295    let dt: cudarc::cutensor::sys::cudaDataType_t =
296        unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
297    let cd = resolve_compute_desc(compute);
298    let stride_ptr = |v: &Vec<i64>| {
299        if v.is_empty() {
300            std::ptr::null()
301        } else {
302            v.as_ptr()
303        }
304    };
305
306    let desc_a = unsafe {
307        ct_result::create_tensor_descriptor(
308            h.0,
309            a.extent.len() as u32,
310            a.extent.as_ptr(),
311            stride_ptr(&a.stride),
312            dt,
313            alignment,
314        )
315    }
316    .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
317    let desc_c = unsafe {
318        ct_result::create_tensor_descriptor(
319            h.0,
320            c.extent.len() as u32,
321            c.extent.as_ptr(),
322            stride_ptr(&c.stride),
323            dt,
324            alignment,
325        )
326    }
327    .map_err(|e| {
328        unsafe {
329            let _ = ct_result::destroy_tensor_descriptor(desc_a);
330        }
331        GpuError::lib(LIB, format!("CreateTensorDescriptor(C): {e}"))
332    })?;
333    let desc_d = unsafe {
334        ct_result::create_tensor_descriptor(
335            h.0,
336            d.extent.len() as u32,
337            d.extent.as_ptr(),
338            stride_ptr(&d.stride),
339            dt,
340            alignment,
341        )
342    }
343    .map_err(|e| {
344        unsafe {
345            let _ = ct_result::destroy_tensor_descriptor(desc_c);
346            let _ = ct_result::destroy_tensor_descriptor(desc_a);
347        }
348        GpuError::lib(LIB, format!("CreateTensorDescriptor(D): {e}"))
349    })?;
350
351    let op = unsafe {
352        ct_local::create_elementwise_binary(
353            h.0,
354            desc_a,
355            a.modes.as_ptr(),
356            op_a,
357            desc_c,
358            c.modes.as_ptr(),
359            op_c,
360            desc_d,
361            d.modes.as_ptr(),
362            op_ac,
363            cd,
364        )
365    }
366    .map_err(|e| {
367        unsafe {
368            let _ = ct_result::destroy_tensor_descriptor(desc_d);
369            let _ = ct_result::destroy_tensor_descriptor(desc_c);
370            let _ = ct_result::destroy_tensor_descriptor(desc_a);
371        }
372        GpuError::lib(LIB, format!("CreateElementwiseBinary: {e}"))
373    })?;
374
375    let pref = unsafe {
376        ct_result::create_plan_preference(
377            h.0,
378            ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
379            ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
380        )
381    }
382    .map_err(|e| {
383        unsafe {
384            let _ = ct_result::destroy_operation_descriptor(op);
385            let _ = ct_result::destroy_tensor_descriptor(desc_d);
386            let _ = ct_result::destroy_tensor_descriptor(desc_c);
387            let _ = ct_result::destroy_tensor_descriptor(desc_a);
388        }
389        GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
390    })?;
391
392    let ws_size = unsafe {
393        ct_result::estimate_workspace_size(
394            h.0,
395            op,
396            pref,
397            ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
398        )
399    }
400    .map_err(|e| {
401        unsafe {
402            let _ = ct_result::destroy_plan_preference(pref);
403            let _ = ct_result::destroy_operation_descriptor(op);
404            let _ = ct_result::destroy_tensor_descriptor(desc_d);
405            let _ = ct_result::destroy_tensor_descriptor(desc_c);
406            let _ = ct_result::destroy_tensor_descriptor(desc_a);
407        }
408        GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
409    })?;
410
411    let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
412        unsafe {
413            let _ = ct_result::destroy_plan_preference(pref);
414            let _ = ct_result::destroy_operation_descriptor(op);
415            let _ = ct_result::destroy_tensor_descriptor(desc_d);
416            let _ = ct_result::destroy_tensor_descriptor(desc_c);
417            let _ = ct_result::destroy_tensor_descriptor(desc_a);
418        }
419        GpuError::lib(LIB, format!("CreatePlan: {e}"))
420    })?;
421
422    Ok(CachedPlan {
423        plan,
424        pref,
425        op,
426        descs: vec![desc_a, desc_c, desc_d],
427        workspace_size: ws_size,
428    })
429}
430
431// ---------------------------------------------------------------------
432// Trinary
433// ---------------------------------------------------------------------
434
435/// D = op_ABC( op_AB(alpha * op_A(A), beta * op_B(B)), gamma * op_C(C) ).
436pub struct ElementwiseTrinaryRequest<T: TensorSupported> {
437    pub a: OperandSpec<T>,
438    pub b: OperandSpec<T>,
439    pub c: OperandSpec<T>,
440    pub d: OperandSpec<T>,
441    pub alpha: T,
442    pub beta: T,
443    pub gamma: T,
444    pub op_a: ct_sys::cutensorOperator_t,
445    pub op_b: ct_sys::cutensorOperator_t,
446    pub op_c: ct_sys::cutensorOperator_t,
447    pub op_ab: ct_sys::cutensorOperator_t,
448    pub op_abc: ct_sys::cutensorOperator_t,
449    pub compute: ComputeDesc,
450    pub alignment: u32,
451    pub reply: oneshot::Sender<Result<(), GpuError>>,
452}
453
454impl<T: TensorSupported> ElementwiseTrinaryRequest<T> {
455    #[allow(clippy::too_many_arguments)]
456    pub fn new(
457        a: OperandSpec<T>,
458        b: OperandSpec<T>,
459        c: OperandSpec<T>,
460        d: OperandSpec<T>,
461        alpha: T,
462        beta: T,
463        gamma: T,
464        op_ab: ct_sys::cutensorOperator_t,
465        op_abc: ct_sys::cutensorOperator_t,
466        reply: oneshot::Sender<Result<(), GpuError>>,
467    ) -> Self {
468        let id = ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY;
469        Self {
470            a,
471            b,
472            c,
473            d,
474            alpha,
475            beta,
476            gamma,
477            op_a: id,
478            op_b: id,
479            op_c: id,
480            op_ab,
481            op_abc,
482            compute: super::contract::default_compute_for::<T>(),
483            alignment: 16,
484            reply,
485        }
486    }
487}
488
489impl<T: TensorSupported> TensorDispatch for ElementwiseTrinaryRequest<T> {
490    fn op_tag(&self) -> &'static str {
491        "ewtri"
492    }
493    fn dtype_tag(&self) -> &'static str {
494        <T as atomr_accel::AccelDtype>::NAME
495    }
496    fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
497        execute_trinary(*self, ctx);
498    }
499    fn fail_mock(self: Box<Self>) {
500        let _ = self.reply.send(Err(GpuError::Unrecoverable(
501            "TensorActor in mock mode".into(),
502        )));
503    }
504}
505
506fn execute_trinary<T: TensorSupported>(req: ElementwiseTrinaryRequest<T>, ctx: &TensorDispatchCtx) {
507    let ElementwiseTrinaryRequest {
508        a,
509        b,
510        c,
511        d,
512        alpha,
513        beta,
514        gamma,
515        op_a,
516        op_b,
517        op_c,
518        op_ab,
519        op_abc,
520        compute,
521        alignment,
522        reply,
523    } = req;
524
525    if a.extent.len() != a.modes.len()
526        || b.extent.len() != b.modes.len()
527        || c.extent.len() != c.modes.len()
528        || d.extent.len() != d.modes.len()
529    {
530        let _ = reply.send(Err(GpuError::Unrecoverable(
531            "ElementwiseTrinary: extent/modes length mismatch".into(),
532        )));
533        return;
534    }
535
536    let a_slice = match a.buf.access() {
537        Ok(s) => s.clone(),
538        Err(e) => {
539            let _ = reply.send(Err(e));
540            return;
541        }
542    };
543    let b_slice = match b.buf.access() {
544        Ok(s) => s.clone(),
545        Err(e) => {
546            let _ = reply.send(Err(e));
547            return;
548        }
549    };
550    let c_slice = match c.buf.access() {
551        Ok(s) => s.clone(),
552        Err(e) => {
553            let _ = reply.send(Err(e));
554            return;
555        }
556    };
557    let d_slice = match d.buf.access() {
558        Ok(s) => s.clone(),
559        Err(e) => {
560            let _ = reply.send(Err(e));
561            return;
562        }
563    };
564    let mut d_owned = match Arc::try_unwrap(d_slice) {
565        Ok(s) => s,
566        Err(_) => {
567            let _ = reply.send(Err(GpuError::Unrecoverable(
568                "ElementwiseTrinary d has multiple live references".into(),
569            )));
570            return;
571        }
572    };
573
574    let key = build_trinary_key_raw(
575        <T as atomr_accel::AccelDtype>::NAME,
576        &a.modes,
577        &b.modes,
578        &c.modes,
579        &d.modes,
580        &a.extent,
581        &b.extent,
582        &c.extent,
583        &d.extent,
584        alignment,
585        compute,
586        op_a,
587        op_b,
588        op_c,
589        op_ab,
590        op_abc,
591    );
592    let cached = match get_or_build_trinary::<T>(
593        &ctx.handle,
594        &ctx.plan_cache,
595        &key,
596        &a,
597        &b,
598        &c,
599        &d,
600        alignment,
601        compute,
602        op_a,
603        op_b,
604        op_c,
605        op_ab,
606        op_abc,
607    ) {
608        Ok(p) => p,
609        Err(e) => {
610            let _ = reply.send(Err(e));
611            return;
612        }
613    };
614
615    d.buf.record_write(&ctx.stream);
616
617    let stream_for_check = ctx.stream.clone();
618    let handle_clone = ctx.handle.clone();
619    let plan_keepalive = cached.clone();
620
621    envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
622        let h = handle_clone.lock();
623        let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
624        let (b_ptr, _gb) = b_slice.device_ptr(&stream_for_check);
625        let (c_ptr, _gc) = c_slice.device_ptr(&stream_for_check);
626        let (d_ptr, _gd) = d_owned.device_ptr_mut(&stream_for_check);
627        let alpha_h = alpha;
628        let beta_h = beta;
629        let gamma_h = gamma;
630        let res = unsafe {
631            ct_local::elementwise_trinary_execute(
632                h.0,
633                plan_keepalive.plan,
634                &alpha_h as *const T as *const _,
635                a_ptr as *const _,
636                &beta_h as *const T as *const _,
637                b_ptr as *const _,
638                &gamma_h as *const T as *const _,
639                c_ptr as *const _,
640                d_ptr as *mut _,
641                stream_for_check.cu_stream() as *mut _,
642            )
643        };
644        drop((_ga, _gb, _gc, _gd));
645        match res {
646            Ok(()) => Ok((d_owned, a_slice, b_slice, c_slice, plan_keepalive)),
647            Err(e) => Err(GpuError::LibraryError {
648                lib: LIB,
649                msg: format!("ElementwiseTrinary: {e}"),
650            }),
651        }
652    });
653}
654
655#[allow(clippy::too_many_arguments)]
656pub fn build_trinary_key_raw(
657    dtype_tag: &'static str,
658    modes_a: &[i32],
659    modes_b: &[i32],
660    modes_c: &[i32],
661    modes_d: &[i32],
662    extent_a: &[i64],
663    extent_b: &[i64],
664    extent_c: &[i64],
665    extent_d: &[i64],
666    alignment: u32,
667    compute: ComputeDesc,
668    op_a: ct_sys::cutensorOperator_t,
669    op_b: ct_sys::cutensorOperator_t,
670    op_c: ct_sys::cutensorOperator_t,
671    op_ab: ct_sys::cutensorOperator_t,
672    op_abc: ct_sys::cutensorOperator_t,
673) -> PlanKey {
674    let mut modes =
675        Vec::with_capacity(modes_a.len() + modes_b.len() + modes_c.len() + modes_d.len() + 4);
676    modes.extend_from_slice(modes_a);
677    modes.push(i32::MIN);
678    modes.extend_from_slice(modes_b);
679    modes.push(i32::MIN);
680    modes.extend_from_slice(modes_c);
681    modes.push(i32::MIN);
682    modes.extend_from_slice(modes_d);
683    let mut extents =
684        Vec::with_capacity(extent_a.len() + extent_b.len() + extent_c.len() + extent_d.len() + 4);
685    extents.extend_from_slice(extent_a);
686    extents.push(i64::MIN);
687    extents.extend_from_slice(extent_b);
688    extents.push(i64::MIN);
689    extents.extend_from_slice(extent_c);
690    extents.push(i64::MIN);
691    extents.extend_from_slice(extent_d);
692    let op_mix = ((op_a as u32).wrapping_mul(0x85eb_ca77))
693        ^ ((op_b as u32).wrapping_mul(0xc2b2_ae3d))
694        ^ ((op_c as u32).wrapping_mul(0x27d4_eb2f))
695        ^ ((op_ab as u32).wrapping_mul(0x9e37_79b9))
696        ^ ((op_abc as u32).wrapping_mul(0x6a09_e667));
697    PlanKey {
698        op_kind: OpKind::ElementwiseTrinary,
699        modes_hash: hash_i32s(&modes),
700        extents_hash: hash_i64s(&extents),
701        alignment,
702        compute_desc_tag: compute_desc_tag(compute) ^ op_mix,
703        dtype_tag,
704        algo: 0,
705    }
706}
707
708#[allow(clippy::too_many_arguments)]
709fn get_or_build_trinary<T: TensorSupported>(
710    handle: &Arc<Mutex<SendHandle>>,
711    plan_cache: &Arc<PlanCache>,
712    key: &PlanKey,
713    a: &OperandSpec<T>,
714    b: &OperandSpec<T>,
715    c: &OperandSpec<T>,
716    d: &OperandSpec<T>,
717    alignment: u32,
718    compute: ComputeDesc,
719    op_a: ct_sys::cutensorOperator_t,
720    op_b: ct_sys::cutensorOperator_t,
721    op_c: ct_sys::cutensorOperator_t,
722    op_ab: ct_sys::cutensorOperator_t,
723    op_abc: ct_sys::cutensorOperator_t,
724) -> Result<Arc<CachedPlan>, GpuError> {
725    if let Some(p) = plan_cache.get(key) {
726        return Ok(p);
727    }
728    let plan = build_trinary_plan::<T>(
729        handle, a, b, c, d, alignment, compute, op_a, op_b, op_c, op_ab, op_abc,
730    )?;
731    let arc = Arc::new(plan);
732    plan_cache.put(*key, arc.clone());
733    Ok(arc)
734}
735
736#[allow(clippy::too_many_arguments)]
737fn build_trinary_plan<T: TensorSupported>(
738    handle: &Arc<Mutex<SendHandle>>,
739    a: &OperandSpec<T>,
740    b: &OperandSpec<T>,
741    c: &OperandSpec<T>,
742    d: &OperandSpec<T>,
743    alignment: u32,
744    compute: ComputeDesc,
745    op_a: ct_sys::cutensorOperator_t,
746    op_b: ct_sys::cutensorOperator_t,
747    op_c: ct_sys::cutensorOperator_t,
748    op_ab: ct_sys::cutensorOperator_t,
749    op_abc: ct_sys::cutensorOperator_t,
750) -> Result<CachedPlan, GpuError> {
751    let h = handle.lock();
752    let dt: cudarc::cutensor::sys::cudaDataType_t =
753        unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
754    let cd = resolve_compute_desc(compute);
755    let stride_ptr = |v: &Vec<i64>| {
756        if v.is_empty() {
757            std::ptr::null()
758        } else {
759            v.as_ptr()
760        }
761    };
762    let desc_a = unsafe {
763        ct_result::create_tensor_descriptor(
764            h.0,
765            a.extent.len() as u32,
766            a.extent.as_ptr(),
767            stride_ptr(&a.stride),
768            dt,
769            alignment,
770        )
771    }
772    .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
773    let desc_b = unsafe {
774        ct_result::create_tensor_descriptor(
775            h.0,
776            b.extent.len() as u32,
777            b.extent.as_ptr(),
778            stride_ptr(&b.stride),
779            dt,
780            alignment,
781        )
782    }
783    .map_err(|e| {
784        unsafe {
785            let _ = ct_result::destroy_tensor_descriptor(desc_a);
786        }
787        GpuError::lib(LIB, format!("CreateTensorDescriptor(B): {e}"))
788    })?;
789    let desc_c = unsafe {
790        ct_result::create_tensor_descriptor(
791            h.0,
792            c.extent.len() as u32,
793            c.extent.as_ptr(),
794            stride_ptr(&c.stride),
795            dt,
796            alignment,
797        )
798    }
799    .map_err(|e| {
800        unsafe {
801            let _ = ct_result::destroy_tensor_descriptor(desc_b);
802            let _ = ct_result::destroy_tensor_descriptor(desc_a);
803        }
804        GpuError::lib(LIB, format!("CreateTensorDescriptor(C): {e}"))
805    })?;
806    let desc_d = unsafe {
807        ct_result::create_tensor_descriptor(
808            h.0,
809            d.extent.len() as u32,
810            d.extent.as_ptr(),
811            stride_ptr(&d.stride),
812            dt,
813            alignment,
814        )
815    }
816    .map_err(|e| {
817        unsafe {
818            let _ = ct_result::destroy_tensor_descriptor(desc_c);
819            let _ = ct_result::destroy_tensor_descriptor(desc_b);
820            let _ = ct_result::destroy_tensor_descriptor(desc_a);
821        }
822        GpuError::lib(LIB, format!("CreateTensorDescriptor(D): {e}"))
823    })?;
824
825    let op = unsafe {
826        ct_local::create_elementwise_trinary(
827            h.0,
828            desc_a,
829            a.modes.as_ptr(),
830            op_a,
831            desc_b,
832            b.modes.as_ptr(),
833            op_b,
834            desc_c,
835            c.modes.as_ptr(),
836            op_c,
837            desc_d,
838            d.modes.as_ptr(),
839            op_ab,
840            op_abc,
841            cd,
842        )
843    }
844    .map_err(|e| {
845        unsafe {
846            let _ = ct_result::destroy_tensor_descriptor(desc_d);
847            let _ = ct_result::destroy_tensor_descriptor(desc_c);
848            let _ = ct_result::destroy_tensor_descriptor(desc_b);
849            let _ = ct_result::destroy_tensor_descriptor(desc_a);
850        }
851        GpuError::lib(LIB, format!("CreateElementwiseTrinary: {e}"))
852    })?;
853
854    let pref = unsafe {
855        ct_result::create_plan_preference(
856            h.0,
857            ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
858            ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
859        )
860    }
861    .map_err(|e| {
862        unsafe {
863            let _ = ct_result::destroy_operation_descriptor(op);
864            let _ = ct_result::destroy_tensor_descriptor(desc_d);
865            let _ = ct_result::destroy_tensor_descriptor(desc_c);
866            let _ = ct_result::destroy_tensor_descriptor(desc_b);
867            let _ = ct_result::destroy_tensor_descriptor(desc_a);
868        }
869        GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
870    })?;
871
872    let ws_size = unsafe {
873        ct_result::estimate_workspace_size(
874            h.0,
875            op,
876            pref,
877            ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
878        )
879    }
880    .map_err(|e| {
881        unsafe {
882            let _ = ct_result::destroy_plan_preference(pref);
883            let _ = ct_result::destroy_operation_descriptor(op);
884            let _ = ct_result::destroy_tensor_descriptor(desc_d);
885            let _ = ct_result::destroy_tensor_descriptor(desc_c);
886            let _ = ct_result::destroy_tensor_descriptor(desc_b);
887            let _ = ct_result::destroy_tensor_descriptor(desc_a);
888        }
889        GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
890    })?;
891
892    let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
893        unsafe {
894            let _ = ct_result::destroy_plan_preference(pref);
895            let _ = ct_result::destroy_operation_descriptor(op);
896            let _ = ct_result::destroy_tensor_descriptor(desc_d);
897            let _ = ct_result::destroy_tensor_descriptor(desc_c);
898            let _ = ct_result::destroy_tensor_descriptor(desc_b);
899            let _ = ct_result::destroy_tensor_descriptor(desc_a);
900        }
901        GpuError::lib(LIB, format!("CreatePlan: {e}"))
902    })?;
903
904    Ok(CachedPlan {
905        plan,
906        pref,
907        op,
908        descs: vec![desc_a, desc_b, desc_c, desc_d],
909        workspace_size: ws_size,
910    })
911}
912
913#[cfg(test)]
914mod tests {
915    use super::*;
916
917    #[test]
918    fn trinary_binary_request_round_trip() {
919        let id = ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY;
920        let add = ct_sys::cutensorOperator_t::CUTENSOR_OP_ADD;
921        let mul = ct_sys::cutensorOperator_t::CUTENSOR_OP_MUL;
922
923        // Binary
924        let key_b1 = build_binary_key_raw(
925            <f32 as atomr_accel::AccelDtype>::NAME,
926            &[1, 2],
927            &[1, 2],
928            &[1, 2],
929            &[8, 16],
930            &[8, 16],
931            &[8, 16],
932            16,
933            ComputeDesc::MinF32,
934            id,
935            id,
936            add,
937        );
938        let key_b2 = build_binary_key_raw(
939            <f32 as atomr_accel::AccelDtype>::NAME,
940            &[1, 2],
941            &[1, 2],
942            &[1, 2],
943            &[8, 16],
944            &[8, 16],
945            &[8, 16],
946            16,
947            ComputeDesc::MinF32,
948            id,
949            id,
950            mul,
951        );
952        // Same shapes, different op_ac → distinct.
953        assert_ne!(key_b1, key_b2);
954        assert_eq!(key_b1.op_kind, OpKind::ElementwiseBinary);
955
956        // Trinary
957        let key_t1 = build_trinary_key_raw(
958            <f64 as atomr_accel::AccelDtype>::NAME,
959            &[1, 2],
960            &[1, 2],
961            &[1, 2],
962            &[1, 2],
963            &[8, 16],
964            &[8, 16],
965            &[8, 16],
966            &[8, 16],
967            16,
968            ComputeDesc::MinF64,
969            id,
970            id,
971            id,
972            add,
973            mul,
974        );
975        let key_t2 = build_trinary_key_raw(
976            <f64 as atomr_accel::AccelDtype>::NAME,
977            &[1, 2],
978            &[1, 2],
979            &[1, 2],
980            &[1, 2],
981            &[8, 16],
982            &[8, 16],
983            &[8, 16],
984            &[8, 16],
985            16,
986            ComputeDesc::MinF64,
987            id,
988            id,
989            id,
990            mul,
991            add,
992        );
993        assert_ne!(key_t1, key_t2);
994        assert_eq!(key_t1.op_kind, OpKind::ElementwiseTrinary);
995        assert_eq!(key_t1.dtype_tag, "f64");
996
997        // Cross-op different op-kinds never collide.
998        assert_ne!(key_b1, key_t1);
999    }
1000}