1use std::sync::Arc;
8
9use cudarc::cutensor::result as ct_result;
10use cudarc::cutensor::sys as ct_sys;
11use cudarc::driver::{DevicePtr, DevicePtrMut};
12use tokio::sync::oneshot;
13
14use crate::dtype::TensorSupported;
15use crate::error::GpuError;
16use crate::gpu_ref::GpuRef;
17use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
18use crate::kernel::envelope;
19use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
20use crate::kernel::tensor::plan_cache::{hash_i32s, hash_i64s, CachedPlan, OpKind, PlanKey};
21
22const LIB: &str = "cutensor";
23
24#[derive(Clone)]
27pub struct OperandSpec<T: TensorSupported> {
28 pub buf: GpuRef<T>,
29 pub extent: Vec<i64>,
30 pub stride: Vec<i64>,
32 pub modes: Vec<i32>,
33}
34
35pub struct ContractRequest<T: TensorSupported> {
37 pub a: OperandSpec<T>,
38 pub b: OperandSpec<T>,
39 pub c: OperandSpec<T>,
40 pub alpha: T,
41 pub beta: T,
42 pub compute: ComputeDesc,
43 pub alignment: u32,
45 pub reply: oneshot::Sender<Result<(), GpuError>>,
46}
47
48impl<T: TensorSupported> ContractRequest<T> {
49 pub fn new(
50 a: OperandSpec<T>,
51 b: OperandSpec<T>,
52 c: OperandSpec<T>,
53 alpha: T,
54 beta: T,
55 reply: oneshot::Sender<Result<(), GpuError>>,
56 ) -> Self {
57 Self {
58 a,
59 b,
60 c,
61 alpha,
62 beta,
63 compute: default_compute_for::<T>(),
64 alignment: 16,
65 reply,
66 }
67 }
68
69 pub fn with_compute(mut self, compute: ComputeDesc) -> Self {
70 self.compute = compute;
71 self
72 }
73}
74
75pub fn default_compute_for<T: TensorSupported>() -> ComputeDesc {
79 match <T as atomr_accel::AccelDtype>::NAME {
80 "f32" => ComputeDesc::MinF32,
81 "f64" => ComputeDesc::MinF64,
82 "f16" => ComputeDesc::MinF32,
83 "bf16" => ComputeDesc::MinF32,
84 _ => ComputeDesc::MinF32,
85 }
86}
87
88impl<T: TensorSupported> TensorDispatch for ContractRequest<T> {
89 fn op_tag(&self) -> &'static str {
90 "contract"
91 }
92
93 fn dtype_tag(&self) -> &'static str {
94 <T as atomr_accel::AccelDtype>::NAME
95 }
96
97 fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
98 execute(*self, ctx);
99 }
100
101 fn fail_mock(self: Box<Self>) {
102 let _ = self.reply.send(Err(GpuError::Unrecoverable(
103 "TensorActor in mock mode".into(),
104 )));
105 }
106}
107
108fn execute<T: TensorSupported>(req: ContractRequest<T>, ctx: &TensorDispatchCtx) {
109 let ContractRequest {
110 a,
111 b,
112 c,
113 alpha,
114 beta,
115 compute,
116 alignment,
117 reply,
118 } = req;
119
120 if a.extent.len() != a.modes.len() {
121 let _ = reply.send(Err(GpuError::Unrecoverable(
122 "Contract: a.extent.len != a.modes.len".into(),
123 )));
124 return;
125 }
126 if b.extent.len() != b.modes.len() {
127 let _ = reply.send(Err(GpuError::Unrecoverable(
128 "Contract: b.extent.len != b.modes.len".into(),
129 )));
130 return;
131 }
132 if c.extent.len() != c.modes.len() {
133 let _ = reply.send(Err(GpuError::Unrecoverable(
134 "Contract: c.extent.len != c.modes.len".into(),
135 )));
136 return;
137 }
138
139 let a_slice = match a.buf.access() {
140 Ok(s) => s.clone(),
141 Err(e) => {
142 let _ = reply.send(Err(e));
143 return;
144 }
145 };
146 let b_slice = match b.buf.access() {
147 Ok(s) => s.clone(),
148 Err(e) => {
149 let _ = reply.send(Err(e));
150 return;
151 }
152 };
153 let c_slice = match c.buf.access() {
154 Ok(s) => s.clone(),
155 Err(e) => {
156 let _ = reply.send(Err(e));
157 return;
158 }
159 };
160 let mut c_owned = match Arc::try_unwrap(c_slice) {
161 Ok(s) => s,
162 Err(_) => {
163 let _ = reply.send(Err(GpuError::Unrecoverable(
164 "Contract c has multiple live references".into(),
165 )));
166 return;
167 }
168 };
169
170 let key = build_key_for::<T>(&a, &b, &c, alignment, compute, 0);
171 let cached = match get_or_build_plan::<T>(ctx, &key, &a, &b, &c, alignment, compute, None) {
172 Ok(p) => p,
173 Err(e) => {
174 let _ = reply.send(Err(e));
175 return;
176 }
177 };
178
179 let ws_size = cached.workspace_size as usize;
180 if let Err(e) = ctx.workspace.ensure(ws_size) {
181 let _ = reply.send(Err(e));
182 return;
183 }
184
185 c.buf.record_write(&ctx.stream);
186
187 let stream_for_check = ctx.stream.clone();
188 let handle_clone = ctx.handle.clone();
189 let workspace = ctx.workspace.clone();
190 let plan_keepalive = cached.clone();
191
192 envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
193 let h = handle_clone.lock();
194 let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
195 let (b_ptr, _gb) = b_slice.device_ptr(&stream_for_check);
196 let (c_ptr, _gc) = c_owned.device_ptr_mut(&stream_for_check);
197 let alpha_h = alpha;
198 let beta_h = beta;
199 let res = workspace
200 .with_bucket(ws_size, |ws_slice| {
201 let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
202 let r = unsafe {
203 ct_result::contract(
204 h.0,
205 plan_keepalive.plan,
206 &alpha_h as *const T as *const _,
207 a_ptr as *const _,
208 b_ptr as *const _,
209 &beta_h as *const T as *const _,
210 c_ptr as *const _,
211 c_ptr as *mut _,
212 ws_ptr as *mut _,
213 plan_keepalive.workspace_size,
214 stream_for_check.cu_stream() as *mut _,
215 )
216 };
217 drop(_gws);
218 r
219 })
220 .unwrap_or_else(|| unsafe {
221 ct_result::contract(
222 h.0,
223 plan_keepalive.plan,
224 &alpha_h as *const T as *const _,
225 a_ptr as *const _,
226 b_ptr as *const _,
227 &beta_h as *const T as *const _,
228 c_ptr as *const _,
229 c_ptr as *mut _,
230 std::ptr::null_mut(),
231 0,
232 stream_for_check.cu_stream() as *mut _,
233 )
234 });
235 drop((_ga, _gb, _gc));
236 match res {
237 Ok(()) => Ok((c_owned, a_slice, b_slice, plan_keepalive)),
238 Err(e) => Err(GpuError::LibraryError {
239 lib: LIB,
240 msg: format!("Contract: {e}"),
241 }),
242 }
243 });
244}
245
246pub fn build_contract_key(
250 dtype_tag: &'static str,
251 modes_a: &[i32],
252 modes_b: &[i32],
253 modes_c: &[i32],
254 extent_a: &[i64],
255 extent_b: &[i64],
256 extent_c: &[i64],
257 alignment: u32,
258 compute: ComputeDesc,
259 algo: i32,
260) -> PlanKey {
261 let mut modes = Vec::with_capacity(modes_a.len() + modes_b.len() + modes_c.len() + 3);
262 modes.extend_from_slice(modes_a);
263 modes.push(i32::MIN);
264 modes.extend_from_slice(modes_b);
265 modes.push(i32::MIN);
266 modes.extend_from_slice(modes_c);
267 let mut extents = Vec::with_capacity(extent_a.len() + extent_b.len() + extent_c.len() + 3);
268 extents.extend_from_slice(extent_a);
269 extents.push(i64::MIN);
270 extents.extend_from_slice(extent_b);
271 extents.push(i64::MIN);
272 extents.extend_from_slice(extent_c);
273 PlanKey {
274 op_kind: OpKind::Contract,
275 modes_hash: hash_i32s(&modes),
276 extents_hash: hash_i64s(&extents),
277 alignment,
278 compute_desc_tag: compute_desc_tag(compute),
279 dtype_tag,
280 algo,
281 }
282}
283
284pub(crate) fn build_key_for<T: TensorSupported>(
287 a: &OperandSpec<T>,
288 b: &OperandSpec<T>,
289 c: &OperandSpec<T>,
290 alignment: u32,
291 compute: ComputeDesc,
292 algo: i32,
293) -> PlanKey {
294 build_contract_key(
295 <T as atomr_accel::AccelDtype>::NAME,
296 &a.modes,
297 &b.modes,
298 &c.modes,
299 &a.extent,
300 &b.extent,
301 &c.extent,
302 alignment,
303 compute,
304 algo,
305 )
306}
307
308pub(crate) fn get_or_build_plan<T: TensorSupported>(
311 ctx: &TensorDispatchCtx,
312 key: &PlanKey,
313 a: &OperandSpec<T>,
314 b: &OperandSpec<T>,
315 c: &OperandSpec<T>,
316 alignment: u32,
317 compute: ComputeDesc,
318 algo: Option<ct_sys::cutensorAlgo_t>,
319) -> Result<Arc<CachedPlan>, GpuError> {
320 if let Some(p) = ctx.plan_cache.get(key) {
321 return Ok(p);
322 }
323 let plan = build_plan::<T>(&ctx.handle, a, b, c, alignment, compute, algo)?;
324 let arc = Arc::new(plan);
325 ctx.plan_cache.put(*key, arc.clone());
326 Ok(arc)
327}
328
329#[allow(clippy::too_many_arguments)]
330pub(crate) fn build_plan<T: TensorSupported>(
331 handle: &Arc<parking_lot::Mutex<crate::kernel::tensor::SendHandle>>,
332 a: &OperandSpec<T>,
333 b: &OperandSpec<T>,
334 c: &OperandSpec<T>,
335 alignment: u32,
336 compute: ComputeDesc,
337 algo: Option<ct_sys::cutensorAlgo_t>,
338) -> Result<CachedPlan, GpuError> {
339 let h = handle.lock();
340 let dt: cudarc::cutensor::sys::cudaDataType_t =
341 unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
342 let cd = resolve_compute_desc(compute);
343
344 let stride_ptr = |v: &Vec<i64>| {
345 if v.is_empty() {
346 std::ptr::null()
347 } else {
348 v.as_ptr()
349 }
350 };
351
352 let desc_a = unsafe {
353 ct_result::create_tensor_descriptor(
354 h.0,
355 a.extent.len() as u32,
356 a.extent.as_ptr(),
357 stride_ptr(&a.stride),
358 dt,
359 alignment,
360 )
361 }
362 .map_err(|e| GpuError::LibraryError {
363 lib: LIB,
364 msg: format!("CreateTensorDescriptor(A): {e}"),
365 })?;
366 let desc_b = unsafe {
367 ct_result::create_tensor_descriptor(
368 h.0,
369 b.extent.len() as u32,
370 b.extent.as_ptr(),
371 stride_ptr(&b.stride),
372 dt,
373 alignment,
374 )
375 }
376 .map_err(|e| {
377 unsafe {
378 let _ = ct_result::destroy_tensor_descriptor(desc_a);
379 }
380 GpuError::LibraryError {
381 lib: LIB,
382 msg: format!("CreateTensorDescriptor(B): {e}"),
383 }
384 })?;
385 let desc_c = unsafe {
386 ct_result::create_tensor_descriptor(
387 h.0,
388 c.extent.len() as u32,
389 c.extent.as_ptr(),
390 stride_ptr(&c.stride),
391 dt,
392 alignment,
393 )
394 }
395 .map_err(|e| {
396 unsafe {
397 let _ = ct_result::destroy_tensor_descriptor(desc_b);
398 let _ = ct_result::destroy_tensor_descriptor(desc_a);
399 }
400 GpuError::LibraryError {
401 lib: LIB,
402 msg: format!("CreateTensorDescriptor(C): {e}"),
403 }
404 })?;
405
406 let op = unsafe {
407 ct_result::create_contraction(
408 h.0,
409 desc_a,
410 a.modes.as_ptr(),
411 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
412 desc_b,
413 b.modes.as_ptr(),
414 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
415 desc_c,
416 c.modes.as_ptr(),
417 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
418 desc_c,
419 c.modes.as_ptr(),
420 cd,
421 )
422 }
423 .map_err(|e| {
424 unsafe {
425 let _ = ct_result::destroy_tensor_descriptor(desc_c);
426 let _ = ct_result::destroy_tensor_descriptor(desc_b);
427 let _ = ct_result::destroy_tensor_descriptor(desc_a);
428 }
429 GpuError::LibraryError {
430 lib: LIB,
431 msg: format!("CreateContraction: {e}"),
432 }
433 })?;
434
435 let chosen_algo = algo.unwrap_or(ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT);
436 let pref = unsafe {
437 ct_result::create_plan_preference(
438 h.0,
439 chosen_algo,
440 ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
441 )
442 }
443 .map_err(|e| {
444 unsafe {
445 let _ = ct_result::destroy_operation_descriptor(op);
446 let _ = ct_result::destroy_tensor_descriptor(desc_c);
447 let _ = ct_result::destroy_tensor_descriptor(desc_b);
448 let _ = ct_result::destroy_tensor_descriptor(desc_a);
449 }
450 GpuError::LibraryError {
451 lib: LIB,
452 msg: format!("CreatePlanPreference: {e}"),
453 }
454 })?;
455
456 let ws_size = unsafe {
457 ct_result::estimate_workspace_size(
458 h.0,
459 op,
460 pref,
461 ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
462 )
463 }
464 .map_err(|e| {
465 unsafe {
466 let _ = ct_result::destroy_plan_preference(pref);
467 let _ = ct_result::destroy_operation_descriptor(op);
468 let _ = ct_result::destroy_tensor_descriptor(desc_c);
469 let _ = ct_result::destroy_tensor_descriptor(desc_b);
470 let _ = ct_result::destroy_tensor_descriptor(desc_a);
471 }
472 GpuError::LibraryError {
473 lib: LIB,
474 msg: format!("EstimateWorkspaceSize: {e}"),
475 }
476 })?;
477
478 let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
479 unsafe {
480 let _ = ct_result::destroy_plan_preference(pref);
481 let _ = ct_result::destroy_operation_descriptor(op);
482 let _ = ct_result::destroy_tensor_descriptor(desc_c);
483 let _ = ct_result::destroy_tensor_descriptor(desc_b);
484 let _ = ct_result::destroy_tensor_descriptor(desc_a);
485 }
486 GpuError::LibraryError {
487 lib: LIB,
488 msg: format!("CreatePlan: {e}"),
489 }
490 })?;
491
492 Ok(CachedPlan {
493 plan,
494 pref,
495 op,
496 descs: vec![desc_a, desc_b, desc_c],
497 workspace_size: ws_size,
498 })
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn contract_request_round_trip_f32_f64_f16_bf16() {
507 let key32 = build_contract_key(
512 <f32 as atomr_accel::AccelDtype>::NAME,
513 &[1, 2],
514 &[2, 3],
515 &[1, 3],
516 &[2, 3],
517 &[3, 4],
518 &[2, 4],
519 16,
520 ComputeDesc::MinF32,
521 0,
522 );
523 assert_eq!(key32.dtype_tag, "f32");
524 assert_eq!(key32.op_kind, OpKind::Contract);
525 assert_eq!(
526 key32.compute_desc_tag,
527 compute_desc_tag(ComputeDesc::MinF32)
528 );
529
530 let key64 = build_contract_key(
531 <f64 as atomr_accel::AccelDtype>::NAME,
532 &[1, 2],
533 &[2, 3],
534 &[1, 3],
535 &[2, 3],
536 &[3, 4],
537 &[2, 4],
538 16,
539 ComputeDesc::MinF64,
540 0,
541 );
542 assert_eq!(key64.dtype_tag, "f64");
543 assert_ne!(key32, key64);
546
547 assert_eq!(
549 default_compute_for::<f32>().tag(),
550 ComputeDesc::MinF32.tag()
551 );
552 assert_eq!(
553 default_compute_for::<f64>().tag(),
554 ComputeDesc::MinF64.tag()
555 );
556
557 #[cfg(feature = "f16")]
558 {
559 let key_f16 = build_contract_key(
560 <half::f16 as atomr_accel::AccelDtype>::NAME,
561 &[1, 2],
562 &[2, 3],
563 &[1, 3],
564 &[2, 3],
565 &[3, 4],
566 &[2, 4],
567 16,
568 ComputeDesc::MinF32,
569 0,
570 );
571 assert_eq!(key_f16.dtype_tag, "f16");
572 assert_ne!(key32, key_f16);
573
574 let key_bf16 = build_contract_key(
575 <half::bf16 as atomr_accel::AccelDtype>::NAME,
576 &[1, 2],
577 &[2, 3],
578 &[1, 3],
579 &[2, 3],
580 &[3, 4],
581 &[2, 4],
582 16,
583 ComputeDesc::MinF32,
584 0,
585 );
586 assert_eq!(key_bf16.dtype_tag, "bf16");
587 assert_ne!(key_f16, key_bf16);
588
589 assert_eq!(
590 default_compute_for::<half::f16>().tag(),
591 ComputeDesc::MinF32.tag()
592 );
593 assert_eq!(
594 default_compute_for::<half::bf16>().tag(),
595 ComputeDesc::MinF32.tag()
596 );
597 }
598 }
599}