Skip to main content

atomr_accel_cuda/kernel/tensor/
reduce.rs

1//! `ReductionRequest<T>` — dtype-generic tensor reduction.
2//!
3//! Wraps `cutensorCreateReduction` + `cutensorReduce`. Output `c`
4//! holds the reduced result (shape determined by `mode_c` ⊆
5//! `mode_a`). `op_reduce` is one of `CUTENSOR_OP_ADD`,
6//! `CUTENSOR_OP_MAX`, `CUTENSOR_OP_MIN`, `CUTENSOR_OP_MUL`.
7
8use std::sync::Arc;
9
10use cudarc::cutensor::result as ct_result;
11use cudarc::cutensor::sys as ct_sys;
12use cudarc::driver::{DevicePtr, DevicePtrMut};
13use parking_lot::Mutex;
14use tokio::sync::oneshot;
15
16use crate::dtype::TensorSupported;
17use crate::error::GpuError;
18use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
19use crate::kernel::envelope;
20use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
21use crate::kernel::tensor::contract::OperandSpec;
22use crate::kernel::tensor::plan_cache::{hash_i32s, hash_i64s, CachedPlan, OpKind, PlanKey};
23use crate::kernel::tensor::SendHandle;
24use crate::sys::cutensor as ct_local;
25
26const LIB: &str = "cutensor";
27
28/// Dtype-generic reduction request.
29pub struct ReductionRequest<T: TensorSupported> {
30    pub a: OperandSpec<T>,
31    pub c: OperandSpec<T>,
32    pub alpha: T,
33    pub beta: T,
34    pub op_reduce: ct_sys::cutensorOperator_t,
35    pub compute: ComputeDesc,
36    pub alignment: u32,
37    pub reply: oneshot::Sender<Result<(), GpuError>>,
38}
39
40impl<T: TensorSupported> ReductionRequest<T> {
41    pub fn new(
42        a: OperandSpec<T>,
43        c: OperandSpec<T>,
44        alpha: T,
45        beta: T,
46        op_reduce: ct_sys::cutensorOperator_t,
47        reply: oneshot::Sender<Result<(), GpuError>>,
48    ) -> Self {
49        Self {
50            a,
51            c,
52            alpha,
53            beta,
54            op_reduce,
55            compute: super::contract::default_compute_for::<T>(),
56            alignment: 16,
57            reply,
58        }
59    }
60}
61
62impl<T: TensorSupported> TensorDispatch for ReductionRequest<T> {
63    fn op_tag(&self) -> &'static str {
64        "reduce"
65    }
66    fn dtype_tag(&self) -> &'static str {
67        <T as atomr_accel::AccelDtype>::NAME
68    }
69    fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
70        execute(*self, ctx);
71    }
72    fn fail_mock(self: Box<Self>) {
73        let _ = self.reply.send(Err(GpuError::Unrecoverable(
74            "TensorActor in mock mode".into(),
75        )));
76    }
77}
78
79fn execute<T: TensorSupported>(req: ReductionRequest<T>, ctx: &TensorDispatchCtx) {
80    let ReductionRequest {
81        a,
82        c,
83        alpha,
84        beta,
85        op_reduce,
86        compute,
87        alignment,
88        reply,
89    } = req;
90
91    if a.extent.len() != a.modes.len() {
92        let _ = reply.send(Err(GpuError::Unrecoverable(
93            "Reduce: a.extent.len != a.modes.len".into(),
94        )));
95        return;
96    }
97    if c.extent.len() != c.modes.len() {
98        let _ = reply.send(Err(GpuError::Unrecoverable(
99            "Reduce: c.extent.len != c.modes.len".into(),
100        )));
101        return;
102    }
103
104    let a_slice = match a.buf.access() {
105        Ok(s) => s.clone(),
106        Err(e) => {
107            let _ = reply.send(Err(e));
108            return;
109        }
110    };
111    let c_slice = match c.buf.access() {
112        Ok(s) => s.clone(),
113        Err(e) => {
114            let _ = reply.send(Err(e));
115            return;
116        }
117    };
118    let mut c_owned = match Arc::try_unwrap(c_slice) {
119        Ok(s) => s,
120        Err(_) => {
121            let _ = reply.send(Err(GpuError::Unrecoverable(
122                "Reduce c has multiple live references".into(),
123            )));
124            return;
125        }
126    };
127
128    let key = build_reduction_key::<T>(&a, &c, alignment, compute, op_reduce);
129    let cached = match get_or_build_plan::<T>(
130        &ctx.handle,
131        &ctx.plan_cache,
132        &key,
133        &a,
134        &c,
135        alignment,
136        compute,
137        op_reduce,
138    ) {
139        Ok(p) => p,
140        Err(e) => {
141            let _ = reply.send(Err(e));
142            return;
143        }
144    };
145
146    let ws_size = cached.workspace_size as usize;
147    if let Err(e) = ctx.workspace.ensure(ws_size) {
148        let _ = reply.send(Err(e));
149        return;
150    }
151
152    c.buf.record_write(&ctx.stream);
153
154    let stream_for_check = ctx.stream.clone();
155    let handle_clone = ctx.handle.clone();
156    let workspace = ctx.workspace.clone();
157    let plan_keepalive = cached.clone();
158
159    envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
160        let h = handle_clone.lock();
161        let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
162        let (c_ptr, _gc) = c_owned.device_ptr_mut(&stream_for_check);
163        let alpha_h = alpha;
164        let beta_h = beta;
165        let res = workspace
166            .with_bucket(ws_size, |ws_slice| {
167                let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
168                let r = unsafe {
169                    ct_local::reduce(
170                        h.0,
171                        plan_keepalive.plan,
172                        &alpha_h as *const T as *const _,
173                        a_ptr as *const _,
174                        &beta_h as *const T as *const _,
175                        c_ptr as *const _,
176                        c_ptr as *mut _,
177                        ws_ptr as *mut _,
178                        plan_keepalive.workspace_size,
179                        stream_for_check.cu_stream() as *mut _,
180                    )
181                };
182                drop(_gws);
183                r
184            })
185            .unwrap_or_else(|| unsafe {
186                ct_local::reduce(
187                    h.0,
188                    plan_keepalive.plan,
189                    &alpha_h as *const T as *const _,
190                    a_ptr as *const _,
191                    &beta_h as *const T as *const _,
192                    c_ptr as *const _,
193                    c_ptr as *mut _,
194                    std::ptr::null_mut(),
195                    0,
196                    stream_for_check.cu_stream() as *mut _,
197                )
198            });
199        drop((_ga, _gc));
200        match res {
201            Ok(()) => Ok((c_owned, a_slice, plan_keepalive)),
202            Err(e) => Err(GpuError::LibraryError {
203                lib: LIB,
204                msg: format!("Reduce: {e}"),
205            }),
206        }
207    });
208}
209
210pub fn build_reduction_key_raw(
211    dtype_tag: &'static str,
212    modes_a: &[i32],
213    modes_c: &[i32],
214    extent_a: &[i64],
215    extent_c: &[i64],
216    alignment: u32,
217    compute: ComputeDesc,
218    op_reduce: ct_sys::cutensorOperator_t,
219) -> PlanKey {
220    let mut modes = Vec::with_capacity(modes_a.len() + modes_c.len() + 2);
221    modes.extend_from_slice(modes_a);
222    modes.push(i32::MIN);
223    modes.extend_from_slice(modes_c);
224    let mut extents = Vec::with_capacity(extent_a.len() + extent_c.len() + 2);
225    extents.extend_from_slice(extent_a);
226    extents.push(i64::MIN);
227    extents.extend_from_slice(extent_c);
228    PlanKey {
229        op_kind: OpKind::Reduce,
230        modes_hash: hash_i32s(&modes),
231        extents_hash: hash_i64s(&extents),
232        alignment,
233        // Fold the reduction operator into the compute-desc tag so two
234        // shapes-only-equal-but-op-different requests don't collide.
235        compute_desc_tag: compute_desc_tag(compute)
236            ^ ((op_reduce as u32).wrapping_mul(0x9E37_79B9)),
237        dtype_tag,
238        algo: 0,
239    }
240}
241
242fn build_reduction_key<T: TensorSupported>(
243    a: &OperandSpec<T>,
244    c: &OperandSpec<T>,
245    alignment: u32,
246    compute: ComputeDesc,
247    op_reduce: ct_sys::cutensorOperator_t,
248) -> PlanKey {
249    build_reduction_key_raw(
250        <T as atomr_accel::AccelDtype>::NAME,
251        &a.modes,
252        &c.modes,
253        &a.extent,
254        &c.extent,
255        alignment,
256        compute,
257        op_reduce,
258    )
259}
260
261#[allow(clippy::too_many_arguments)]
262fn get_or_build_plan<T: TensorSupported>(
263    handle: &Arc<Mutex<SendHandle>>,
264    plan_cache: &Arc<crate::kernel::tensor::plan_cache::PlanCache>,
265    key: &PlanKey,
266    a: &OperandSpec<T>,
267    c: &OperandSpec<T>,
268    alignment: u32,
269    compute: ComputeDesc,
270    op_reduce: ct_sys::cutensorOperator_t,
271) -> Result<Arc<CachedPlan>, GpuError> {
272    if let Some(p) = plan_cache.get(key) {
273        return Ok(p);
274    }
275    let plan = build_plan::<T>(handle, a, c, alignment, compute, op_reduce)?;
276    let arc = Arc::new(plan);
277    plan_cache.put(*key, arc.clone());
278    Ok(arc)
279}
280
281fn build_plan<T: TensorSupported>(
282    handle: &Arc<Mutex<SendHandle>>,
283    a: &OperandSpec<T>,
284    c: &OperandSpec<T>,
285    alignment: u32,
286    compute: ComputeDesc,
287    op_reduce: ct_sys::cutensorOperator_t,
288) -> Result<CachedPlan, GpuError> {
289    let h = handle.lock();
290    let dt: cudarc::cutensor::sys::cudaDataType_t =
291        unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
292    let cd = resolve_compute_desc(compute);
293    let stride_ptr = |v: &Vec<i64>| {
294        if v.is_empty() {
295            std::ptr::null()
296        } else {
297            v.as_ptr()
298        }
299    };
300
301    let desc_a = unsafe {
302        ct_result::create_tensor_descriptor(
303            h.0,
304            a.extent.len() as u32,
305            a.extent.as_ptr(),
306            stride_ptr(&a.stride),
307            dt,
308            alignment,
309        )
310    }
311    .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
312    let desc_c = unsafe {
313        ct_result::create_tensor_descriptor(
314            h.0,
315            c.extent.len() as u32,
316            c.extent.as_ptr(),
317            stride_ptr(&c.stride),
318            dt,
319            alignment,
320        )
321    }
322    .map_err(|e| {
323        unsafe {
324            let _ = ct_result::destroy_tensor_descriptor(desc_a);
325        }
326        GpuError::lib(LIB, format!("CreateTensorDescriptor(C): {e}"))
327    })?;
328
329    let op = unsafe {
330        ct_result::create_reduction(
331            h.0,
332            desc_a,
333            a.modes.as_ptr(),
334            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
335            desc_c,
336            c.modes.as_ptr(),
337            ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
338            desc_c,
339            c.modes.as_ptr(),
340            op_reduce,
341            cd,
342        )
343    }
344    .map_err(|e| {
345        unsafe {
346            let _ = ct_result::destroy_tensor_descriptor(desc_c);
347            let _ = ct_result::destroy_tensor_descriptor(desc_a);
348        }
349        GpuError::lib(LIB, format!("CreateReduction: {e}"))
350    })?;
351
352    let pref = unsafe {
353        ct_result::create_plan_preference(
354            h.0,
355            ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
356            ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
357        )
358    }
359    .map_err(|e| {
360        unsafe {
361            let _ = ct_result::destroy_operation_descriptor(op);
362            let _ = ct_result::destroy_tensor_descriptor(desc_c);
363            let _ = ct_result::destroy_tensor_descriptor(desc_a);
364        }
365        GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
366    })?;
367
368    let ws_size = unsafe {
369        ct_result::estimate_workspace_size(
370            h.0,
371            op,
372            pref,
373            ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
374        )
375    }
376    .map_err(|e| {
377        unsafe {
378            let _ = ct_result::destroy_plan_preference(pref);
379            let _ = ct_result::destroy_operation_descriptor(op);
380            let _ = ct_result::destroy_tensor_descriptor(desc_c);
381            let _ = ct_result::destroy_tensor_descriptor(desc_a);
382        }
383        GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
384    })?;
385
386    let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
387        unsafe {
388            let _ = ct_result::destroy_plan_preference(pref);
389            let _ = ct_result::destroy_operation_descriptor(op);
390            let _ = ct_result::destroy_tensor_descriptor(desc_c);
391            let _ = ct_result::destroy_tensor_descriptor(desc_a);
392        }
393        GpuError::lib(LIB, format!("CreatePlan: {e}"))
394    })?;
395
396    Ok(CachedPlan {
397        plan,
398        pref,
399        op,
400        descs: vec![desc_a, desc_c],
401        workspace_size: ws_size,
402    })
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn reduction_request_round_trip() {
411        let key_add = build_reduction_key_raw(
412            <f32 as atomr_accel::AccelDtype>::NAME,
413            &[1, 2, 3],
414            &[1],
415            &[8, 16, 32],
416            &[8],
417            16,
418            ComputeDesc::MinF32,
419            ct_sys::cutensorOperator_t::CUTENSOR_OP_ADD,
420        );
421        let key_max = build_reduction_key_raw(
422            <f32 as atomr_accel::AccelDtype>::NAME,
423            &[1, 2, 3],
424            &[1],
425            &[8, 16, 32],
426            &[8],
427            16,
428            ComputeDesc::MinF32,
429            ct_sys::cutensorOperator_t::CUTENSOR_OP_MAX,
430        );
431        // Same shapes, different op_reduce → distinct cache slots.
432        assert_ne!(key_add, key_max);
433        assert_eq!(key_add.op_kind, OpKind::Reduce);
434        assert_eq!(key_add.dtype_tag, "f32");
435
436        let key_f64 = build_reduction_key_raw(
437            <f64 as atomr_accel::AccelDtype>::NAME,
438            &[1, 2, 3],
439            &[1],
440            &[8, 16, 32],
441            &[8],
442            16,
443            ComputeDesc::MinF64,
444            ct_sys::cutensorOperator_t::CUTENSOR_OP_ADD,
445        );
446        assert_ne!(key_add, key_f64);
447        assert_eq!(key_f64.dtype_tag, "f64");
448    }
449}