Skip to main content

atomr_accel_cuda/kernel/tensor/
contract.rs

1//! `ContractRequest<T>` — dtype-generic Einstein-summation contraction.
2//!
3//! D = alpha * A^modes_a · B^modes_b + beta * C^modes_c (D in-place
4//! into the C buffer). Mirrors the cuTENSOR `cutensorContract` entry
5//! point.
6
7use std::sync::Arc;
8
9use cudarc::cutensor::result as ct_result;
10use cudarc::cutensor::sys as ct_sys;
11use cudarc::driver::{DevicePtr, DevicePtrMut};
12use tokio::sync::oneshot;
13
14use crate::dtype::TensorSupported;
15use crate::error::GpuError;
16use crate::gpu_ref::GpuRef;
17use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
18use crate::kernel::envelope;
19use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
20use crate::kernel::tensor::plan_cache::{hash_i32s, hash_i64s, CachedPlan, OpKind, PlanKey};
21
22const LIB: &str = "cutensor";
23
24/// One operand specification: device buffer + per-mode extents +
25/// optional strides + Einstein-summation labels.
26#[derive(Clone)]
27pub struct OperandSpec<T: TensorSupported> {
28    pub buf: GpuRef<T>,
29    pub extent: Vec<i64>,
30    /// Empty == dense column-major.
31    pub stride: Vec<i64>,
32    pub modes: Vec<i32>,
33}
34
35/// Dtype-generic contraction request.
36pub struct ContractRequest<T: TensorSupported> {
37    pub a: OperandSpec<T>,
38    pub b: OperandSpec<T>,
39    pub c: OperandSpec<T>,
40    pub alpha: T,
41    pub beta: T,
42    pub compute: ComputeDesc,
43    /// Required tensor alignment in bytes.
44    pub alignment: u32,
45    pub reply: oneshot::Sender<Result<(), GpuError>>,
46}
47
48impl<T: TensorSupported> ContractRequest<T> {
49    pub fn new(
50        a: OperandSpec<T>,
51        b: OperandSpec<T>,
52        c: OperandSpec<T>,
53        alpha: T,
54        beta: T,
55        reply: oneshot::Sender<Result<(), GpuError>>,
56    ) -> Self {
57        Self {
58            a,
59            b,
60            c,
61            alpha,
62            beta,
63            compute: default_compute_for::<T>(),
64            alignment: 16,
65            reply,
66        }
67    }
68
69    pub fn with_compute(mut self, compute: ComputeDesc) -> Self {
70        self.compute = compute;
71        self
72    }
73}
74
75/// Pick the canonical compute descriptor for `T`. Mirrors NVIDIA's
76/// guidance: f32 inputs default to MIN_32F, f64 to MIN_64F, half/bf16
77/// accumulate in f32 (MIN_32F).
78pub fn default_compute_for<T: TensorSupported>() -> ComputeDesc {
79    match <T as atomr_accel::AccelDtype>::NAME {
80        "f32" => ComputeDesc::MinF32,
81        "f64" => ComputeDesc::MinF64,
82        "f16" => ComputeDesc::MinF32,
83        "bf16" => ComputeDesc::MinF32,
84        _ => ComputeDesc::MinF32,
85    }
86}
87
88impl<T: TensorSupported> TensorDispatch for ContractRequest<T> {
89    fn op_tag(&self) -> &'static str {
90        "contract"
91    }
92
93    fn dtype_tag(&self) -> &'static str {
94        <T as atomr_accel::AccelDtype>::NAME
95    }
96
97    fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
98        execute(*self, ctx);
99    }
100
101    fn fail_mock(self: Box<Self>) {
102        let _ = self.reply.send(Err(GpuError::Unrecoverable(
103            "TensorActor in mock mode".into(),
104        )));
105    }
106}
107
108fn execute<T: TensorSupported>(req: ContractRequest<T>, ctx: &TensorDispatchCtx) {
109    let ContractRequest {
110        a,
111        b,
112        c,
113        alpha,
114        beta,
115        compute,
116        alignment,
117        reply,
118    } = req;
119
120    if a.extent.len() != a.modes.len() {
121        let _ = reply.send(Err(GpuError::Unrecoverable(
122            "Contract: a.extent.len != a.modes.len".into(),
123        )));
124        return;
125    }
126    if b.extent.len() != b.modes.len() {
127        let _ = reply.send(Err(GpuError::Unrecoverable(
128            "Contract: b.extent.len != b.modes.len".into(),
129        )));
130        return;
131    }
132    if c.extent.len() != c.modes.len() {
133        let _ = reply.send(Err(GpuError::Unrecoverable(
134            "Contract: c.extent.len != c.modes.len".into(),
135        )));
136        return;
137    }
138
139    let a_slice = match a.buf.access() {
140        Ok(s) => s.clone(),
141        Err(e) => {
142            let _ = reply.send(Err(e));
143            return;
144        }
145    };
146    let b_slice = match b.buf.access() {
147        Ok(s) => s.clone(),
148        Err(e) => {
149            let _ = reply.send(Err(e));
150            return;
151        }
152    };
153    let c_slice = match c.buf.access() {
154        Ok(s) => s.clone(),
155        Err(e) => {
156            let _ = reply.send(Err(e));
157            return;
158        }
159    };
160    let mut c_owned = match Arc::try_unwrap(c_slice) {
161        Ok(s) => s,
162        Err(_) => {
163            let _ = reply.send(Err(GpuError::Unrecoverable(
164                "Contract c has multiple live references".into(),
165            )));
166            return;
167        }
168    };
169
170    let key = build_key_for::<T>(&a, &b, &c, alignment, compute, /*algo*/ 0);
171    let cached = match get_or_build_plan::<T>(ctx, &key, &a, &b, &c, alignment, compute, None) {
172        Ok(p) => p,
173        Err(e) => {
174            let _ = reply.send(Err(e));
175            return;
176        }
177    };
178
179    let ws_size = cached.workspace_size as usize;
180    if let Err(e) = ctx.workspace.ensure(ws_size) {
181        let _ = reply.send(Err(e));
182        return;
183    }
184
185    c.buf.record_write(&ctx.stream);
186
187    let stream_for_check = ctx.stream.clone();
188    let handle_clone = ctx.handle.clone();
189    let workspace = ctx.workspace.clone();
190    let plan_keepalive = cached.clone();
191
192    envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
193        let h = handle_clone.lock();
194        let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
195        let (b_ptr, _gb) = b_slice.device_ptr(&stream_for_check);
196        let (c_ptr, _gc) = c_owned.device_ptr_mut(&stream_for_check);
197        let alpha_h = alpha;
198        let beta_h = beta;
199        let res = workspace
200            .with_bucket(ws_size, |ws_slice| {
201                let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
202                let r = unsafe {
203                    ct_result::contract(
204                        h.0,
205                        plan_keepalive.plan,
206                        &alpha_h as *const T as *const _,
207                        a_ptr as *const _,
208                        b_ptr as *const _,
209                        &beta_h as *const T as *const _,
210                        c_ptr as *const _,
211                        c_ptr as *mut _,
212                        ws_ptr as *mut _,
213                        plan_keepalive.workspace_size,
214                        stream_for_check.cu_stream() as *mut _,
215                    )
216                };
217                drop(_gws);
218                r
219            })
220            .unwrap_or_else(|| unsafe {
221                ct_result::contract(
222                    h.0,
223                    plan_keepalive.plan,
224                    &alpha_h as *const T as *const _,
225                    a_ptr as *const _,
226                    b_ptr as *const _,
227                    &beta_h as *const T as *const _,
228                    c_ptr as *const _,
229                    c_ptr as *mut _,
230                    std::ptr::null_mut(),
231                    0,
232                    stream_for_check.cu_stream() as *mut _,
233                )
234            });
235        drop((_ga, _gb, _gc));
236        match res {
237            Ok(()) => Ok((c_owned, a_slice, b_slice, plan_keepalive)),
238            Err(e) => Err(GpuError::LibraryError {
239                lib: LIB,
240                msg: format!("Contract: {e}"),
241            }),
242        }
243    });
244}
245
246/// Build a cache key from raw mode/extent slices. Tests that don't
247/// need a live `GpuRef` use this directly; the dispatch path goes
248/// through [`build_key_for`].
249pub fn build_contract_key(
250    dtype_tag: &'static str,
251    modes_a: &[i32],
252    modes_b: &[i32],
253    modes_c: &[i32],
254    extent_a: &[i64],
255    extent_b: &[i64],
256    extent_c: &[i64],
257    alignment: u32,
258    compute: ComputeDesc,
259    algo: i32,
260) -> PlanKey {
261    let mut modes = Vec::with_capacity(modes_a.len() + modes_b.len() + modes_c.len() + 3);
262    modes.extend_from_slice(modes_a);
263    modes.push(i32::MIN);
264    modes.extend_from_slice(modes_b);
265    modes.push(i32::MIN);
266    modes.extend_from_slice(modes_c);
267    let mut extents = Vec::with_capacity(extent_a.len() + extent_b.len() + extent_c.len() + 3);
268    extents.extend_from_slice(extent_a);
269    extents.push(i64::MIN);
270    extents.extend_from_slice(extent_b);
271    extents.push(i64::MIN);
272    extents.extend_from_slice(extent_c);
273    PlanKey {
274        op_kind: OpKind::Contract,
275        modes_hash: hash_i32s(&modes),
276        extents_hash: hash_i64s(&extents),
277        alignment,
278        compute_desc_tag: compute_desc_tag(compute),
279        dtype_tag,
280        algo,
281    }
282}
283
284/// Wrapper around [`build_contract_key`] that pulls extents/modes
285/// from typed `OperandSpec`s.
286pub(crate) fn build_key_for<T: TensorSupported>(
287    a: &OperandSpec<T>,
288    b: &OperandSpec<T>,
289    c: &OperandSpec<T>,
290    alignment: u32,
291    compute: ComputeDesc,
292    algo: i32,
293) -> PlanKey {
294    build_contract_key(
295        <T as atomr_accel::AccelDtype>::NAME,
296        &a.modes,
297        &b.modes,
298        &c.modes,
299        &a.extent,
300        &b.extent,
301        &c.extent,
302        alignment,
303        compute,
304        algo,
305    )
306}
307
308/// Look up `key` in `cache`; on miss, build a fresh plan with the
309/// supplied algo (or default) and insert before returning.
310pub(crate) fn get_or_build_plan<T: TensorSupported>(
311    ctx: &TensorDispatchCtx,
312    key: &PlanKey,
313    a: &OperandSpec<T>,
314    b: &OperandSpec<T>,
315    c: &OperandSpec<T>,
316    alignment: u32,
317    compute: ComputeDesc,
318    algo: Option<ct_sys::cutensorAlgo_t>,
319) -> Result<Arc<CachedPlan>, GpuError> {
320    if let Some(p) = ctx.plan_cache.get(key) {
321        return Ok(p);
322    }
323    let plan = build_plan::<T>(&ctx.handle, a, b, c, alignment, compute, algo)?;
324    let arc = Arc::new(plan);
325    ctx.plan_cache.put(*key, arc.clone());
326    Ok(arc)
327}
328
329#[allow(clippy::too_many_arguments)]
330pub(crate) fn build_plan<T: TensorSupported>(
331    handle: &Arc<parking_lot::Mutex<crate::kernel::tensor::SendHandle>>,
332    a: &OperandSpec<T>,
333    b: &OperandSpec<T>,
334    c: &OperandSpec<T>,
335    alignment: u32,
336    compute: ComputeDesc,
337    algo: Option<ct_sys::cutensorAlgo_t>,
338) -> Result<CachedPlan, GpuError> {
339    let h = handle.lock();
340    let dt: cudarc::cutensor::sys::cudaDataType_t =
341        unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
342    let cd = resolve_compute_desc(compute);
343
344    let stride_ptr = |v: &Vec<i64>| {
345        if v.is_empty() {
346            std::ptr::null()
347        } else {
348            v.as_ptr()
349        }
350    };
351
352    let desc_a = unsafe {
353        ct_result::create_tensor_descriptor(
354            h.0,
355            a.extent.len() as u32,
356            a.extent.as_ptr(),
357            stride_ptr(&a.stride),
358            dt,
359            alignment,
360        )
361    }
362    .map_err(|e| GpuError::LibraryError {
363        lib: LIB,
364        msg: format!("CreateTensorDescriptor(A): {e}"),
365    })?;
366    let desc_b = unsafe {
367        ct_result::create_tensor_descriptor(
368            h.0,
369            b.extent.len() as u32,
370            b.extent.as_ptr(),
371            stride_ptr(&b.stride),
372            dt,
373            alignment,
374        )
375    }
376    .map_err(|e| {
377        unsafe {
378            let _ = ct_result::destroy_tensor_descriptor(desc_a);
379        }
380        GpuError::LibraryError {
381            lib: LIB,
382            msg: format!("CreateTensorDescriptor(B): {e}"),
383        }
384    })?;
385    let desc_c = unsafe {
386        ct_result::create_tensor_descriptor(
387            h.0,
388            c.extent.len() as u32,
389            c.extent.as_ptr(),
390            stride_ptr(&c.stride),
391            dt,
392            alignment,
393        )
394    }
395    .map_err(|e| {
396        unsafe {
397            let _ = ct_result::destroy_tensor_descriptor(desc_b);
398            let _ = ct_result::destroy_tensor_descriptor(desc_a);
399        }
400        GpuError::LibraryError {
401            lib: LIB,
402            msg: format!("CreateTensorDescriptor(C): {e}"),
403        }
404    })?;
405
406    let op = unsafe {
407        ct_result::create_contraction(
408            h.0,
409            desc_a,
410            a.modes.as_ptr(),
411            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
412            desc_b,
413            b.modes.as_ptr(),
414            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
415            desc_c,
416            c.modes.as_ptr(),
417            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
418            desc_c,
419            c.modes.as_ptr(),
420            cd,
421        )
422    }
423    .map_err(|e| {
424        unsafe {
425            let _ = ct_result::destroy_tensor_descriptor(desc_c);
426            let _ = ct_result::destroy_tensor_descriptor(desc_b);
427            let _ = ct_result::destroy_tensor_descriptor(desc_a);
428        }
429        GpuError::LibraryError {
430            lib: LIB,
431            msg: format!("CreateContraction: {e}"),
432        }
433    })?;
434
435    let chosen_algo = algo.unwrap_or(ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT);
436    let pref = unsafe {
437        ct_result::create_plan_preference(
438            h.0,
439            chosen_algo,
440            ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
441        )
442    }
443    .map_err(|e| {
444        unsafe {
445            let _ = ct_result::destroy_operation_descriptor(op);
446            let _ = ct_result::destroy_tensor_descriptor(desc_c);
447            let _ = ct_result::destroy_tensor_descriptor(desc_b);
448            let _ = ct_result::destroy_tensor_descriptor(desc_a);
449        }
450        GpuError::LibraryError {
451            lib: LIB,
452            msg: format!("CreatePlanPreference: {e}"),
453        }
454    })?;
455
456    let ws_size = unsafe {
457        ct_result::estimate_workspace_size(
458            h.0,
459            op,
460            pref,
461            ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
462        )
463    }
464    .map_err(|e| {
465        unsafe {
466            let _ = ct_result::destroy_plan_preference(pref);
467            let _ = ct_result::destroy_operation_descriptor(op);
468            let _ = ct_result::destroy_tensor_descriptor(desc_c);
469            let _ = ct_result::destroy_tensor_descriptor(desc_b);
470            let _ = ct_result::destroy_tensor_descriptor(desc_a);
471        }
472        GpuError::LibraryError {
473            lib: LIB,
474            msg: format!("EstimateWorkspaceSize: {e}"),
475        }
476    })?;
477
478    let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
479        unsafe {
480            let _ = ct_result::destroy_plan_preference(pref);
481            let _ = ct_result::destroy_operation_descriptor(op);
482            let _ = ct_result::destroy_tensor_descriptor(desc_c);
483            let _ = ct_result::destroy_tensor_descriptor(desc_b);
484            let _ = ct_result::destroy_tensor_descriptor(desc_a);
485        }
486        GpuError::LibraryError {
487            lib: LIB,
488            msg: format!("CreatePlan: {e}"),
489        }
490    })?;
491
492    Ok(CachedPlan {
493        plan,
494        pref,
495        op,
496        descs: vec![desc_a, desc_b, desc_c],
497        workspace_size: ws_size,
498    })
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn contract_request_round_trip_f32_f64_f16_bf16() {
507        // Round-trip: build cache keys for each supported dtype and
508        // verify dtype tag, op kind, and compute descriptor are wired
509        // correctly. Cache keys are derived from raw mode/extent
510        // vectors, so this test runs on a GPU-less host.
511        let key32 = build_contract_key(
512            <f32 as atomr_accel::AccelDtype>::NAME,
513            &[1, 2],
514            &[2, 3],
515            &[1, 3],
516            &[2, 3],
517            &[3, 4],
518            &[2, 4],
519            16,
520            ComputeDesc::MinF32,
521            0,
522        );
523        assert_eq!(key32.dtype_tag, "f32");
524        assert_eq!(key32.op_kind, OpKind::Contract);
525        assert_eq!(
526            key32.compute_desc_tag,
527            compute_desc_tag(ComputeDesc::MinF32)
528        );
529
530        let key64 = build_contract_key(
531            <f64 as atomr_accel::AccelDtype>::NAME,
532            &[1, 2],
533            &[2, 3],
534            &[1, 3],
535            &[2, 3],
536            &[3, 4],
537            &[2, 4],
538            16,
539            ComputeDesc::MinF64,
540            0,
541        );
542        assert_eq!(key64.dtype_tag, "f64");
543        // Different dtype must produce a different key even with the
544        // same shapes.
545        assert_ne!(key32, key64);
546
547        // Default compute descriptor per dtype.
548        assert_eq!(
549            default_compute_for::<f32>().tag(),
550            ComputeDesc::MinF32.tag()
551        );
552        assert_eq!(
553            default_compute_for::<f64>().tag(),
554            ComputeDesc::MinF64.tag()
555        );
556
557        #[cfg(feature = "f16")]
558        {
559            let key_f16 = build_contract_key(
560                <half::f16 as atomr_accel::AccelDtype>::NAME,
561                &[1, 2],
562                &[2, 3],
563                &[1, 3],
564                &[2, 3],
565                &[3, 4],
566                &[2, 4],
567                16,
568                ComputeDesc::MinF32,
569                0,
570            );
571            assert_eq!(key_f16.dtype_tag, "f16");
572            assert_ne!(key32, key_f16);
573
574            let key_bf16 = build_contract_key(
575                <half::bf16 as atomr_accel::AccelDtype>::NAME,
576                &[1, 2],
577                &[2, 3],
578                &[1, 3],
579                &[2, 3],
580                &[3, 4],
581                &[2, 4],
582                16,
583                ComputeDesc::MinF32,
584                0,
585            );
586            assert_eq!(key_bf16.dtype_tag, "bf16");
587            assert_ne!(key_f16, key_bf16);
588
589            assert_eq!(
590                default_compute_for::<half::f16>().tag(),
591                ComputeDesc::MinF32.tag()
592            );
593            assert_eq!(
594                default_compute_for::<half::bf16>().tag(),
595                ComputeDesc::MinF32.tag()
596            );
597        }
598    }
599}