1use std::sync::Arc;
9
10use cudarc::cutensor::result as ct_result;
11use cudarc::cutensor::sys as ct_sys;
12use cudarc::driver::{DevicePtr, DevicePtrMut};
13use parking_lot::Mutex;
14use tokio::sync::oneshot;
15
16use crate::dtype::TensorSupported;
17use crate::error::GpuError;
18use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
19use crate::kernel::envelope;
20use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
21use crate::kernel::tensor::contract::OperandSpec;
22use crate::kernel::tensor::plan_cache::{hash_i32s, hash_i64s, CachedPlan, OpKind, PlanKey};
23use crate::kernel::tensor::SendHandle;
24use crate::sys::cutensor as ct_local;
25
26const LIB: &str = "cutensor";
27
28pub struct ReductionRequest<T: TensorSupported> {
30 pub a: OperandSpec<T>,
31 pub c: OperandSpec<T>,
32 pub alpha: T,
33 pub beta: T,
34 pub op_reduce: ct_sys::cutensorOperator_t,
35 pub compute: ComputeDesc,
36 pub alignment: u32,
37 pub reply: oneshot::Sender<Result<(), GpuError>>,
38}
39
40impl<T: TensorSupported> ReductionRequest<T> {
41 pub fn new(
42 a: OperandSpec<T>,
43 c: OperandSpec<T>,
44 alpha: T,
45 beta: T,
46 op_reduce: ct_sys::cutensorOperator_t,
47 reply: oneshot::Sender<Result<(), GpuError>>,
48 ) -> Self {
49 Self {
50 a,
51 c,
52 alpha,
53 beta,
54 op_reduce,
55 compute: super::contract::default_compute_for::<T>(),
56 alignment: 16,
57 reply,
58 }
59 }
60}
61
62impl<T: TensorSupported> TensorDispatch for ReductionRequest<T> {
63 fn op_tag(&self) -> &'static str {
64 "reduce"
65 }
66 fn dtype_tag(&self) -> &'static str {
67 <T as atomr_accel::AccelDtype>::NAME
68 }
69 fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
70 execute(*self, ctx);
71 }
72 fn fail_mock(self: Box<Self>) {
73 let _ = self.reply.send(Err(GpuError::Unrecoverable(
74 "TensorActor in mock mode".into(),
75 )));
76 }
77}
78
79fn execute<T: TensorSupported>(req: ReductionRequest<T>, ctx: &TensorDispatchCtx) {
80 let ReductionRequest {
81 a,
82 c,
83 alpha,
84 beta,
85 op_reduce,
86 compute,
87 alignment,
88 reply,
89 } = req;
90
91 if a.extent.len() != a.modes.len() {
92 let _ = reply.send(Err(GpuError::Unrecoverable(
93 "Reduce: a.extent.len != a.modes.len".into(),
94 )));
95 return;
96 }
97 if c.extent.len() != c.modes.len() {
98 let _ = reply.send(Err(GpuError::Unrecoverable(
99 "Reduce: c.extent.len != c.modes.len".into(),
100 )));
101 return;
102 }
103
104 let a_slice = match a.buf.access() {
105 Ok(s) => s.clone(),
106 Err(e) => {
107 let _ = reply.send(Err(e));
108 return;
109 }
110 };
111 let c_slice = match c.buf.access() {
112 Ok(s) => s.clone(),
113 Err(e) => {
114 let _ = reply.send(Err(e));
115 return;
116 }
117 };
118 let mut c_owned = match Arc::try_unwrap(c_slice) {
119 Ok(s) => s,
120 Err(_) => {
121 let _ = reply.send(Err(GpuError::Unrecoverable(
122 "Reduce c has multiple live references".into(),
123 )));
124 return;
125 }
126 };
127
128 let key = build_reduction_key::<T>(&a, &c, alignment, compute, op_reduce);
129 let cached = match get_or_build_plan::<T>(
130 &ctx.handle,
131 &ctx.plan_cache,
132 &key,
133 &a,
134 &c,
135 alignment,
136 compute,
137 op_reduce,
138 ) {
139 Ok(p) => p,
140 Err(e) => {
141 let _ = reply.send(Err(e));
142 return;
143 }
144 };
145
146 let ws_size = cached.workspace_size as usize;
147 if let Err(e) = ctx.workspace.ensure(ws_size) {
148 let _ = reply.send(Err(e));
149 return;
150 }
151
152 c.buf.record_write(&ctx.stream);
153
154 let stream_for_check = ctx.stream.clone();
155 let handle_clone = ctx.handle.clone();
156 let workspace = ctx.workspace.clone();
157 let plan_keepalive = cached.clone();
158
159 envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
160 let h = handle_clone.lock();
161 let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
162 let (c_ptr, _gc) = c_owned.device_ptr_mut(&stream_for_check);
163 let alpha_h = alpha;
164 let beta_h = beta;
165 let res = workspace
166 .with_bucket(ws_size, |ws_slice| {
167 let (ws_ptr, _gws) = ws_slice.device_ptr_mut(&stream_for_check);
168 let r = unsafe {
169 ct_local::reduce(
170 h.0,
171 plan_keepalive.plan,
172 &alpha_h as *const T as *const _,
173 a_ptr as *const _,
174 &beta_h as *const T as *const _,
175 c_ptr as *const _,
176 c_ptr as *mut _,
177 ws_ptr as *mut _,
178 plan_keepalive.workspace_size,
179 stream_for_check.cu_stream() as *mut _,
180 )
181 };
182 drop(_gws);
183 r
184 })
185 .unwrap_or_else(|| unsafe {
186 ct_local::reduce(
187 h.0,
188 plan_keepalive.plan,
189 &alpha_h as *const T as *const _,
190 a_ptr as *const _,
191 &beta_h as *const T as *const _,
192 c_ptr as *const _,
193 c_ptr as *mut _,
194 std::ptr::null_mut(),
195 0,
196 stream_for_check.cu_stream() as *mut _,
197 )
198 });
199 drop((_ga, _gc));
200 match res {
201 Ok(()) => Ok((c_owned, a_slice, plan_keepalive)),
202 Err(e) => Err(GpuError::LibraryError {
203 lib: LIB,
204 msg: format!("Reduce: {e}"),
205 }),
206 }
207 });
208}
209
210pub fn build_reduction_key_raw(
211 dtype_tag: &'static str,
212 modes_a: &[i32],
213 modes_c: &[i32],
214 extent_a: &[i64],
215 extent_c: &[i64],
216 alignment: u32,
217 compute: ComputeDesc,
218 op_reduce: ct_sys::cutensorOperator_t,
219) -> PlanKey {
220 let mut modes = Vec::with_capacity(modes_a.len() + modes_c.len() + 2);
221 modes.extend_from_slice(modes_a);
222 modes.push(i32::MIN);
223 modes.extend_from_slice(modes_c);
224 let mut extents = Vec::with_capacity(extent_a.len() + extent_c.len() + 2);
225 extents.extend_from_slice(extent_a);
226 extents.push(i64::MIN);
227 extents.extend_from_slice(extent_c);
228 PlanKey {
229 op_kind: OpKind::Reduce,
230 modes_hash: hash_i32s(&modes),
231 extents_hash: hash_i64s(&extents),
232 alignment,
233 compute_desc_tag: compute_desc_tag(compute)
236 ^ ((op_reduce as u32).wrapping_mul(0x9E37_79B9)),
237 dtype_tag,
238 algo: 0,
239 }
240}
241
242fn build_reduction_key<T: TensorSupported>(
243 a: &OperandSpec<T>,
244 c: &OperandSpec<T>,
245 alignment: u32,
246 compute: ComputeDesc,
247 op_reduce: ct_sys::cutensorOperator_t,
248) -> PlanKey {
249 build_reduction_key_raw(
250 <T as atomr_accel::AccelDtype>::NAME,
251 &a.modes,
252 &c.modes,
253 &a.extent,
254 &c.extent,
255 alignment,
256 compute,
257 op_reduce,
258 )
259}
260
261#[allow(clippy::too_many_arguments)]
262fn get_or_build_plan<T: TensorSupported>(
263 handle: &Arc<Mutex<SendHandle>>,
264 plan_cache: &Arc<crate::kernel::tensor::plan_cache::PlanCache>,
265 key: &PlanKey,
266 a: &OperandSpec<T>,
267 c: &OperandSpec<T>,
268 alignment: u32,
269 compute: ComputeDesc,
270 op_reduce: ct_sys::cutensorOperator_t,
271) -> Result<Arc<CachedPlan>, GpuError> {
272 if let Some(p) = plan_cache.get(key) {
273 return Ok(p);
274 }
275 let plan = build_plan::<T>(handle, a, c, alignment, compute, op_reduce)?;
276 let arc = Arc::new(plan);
277 plan_cache.put(*key, arc.clone());
278 Ok(arc)
279}
280
281fn build_plan<T: TensorSupported>(
282 handle: &Arc<Mutex<SendHandle>>,
283 a: &OperandSpec<T>,
284 c: &OperandSpec<T>,
285 alignment: u32,
286 compute: ComputeDesc,
287 op_reduce: ct_sys::cutensorOperator_t,
288) -> Result<CachedPlan, GpuError> {
289 let h = handle.lock();
290 let dt: cudarc::cutensor::sys::cudaDataType_t =
291 unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
292 let cd = resolve_compute_desc(compute);
293 let stride_ptr = |v: &Vec<i64>| {
294 if v.is_empty() {
295 std::ptr::null()
296 } else {
297 v.as_ptr()
298 }
299 };
300
301 let desc_a = unsafe {
302 ct_result::create_tensor_descriptor(
303 h.0,
304 a.extent.len() as u32,
305 a.extent.as_ptr(),
306 stride_ptr(&a.stride),
307 dt,
308 alignment,
309 )
310 }
311 .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
312 let desc_c = unsafe {
313 ct_result::create_tensor_descriptor(
314 h.0,
315 c.extent.len() as u32,
316 c.extent.as_ptr(),
317 stride_ptr(&c.stride),
318 dt,
319 alignment,
320 )
321 }
322 .map_err(|e| {
323 unsafe {
324 let _ = ct_result::destroy_tensor_descriptor(desc_a);
325 }
326 GpuError::lib(LIB, format!("CreateTensorDescriptor(C): {e}"))
327 })?;
328
329 let op = unsafe {
330 ct_result::create_reduction(
331 h.0,
332 desc_a,
333 a.modes.as_ptr(),
334 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
335 desc_c,
336 c.modes.as_ptr(),
337 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
338 desc_c,
339 c.modes.as_ptr(),
340 op_reduce,
341 cd,
342 )
343 }
344 .map_err(|e| {
345 unsafe {
346 let _ = ct_result::destroy_tensor_descriptor(desc_c);
347 let _ = ct_result::destroy_tensor_descriptor(desc_a);
348 }
349 GpuError::lib(LIB, format!("CreateReduction: {e}"))
350 })?;
351
352 let pref = unsafe {
353 ct_result::create_plan_preference(
354 h.0,
355 ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
356 ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
357 )
358 }
359 .map_err(|e| {
360 unsafe {
361 let _ = ct_result::destroy_operation_descriptor(op);
362 let _ = ct_result::destroy_tensor_descriptor(desc_c);
363 let _ = ct_result::destroy_tensor_descriptor(desc_a);
364 }
365 GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
366 })?;
367
368 let ws_size = unsafe {
369 ct_result::estimate_workspace_size(
370 h.0,
371 op,
372 pref,
373 ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
374 )
375 }
376 .map_err(|e| {
377 unsafe {
378 let _ = ct_result::destroy_plan_preference(pref);
379 let _ = ct_result::destroy_operation_descriptor(op);
380 let _ = ct_result::destroy_tensor_descriptor(desc_c);
381 let _ = ct_result::destroy_tensor_descriptor(desc_a);
382 }
383 GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
384 })?;
385
386 let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
387 unsafe {
388 let _ = ct_result::destroy_plan_preference(pref);
389 let _ = ct_result::destroy_operation_descriptor(op);
390 let _ = ct_result::destroy_tensor_descriptor(desc_c);
391 let _ = ct_result::destroy_tensor_descriptor(desc_a);
392 }
393 GpuError::lib(LIB, format!("CreatePlan: {e}"))
394 })?;
395
396 Ok(CachedPlan {
397 plan,
398 pref,
399 op,
400 descs: vec![desc_a, desc_c],
401 workspace_size: ws_size,
402 })
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn reduction_request_round_trip() {
411 let key_add = build_reduction_key_raw(
412 <f32 as atomr_accel::AccelDtype>::NAME,
413 &[1, 2, 3],
414 &[1],
415 &[8, 16, 32],
416 &[8],
417 16,
418 ComputeDesc::MinF32,
419 ct_sys::cutensorOperator_t::CUTENSOR_OP_ADD,
420 );
421 let key_max = build_reduction_key_raw(
422 <f32 as atomr_accel::AccelDtype>::NAME,
423 &[1, 2, 3],
424 &[1],
425 &[8, 16, 32],
426 &[8],
427 16,
428 ComputeDesc::MinF32,
429 ct_sys::cutensorOperator_t::CUTENSOR_OP_MAX,
430 );
431 assert_ne!(key_add, key_max);
433 assert_eq!(key_add.op_kind, OpKind::Reduce);
434 assert_eq!(key_add.dtype_tag, "f32");
435
436 let key_f64 = build_reduction_key_raw(
437 <f64 as atomr_accel::AccelDtype>::NAME,
438 &[1, 2, 3],
439 &[1],
440 &[8, 16, 32],
441 &[8],
442 16,
443 ComputeDesc::MinF64,
444 ct_sys::cutensorOperator_t::CUTENSOR_OP_ADD,
445 );
446 assert_ne!(key_add, key_f64);
447 assert_eq!(key_f64.dtype_tag, "f64");
448 }
449}