1use std::sync::Arc;
9
10use cudarc::cutensor::result as ct_result;
11use cudarc::cutensor::sys as ct_sys;
12use cudarc::driver::{DevicePtr, DevicePtrMut};
13use parking_lot::Mutex;
14use tokio::sync::oneshot;
15
16use crate::dtype::TensorSupported;
17use crate::error::GpuError;
18use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
19use crate::kernel::envelope;
20use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
21use crate::kernel::tensor::contract::OperandSpec;
22use crate::kernel::tensor::plan_cache::{
23 hash_i32s, hash_i64s, CachedPlan, OpKind, PlanCache, PlanKey,
24};
25use crate::kernel::tensor::SendHandle;
26use crate::sys::cutensor as ct_local;
27
28const LIB: &str = "cutensor";
29
30pub struct ElementwiseBinaryRequest<T: TensorSupported> {
36 pub a: OperandSpec<T>,
37 pub c: OperandSpec<T>,
38 pub d: OperandSpec<T>,
39 pub alpha: T,
40 pub gamma: T,
41 pub op_a: ct_sys::cutensorOperator_t,
42 pub op_c: ct_sys::cutensorOperator_t,
43 pub op_ac: ct_sys::cutensorOperator_t,
44 pub compute: ComputeDesc,
45 pub alignment: u32,
46 pub reply: oneshot::Sender<Result<(), GpuError>>,
47}
48
49impl<T: TensorSupported> ElementwiseBinaryRequest<T> {
50 pub fn new(
51 a: OperandSpec<T>,
52 c: OperandSpec<T>,
53 d: OperandSpec<T>,
54 alpha: T,
55 gamma: T,
56 op_ac: ct_sys::cutensorOperator_t,
57 reply: oneshot::Sender<Result<(), GpuError>>,
58 ) -> Self {
59 Self {
60 a,
61 c,
62 d,
63 alpha,
64 gamma,
65 op_a: ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
66 op_c: ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
67 op_ac,
68 compute: super::contract::default_compute_for::<T>(),
69 alignment: 16,
70 reply,
71 }
72 }
73}
74
75impl<T: TensorSupported> TensorDispatch for ElementwiseBinaryRequest<T> {
76 fn op_tag(&self) -> &'static str {
77 "ewbin"
78 }
79 fn dtype_tag(&self) -> &'static str {
80 <T as atomr_accel::AccelDtype>::NAME
81 }
82 fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
83 execute_binary(*self, ctx);
84 }
85 fn fail_mock(self: Box<Self>) {
86 let _ = self.reply.send(Err(GpuError::Unrecoverable(
87 "TensorActor in mock mode".into(),
88 )));
89 }
90}
91
92fn execute_binary<T: TensorSupported>(req: ElementwiseBinaryRequest<T>, ctx: &TensorDispatchCtx) {
93 let ElementwiseBinaryRequest {
94 a,
95 c,
96 d,
97 alpha,
98 gamma,
99 op_a,
100 op_c,
101 op_ac,
102 compute,
103 alignment,
104 reply,
105 } = req;
106
107 if a.extent.len() != a.modes.len()
108 || c.extent.len() != c.modes.len()
109 || d.extent.len() != d.modes.len()
110 {
111 let _ = reply.send(Err(GpuError::Unrecoverable(
112 "ElementwiseBinary: extent/modes length mismatch".into(),
113 )));
114 return;
115 }
116
117 let a_slice = match a.buf.access() {
118 Ok(s) => s.clone(),
119 Err(e) => {
120 let _ = reply.send(Err(e));
121 return;
122 }
123 };
124 let c_slice = match c.buf.access() {
125 Ok(s) => s.clone(),
126 Err(e) => {
127 let _ = reply.send(Err(e));
128 return;
129 }
130 };
131 let d_slice = match d.buf.access() {
132 Ok(s) => s.clone(),
133 Err(e) => {
134 let _ = reply.send(Err(e));
135 return;
136 }
137 };
138 let mut d_owned = match Arc::try_unwrap(d_slice) {
139 Ok(s) => s,
140 Err(_) => {
141 let _ = reply.send(Err(GpuError::Unrecoverable(
142 "ElementwiseBinary d has multiple live references".into(),
143 )));
144 return;
145 }
146 };
147
148 let key = build_binary_key_raw(
149 <T as atomr_accel::AccelDtype>::NAME,
150 &a.modes,
151 &c.modes,
152 &d.modes,
153 &a.extent,
154 &c.extent,
155 &d.extent,
156 alignment,
157 compute,
158 op_a,
159 op_c,
160 op_ac,
161 );
162 let cached = match get_or_build_binary::<T>(
163 &ctx.handle,
164 &ctx.plan_cache,
165 &key,
166 &a,
167 &c,
168 &d,
169 alignment,
170 compute,
171 op_a,
172 op_c,
173 op_ac,
174 ) {
175 Ok(p) => p,
176 Err(e) => {
177 let _ = reply.send(Err(e));
178 return;
179 }
180 };
181
182 d.buf.record_write(&ctx.stream);
183
184 let stream_for_check = ctx.stream.clone();
185 let handle_clone = ctx.handle.clone();
186 let plan_keepalive = cached.clone();
187
188 envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
189 let h = handle_clone.lock();
190 let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
191 let (c_ptr, _gc) = c_slice.device_ptr(&stream_for_check);
192 let (d_ptr, _gd) = d_owned.device_ptr_mut(&stream_for_check);
193 let alpha_h = alpha;
194 let gamma_h = gamma;
195 let res = unsafe {
196 ct_local::elementwise_binary_execute(
197 h.0,
198 plan_keepalive.plan,
199 &alpha_h as *const T as *const _,
200 a_ptr as *const _,
201 &gamma_h as *const T as *const _,
202 c_ptr as *const _,
203 d_ptr as *mut _,
204 stream_for_check.cu_stream() as *mut _,
205 )
206 };
207 drop((_ga, _gc, _gd));
208 match res {
209 Ok(()) => Ok((d_owned, a_slice, c_slice, plan_keepalive)),
210 Err(e) => Err(GpuError::LibraryError {
211 lib: LIB,
212 msg: format!("ElementwiseBinary: {e}"),
213 }),
214 }
215 });
216}
217
218#[allow(clippy::too_many_arguments)]
219pub fn build_binary_key_raw(
220 dtype_tag: &'static str,
221 modes_a: &[i32],
222 modes_c: &[i32],
223 modes_d: &[i32],
224 extent_a: &[i64],
225 extent_c: &[i64],
226 extent_d: &[i64],
227 alignment: u32,
228 compute: ComputeDesc,
229 op_a: ct_sys::cutensorOperator_t,
230 op_c: ct_sys::cutensorOperator_t,
231 op_ac: ct_sys::cutensorOperator_t,
232) -> PlanKey {
233 let mut modes = Vec::with_capacity(modes_a.len() + modes_c.len() + modes_d.len() + 3);
234 modes.extend_from_slice(modes_a);
235 modes.push(i32::MIN);
236 modes.extend_from_slice(modes_c);
237 modes.push(i32::MIN);
238 modes.extend_from_slice(modes_d);
239 let mut extents = Vec::with_capacity(extent_a.len() + extent_c.len() + extent_d.len() + 3);
240 extents.extend_from_slice(extent_a);
241 extents.push(i64::MIN);
242 extents.extend_from_slice(extent_c);
243 extents.push(i64::MIN);
244 extents.extend_from_slice(extent_d);
245 let op_mix = ((op_a as u32).wrapping_mul(0x85eb_ca77))
246 ^ ((op_c as u32).wrapping_mul(0xc2b2_ae3d))
247 ^ ((op_ac as u32).wrapping_mul(0x27d4_eb2f));
248 PlanKey {
249 op_kind: OpKind::ElementwiseBinary,
250 modes_hash: hash_i32s(&modes),
251 extents_hash: hash_i64s(&extents),
252 alignment,
253 compute_desc_tag: compute_desc_tag(compute) ^ op_mix,
254 dtype_tag,
255 algo: 0,
256 }
257}
258
259#[allow(clippy::too_many_arguments)]
260fn get_or_build_binary<T: TensorSupported>(
261 handle: &Arc<Mutex<SendHandle>>,
262 plan_cache: &Arc<PlanCache>,
263 key: &PlanKey,
264 a: &OperandSpec<T>,
265 c: &OperandSpec<T>,
266 d: &OperandSpec<T>,
267 alignment: u32,
268 compute: ComputeDesc,
269 op_a: ct_sys::cutensorOperator_t,
270 op_c: ct_sys::cutensorOperator_t,
271 op_ac: ct_sys::cutensorOperator_t,
272) -> Result<Arc<CachedPlan>, GpuError> {
273 if let Some(p) = plan_cache.get(key) {
274 return Ok(p);
275 }
276 let plan = build_binary_plan::<T>(handle, a, c, d, alignment, compute, op_a, op_c, op_ac)?;
277 let arc = Arc::new(plan);
278 plan_cache.put(*key, arc.clone());
279 Ok(arc)
280}
281
282#[allow(clippy::too_many_arguments)]
283fn build_binary_plan<T: TensorSupported>(
284 handle: &Arc<Mutex<SendHandle>>,
285 a: &OperandSpec<T>,
286 c: &OperandSpec<T>,
287 d: &OperandSpec<T>,
288 alignment: u32,
289 compute: ComputeDesc,
290 op_a: ct_sys::cutensorOperator_t,
291 op_c: ct_sys::cutensorOperator_t,
292 op_ac: ct_sys::cutensorOperator_t,
293) -> Result<CachedPlan, GpuError> {
294 let h = handle.lock();
295 let dt: cudarc::cutensor::sys::cudaDataType_t =
296 unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
297 let cd = resolve_compute_desc(compute);
298 let stride_ptr = |v: &Vec<i64>| {
299 if v.is_empty() {
300 std::ptr::null()
301 } else {
302 v.as_ptr()
303 }
304 };
305
306 let desc_a = unsafe {
307 ct_result::create_tensor_descriptor(
308 h.0,
309 a.extent.len() as u32,
310 a.extent.as_ptr(),
311 stride_ptr(&a.stride),
312 dt,
313 alignment,
314 )
315 }
316 .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
317 let desc_c = unsafe {
318 ct_result::create_tensor_descriptor(
319 h.0,
320 c.extent.len() as u32,
321 c.extent.as_ptr(),
322 stride_ptr(&c.stride),
323 dt,
324 alignment,
325 )
326 }
327 .map_err(|e| {
328 unsafe {
329 let _ = ct_result::destroy_tensor_descriptor(desc_a);
330 }
331 GpuError::lib(LIB, format!("CreateTensorDescriptor(C): {e}"))
332 })?;
333 let desc_d = unsafe {
334 ct_result::create_tensor_descriptor(
335 h.0,
336 d.extent.len() as u32,
337 d.extent.as_ptr(),
338 stride_ptr(&d.stride),
339 dt,
340 alignment,
341 )
342 }
343 .map_err(|e| {
344 unsafe {
345 let _ = ct_result::destroy_tensor_descriptor(desc_c);
346 let _ = ct_result::destroy_tensor_descriptor(desc_a);
347 }
348 GpuError::lib(LIB, format!("CreateTensorDescriptor(D): {e}"))
349 })?;
350
351 let op = unsafe {
352 ct_local::create_elementwise_binary(
353 h.0,
354 desc_a,
355 a.modes.as_ptr(),
356 op_a,
357 desc_c,
358 c.modes.as_ptr(),
359 op_c,
360 desc_d,
361 d.modes.as_ptr(),
362 op_ac,
363 cd,
364 )
365 }
366 .map_err(|e| {
367 unsafe {
368 let _ = ct_result::destroy_tensor_descriptor(desc_d);
369 let _ = ct_result::destroy_tensor_descriptor(desc_c);
370 let _ = ct_result::destroy_tensor_descriptor(desc_a);
371 }
372 GpuError::lib(LIB, format!("CreateElementwiseBinary: {e}"))
373 })?;
374
375 let pref = unsafe {
376 ct_result::create_plan_preference(
377 h.0,
378 ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
379 ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
380 )
381 }
382 .map_err(|e| {
383 unsafe {
384 let _ = ct_result::destroy_operation_descriptor(op);
385 let _ = ct_result::destroy_tensor_descriptor(desc_d);
386 let _ = ct_result::destroy_tensor_descriptor(desc_c);
387 let _ = ct_result::destroy_tensor_descriptor(desc_a);
388 }
389 GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
390 })?;
391
392 let ws_size = unsafe {
393 ct_result::estimate_workspace_size(
394 h.0,
395 op,
396 pref,
397 ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
398 )
399 }
400 .map_err(|e| {
401 unsafe {
402 let _ = ct_result::destroy_plan_preference(pref);
403 let _ = ct_result::destroy_operation_descriptor(op);
404 let _ = ct_result::destroy_tensor_descriptor(desc_d);
405 let _ = ct_result::destroy_tensor_descriptor(desc_c);
406 let _ = ct_result::destroy_tensor_descriptor(desc_a);
407 }
408 GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
409 })?;
410
411 let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
412 unsafe {
413 let _ = ct_result::destroy_plan_preference(pref);
414 let _ = ct_result::destroy_operation_descriptor(op);
415 let _ = ct_result::destroy_tensor_descriptor(desc_d);
416 let _ = ct_result::destroy_tensor_descriptor(desc_c);
417 let _ = ct_result::destroy_tensor_descriptor(desc_a);
418 }
419 GpuError::lib(LIB, format!("CreatePlan: {e}"))
420 })?;
421
422 Ok(CachedPlan {
423 plan,
424 pref,
425 op,
426 descs: vec![desc_a, desc_c, desc_d],
427 workspace_size: ws_size,
428 })
429}
430
431pub struct ElementwiseTrinaryRequest<T: TensorSupported> {
437 pub a: OperandSpec<T>,
438 pub b: OperandSpec<T>,
439 pub c: OperandSpec<T>,
440 pub d: OperandSpec<T>,
441 pub alpha: T,
442 pub beta: T,
443 pub gamma: T,
444 pub op_a: ct_sys::cutensorOperator_t,
445 pub op_b: ct_sys::cutensorOperator_t,
446 pub op_c: ct_sys::cutensorOperator_t,
447 pub op_ab: ct_sys::cutensorOperator_t,
448 pub op_abc: ct_sys::cutensorOperator_t,
449 pub compute: ComputeDesc,
450 pub alignment: u32,
451 pub reply: oneshot::Sender<Result<(), GpuError>>,
452}
453
454impl<T: TensorSupported> ElementwiseTrinaryRequest<T> {
455 #[allow(clippy::too_many_arguments)]
456 pub fn new(
457 a: OperandSpec<T>,
458 b: OperandSpec<T>,
459 c: OperandSpec<T>,
460 d: OperandSpec<T>,
461 alpha: T,
462 beta: T,
463 gamma: T,
464 op_ab: ct_sys::cutensorOperator_t,
465 op_abc: ct_sys::cutensorOperator_t,
466 reply: oneshot::Sender<Result<(), GpuError>>,
467 ) -> Self {
468 let id = ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY;
469 Self {
470 a,
471 b,
472 c,
473 d,
474 alpha,
475 beta,
476 gamma,
477 op_a: id,
478 op_b: id,
479 op_c: id,
480 op_ab,
481 op_abc,
482 compute: super::contract::default_compute_for::<T>(),
483 alignment: 16,
484 reply,
485 }
486 }
487}
488
489impl<T: TensorSupported> TensorDispatch for ElementwiseTrinaryRequest<T> {
490 fn op_tag(&self) -> &'static str {
491 "ewtri"
492 }
493 fn dtype_tag(&self) -> &'static str {
494 <T as atomr_accel::AccelDtype>::NAME
495 }
496 fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
497 execute_trinary(*self, ctx);
498 }
499 fn fail_mock(self: Box<Self>) {
500 let _ = self.reply.send(Err(GpuError::Unrecoverable(
501 "TensorActor in mock mode".into(),
502 )));
503 }
504}
505
506fn execute_trinary<T: TensorSupported>(req: ElementwiseTrinaryRequest<T>, ctx: &TensorDispatchCtx) {
507 let ElementwiseTrinaryRequest {
508 a,
509 b,
510 c,
511 d,
512 alpha,
513 beta,
514 gamma,
515 op_a,
516 op_b,
517 op_c,
518 op_ab,
519 op_abc,
520 compute,
521 alignment,
522 reply,
523 } = req;
524
525 if a.extent.len() != a.modes.len()
526 || b.extent.len() != b.modes.len()
527 || c.extent.len() != c.modes.len()
528 || d.extent.len() != d.modes.len()
529 {
530 let _ = reply.send(Err(GpuError::Unrecoverable(
531 "ElementwiseTrinary: extent/modes length mismatch".into(),
532 )));
533 return;
534 }
535
536 let a_slice = match a.buf.access() {
537 Ok(s) => s.clone(),
538 Err(e) => {
539 let _ = reply.send(Err(e));
540 return;
541 }
542 };
543 let b_slice = match b.buf.access() {
544 Ok(s) => s.clone(),
545 Err(e) => {
546 let _ = reply.send(Err(e));
547 return;
548 }
549 };
550 let c_slice = match c.buf.access() {
551 Ok(s) => s.clone(),
552 Err(e) => {
553 let _ = reply.send(Err(e));
554 return;
555 }
556 };
557 let d_slice = match d.buf.access() {
558 Ok(s) => s.clone(),
559 Err(e) => {
560 let _ = reply.send(Err(e));
561 return;
562 }
563 };
564 let mut d_owned = match Arc::try_unwrap(d_slice) {
565 Ok(s) => s,
566 Err(_) => {
567 let _ = reply.send(Err(GpuError::Unrecoverable(
568 "ElementwiseTrinary d has multiple live references".into(),
569 )));
570 return;
571 }
572 };
573
574 let key = build_trinary_key_raw(
575 <T as atomr_accel::AccelDtype>::NAME,
576 &a.modes,
577 &b.modes,
578 &c.modes,
579 &d.modes,
580 &a.extent,
581 &b.extent,
582 &c.extent,
583 &d.extent,
584 alignment,
585 compute,
586 op_a,
587 op_b,
588 op_c,
589 op_ab,
590 op_abc,
591 );
592 let cached = match get_or_build_trinary::<T>(
593 &ctx.handle,
594 &ctx.plan_cache,
595 &key,
596 &a,
597 &b,
598 &c,
599 &d,
600 alignment,
601 compute,
602 op_a,
603 op_b,
604 op_c,
605 op_ab,
606 op_abc,
607 ) {
608 Ok(p) => p,
609 Err(e) => {
610 let _ = reply.send(Err(e));
611 return;
612 }
613 };
614
615 d.buf.record_write(&ctx.stream);
616
617 let stream_for_check = ctx.stream.clone();
618 let handle_clone = ctx.handle.clone();
619 let plan_keepalive = cached.clone();
620
621 envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
622 let h = handle_clone.lock();
623 let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
624 let (b_ptr, _gb) = b_slice.device_ptr(&stream_for_check);
625 let (c_ptr, _gc) = c_slice.device_ptr(&stream_for_check);
626 let (d_ptr, _gd) = d_owned.device_ptr_mut(&stream_for_check);
627 let alpha_h = alpha;
628 let beta_h = beta;
629 let gamma_h = gamma;
630 let res = unsafe {
631 ct_local::elementwise_trinary_execute(
632 h.0,
633 plan_keepalive.plan,
634 &alpha_h as *const T as *const _,
635 a_ptr as *const _,
636 &beta_h as *const T as *const _,
637 b_ptr as *const _,
638 &gamma_h as *const T as *const _,
639 c_ptr as *const _,
640 d_ptr as *mut _,
641 stream_for_check.cu_stream() as *mut _,
642 )
643 };
644 drop((_ga, _gb, _gc, _gd));
645 match res {
646 Ok(()) => Ok((d_owned, a_slice, b_slice, c_slice, plan_keepalive)),
647 Err(e) => Err(GpuError::LibraryError {
648 lib: LIB,
649 msg: format!("ElementwiseTrinary: {e}"),
650 }),
651 }
652 });
653}
654
655#[allow(clippy::too_many_arguments)]
656pub fn build_trinary_key_raw(
657 dtype_tag: &'static str,
658 modes_a: &[i32],
659 modes_b: &[i32],
660 modes_c: &[i32],
661 modes_d: &[i32],
662 extent_a: &[i64],
663 extent_b: &[i64],
664 extent_c: &[i64],
665 extent_d: &[i64],
666 alignment: u32,
667 compute: ComputeDesc,
668 op_a: ct_sys::cutensorOperator_t,
669 op_b: ct_sys::cutensorOperator_t,
670 op_c: ct_sys::cutensorOperator_t,
671 op_ab: ct_sys::cutensorOperator_t,
672 op_abc: ct_sys::cutensorOperator_t,
673) -> PlanKey {
674 let mut modes =
675 Vec::with_capacity(modes_a.len() + modes_b.len() + modes_c.len() + modes_d.len() + 4);
676 modes.extend_from_slice(modes_a);
677 modes.push(i32::MIN);
678 modes.extend_from_slice(modes_b);
679 modes.push(i32::MIN);
680 modes.extend_from_slice(modes_c);
681 modes.push(i32::MIN);
682 modes.extend_from_slice(modes_d);
683 let mut extents =
684 Vec::with_capacity(extent_a.len() + extent_b.len() + extent_c.len() + extent_d.len() + 4);
685 extents.extend_from_slice(extent_a);
686 extents.push(i64::MIN);
687 extents.extend_from_slice(extent_b);
688 extents.push(i64::MIN);
689 extents.extend_from_slice(extent_c);
690 extents.push(i64::MIN);
691 extents.extend_from_slice(extent_d);
692 let op_mix = ((op_a as u32).wrapping_mul(0x85eb_ca77))
693 ^ ((op_b as u32).wrapping_mul(0xc2b2_ae3d))
694 ^ ((op_c as u32).wrapping_mul(0x27d4_eb2f))
695 ^ ((op_ab as u32).wrapping_mul(0x9e37_79b9))
696 ^ ((op_abc as u32).wrapping_mul(0x6a09_e667));
697 PlanKey {
698 op_kind: OpKind::ElementwiseTrinary,
699 modes_hash: hash_i32s(&modes),
700 extents_hash: hash_i64s(&extents),
701 alignment,
702 compute_desc_tag: compute_desc_tag(compute) ^ op_mix,
703 dtype_tag,
704 algo: 0,
705 }
706}
707
708#[allow(clippy::too_many_arguments)]
709fn get_or_build_trinary<T: TensorSupported>(
710 handle: &Arc<Mutex<SendHandle>>,
711 plan_cache: &Arc<PlanCache>,
712 key: &PlanKey,
713 a: &OperandSpec<T>,
714 b: &OperandSpec<T>,
715 c: &OperandSpec<T>,
716 d: &OperandSpec<T>,
717 alignment: u32,
718 compute: ComputeDesc,
719 op_a: ct_sys::cutensorOperator_t,
720 op_b: ct_sys::cutensorOperator_t,
721 op_c: ct_sys::cutensorOperator_t,
722 op_ab: ct_sys::cutensorOperator_t,
723 op_abc: ct_sys::cutensorOperator_t,
724) -> Result<Arc<CachedPlan>, GpuError> {
725 if let Some(p) = plan_cache.get(key) {
726 return Ok(p);
727 }
728 let plan = build_trinary_plan::<T>(
729 handle, a, b, c, d, alignment, compute, op_a, op_b, op_c, op_ab, op_abc,
730 )?;
731 let arc = Arc::new(plan);
732 plan_cache.put(*key, arc.clone());
733 Ok(arc)
734}
735
736#[allow(clippy::too_many_arguments)]
737fn build_trinary_plan<T: TensorSupported>(
738 handle: &Arc<Mutex<SendHandle>>,
739 a: &OperandSpec<T>,
740 b: &OperandSpec<T>,
741 c: &OperandSpec<T>,
742 d: &OperandSpec<T>,
743 alignment: u32,
744 compute: ComputeDesc,
745 op_a: ct_sys::cutensorOperator_t,
746 op_b: ct_sys::cutensorOperator_t,
747 op_c: ct_sys::cutensorOperator_t,
748 op_ab: ct_sys::cutensorOperator_t,
749 op_abc: ct_sys::cutensorOperator_t,
750) -> Result<CachedPlan, GpuError> {
751 let h = handle.lock();
752 let dt: cudarc::cutensor::sys::cudaDataType_t =
753 unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
754 let cd = resolve_compute_desc(compute);
755 let stride_ptr = |v: &Vec<i64>| {
756 if v.is_empty() {
757 std::ptr::null()
758 } else {
759 v.as_ptr()
760 }
761 };
762 let desc_a = unsafe {
763 ct_result::create_tensor_descriptor(
764 h.0,
765 a.extent.len() as u32,
766 a.extent.as_ptr(),
767 stride_ptr(&a.stride),
768 dt,
769 alignment,
770 )
771 }
772 .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
773 let desc_b = unsafe {
774 ct_result::create_tensor_descriptor(
775 h.0,
776 b.extent.len() as u32,
777 b.extent.as_ptr(),
778 stride_ptr(&b.stride),
779 dt,
780 alignment,
781 )
782 }
783 .map_err(|e| {
784 unsafe {
785 let _ = ct_result::destroy_tensor_descriptor(desc_a);
786 }
787 GpuError::lib(LIB, format!("CreateTensorDescriptor(B): {e}"))
788 })?;
789 let desc_c = unsafe {
790 ct_result::create_tensor_descriptor(
791 h.0,
792 c.extent.len() as u32,
793 c.extent.as_ptr(),
794 stride_ptr(&c.stride),
795 dt,
796 alignment,
797 )
798 }
799 .map_err(|e| {
800 unsafe {
801 let _ = ct_result::destroy_tensor_descriptor(desc_b);
802 let _ = ct_result::destroy_tensor_descriptor(desc_a);
803 }
804 GpuError::lib(LIB, format!("CreateTensorDescriptor(C): {e}"))
805 })?;
806 let desc_d = unsafe {
807 ct_result::create_tensor_descriptor(
808 h.0,
809 d.extent.len() as u32,
810 d.extent.as_ptr(),
811 stride_ptr(&d.stride),
812 dt,
813 alignment,
814 )
815 }
816 .map_err(|e| {
817 unsafe {
818 let _ = ct_result::destroy_tensor_descriptor(desc_c);
819 let _ = ct_result::destroy_tensor_descriptor(desc_b);
820 let _ = ct_result::destroy_tensor_descriptor(desc_a);
821 }
822 GpuError::lib(LIB, format!("CreateTensorDescriptor(D): {e}"))
823 })?;
824
825 let op = unsafe {
826 ct_local::create_elementwise_trinary(
827 h.0,
828 desc_a,
829 a.modes.as_ptr(),
830 op_a,
831 desc_b,
832 b.modes.as_ptr(),
833 op_b,
834 desc_c,
835 c.modes.as_ptr(),
836 op_c,
837 desc_d,
838 d.modes.as_ptr(),
839 op_ab,
840 op_abc,
841 cd,
842 )
843 }
844 .map_err(|e| {
845 unsafe {
846 let _ = ct_result::destroy_tensor_descriptor(desc_d);
847 let _ = ct_result::destroy_tensor_descriptor(desc_c);
848 let _ = ct_result::destroy_tensor_descriptor(desc_b);
849 let _ = ct_result::destroy_tensor_descriptor(desc_a);
850 }
851 GpuError::lib(LIB, format!("CreateElementwiseTrinary: {e}"))
852 })?;
853
854 let pref = unsafe {
855 ct_result::create_plan_preference(
856 h.0,
857 ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
858 ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
859 )
860 }
861 .map_err(|e| {
862 unsafe {
863 let _ = ct_result::destroy_operation_descriptor(op);
864 let _ = ct_result::destroy_tensor_descriptor(desc_d);
865 let _ = ct_result::destroy_tensor_descriptor(desc_c);
866 let _ = ct_result::destroy_tensor_descriptor(desc_b);
867 let _ = ct_result::destroy_tensor_descriptor(desc_a);
868 }
869 GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
870 })?;
871
872 let ws_size = unsafe {
873 ct_result::estimate_workspace_size(
874 h.0,
875 op,
876 pref,
877 ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
878 )
879 }
880 .map_err(|e| {
881 unsafe {
882 let _ = ct_result::destroy_plan_preference(pref);
883 let _ = ct_result::destroy_operation_descriptor(op);
884 let _ = ct_result::destroy_tensor_descriptor(desc_d);
885 let _ = ct_result::destroy_tensor_descriptor(desc_c);
886 let _ = ct_result::destroy_tensor_descriptor(desc_b);
887 let _ = ct_result::destroy_tensor_descriptor(desc_a);
888 }
889 GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
890 })?;
891
892 let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
893 unsafe {
894 let _ = ct_result::destroy_plan_preference(pref);
895 let _ = ct_result::destroy_operation_descriptor(op);
896 let _ = ct_result::destroy_tensor_descriptor(desc_d);
897 let _ = ct_result::destroy_tensor_descriptor(desc_c);
898 let _ = ct_result::destroy_tensor_descriptor(desc_b);
899 let _ = ct_result::destroy_tensor_descriptor(desc_a);
900 }
901 GpuError::lib(LIB, format!("CreatePlan: {e}"))
902 })?;
903
904 Ok(CachedPlan {
905 plan,
906 pref,
907 op,
908 descs: vec![desc_a, desc_b, desc_c, desc_d],
909 workspace_size: ws_size,
910 })
911}
912
913#[cfg(test)]
914mod tests {
915 use super::*;
916
917 #[test]
918 fn trinary_binary_request_round_trip() {
919 let id = ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY;
920 let add = ct_sys::cutensorOperator_t::CUTENSOR_OP_ADD;
921 let mul = ct_sys::cutensorOperator_t::CUTENSOR_OP_MUL;
922
923 let key_b1 = build_binary_key_raw(
925 <f32 as atomr_accel::AccelDtype>::NAME,
926 &[1, 2],
927 &[1, 2],
928 &[1, 2],
929 &[8, 16],
930 &[8, 16],
931 &[8, 16],
932 16,
933 ComputeDesc::MinF32,
934 id,
935 id,
936 add,
937 );
938 let key_b2 = build_binary_key_raw(
939 <f32 as atomr_accel::AccelDtype>::NAME,
940 &[1, 2],
941 &[1, 2],
942 &[1, 2],
943 &[8, 16],
944 &[8, 16],
945 &[8, 16],
946 16,
947 ComputeDesc::MinF32,
948 id,
949 id,
950 mul,
951 );
952 assert_ne!(key_b1, key_b2);
954 assert_eq!(key_b1.op_kind, OpKind::ElementwiseBinary);
955
956 let key_t1 = build_trinary_key_raw(
958 <f64 as atomr_accel::AccelDtype>::NAME,
959 &[1, 2],
960 &[1, 2],
961 &[1, 2],
962 &[1, 2],
963 &[8, 16],
964 &[8, 16],
965 &[8, 16],
966 &[8, 16],
967 16,
968 ComputeDesc::MinF64,
969 id,
970 id,
971 id,
972 add,
973 mul,
974 );
975 let key_t2 = build_trinary_key_raw(
976 <f64 as atomr_accel::AccelDtype>::NAME,
977 &[1, 2],
978 &[1, 2],
979 &[1, 2],
980 &[1, 2],
981 &[8, 16],
982 &[8, 16],
983 &[8, 16],
984 &[8, 16],
985 16,
986 ComputeDesc::MinF64,
987 id,
988 id,
989 id,
990 mul,
991 add,
992 );
993 assert_ne!(key_t1, key_t2);
994 assert_eq!(key_t1.op_kind, OpKind::ElementwiseTrinary);
995 assert_eq!(key_t1.dtype_tag, "f64");
996
997 assert_ne!(key_b1, key_t1);
999 }
1000}