1use std::sync::Arc;
7
8use cudarc::cutensor::result as ct_result;
9use cudarc::cutensor::sys as ct_sys;
10use cudarc::driver::{DevicePtr, DevicePtrMut};
11use parking_lot::Mutex;
12use tokio::sync::oneshot;
13
14use crate::dtype::TensorSupported;
15use crate::error::GpuError;
16use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx};
17use crate::kernel::envelope;
18use crate::kernel::tensor::compute_desc::{compute_desc_tag, resolve_compute_desc, ComputeDesc};
19use crate::kernel::tensor::contract::OperandSpec;
20use crate::kernel::tensor::plan_cache::{
21 hash_i32s, hash_i64s, CachedPlan, OpKind, PlanCache, PlanKey,
22};
23use crate::kernel::tensor::SendHandle;
24use crate::sys::cutensor as ct_local;
25
26const LIB: &str = "cutensor";
27
28pub struct PermutationRequest<T: TensorSupported> {
29 pub a: OperandSpec<T>,
30 pub b: OperandSpec<T>,
31 pub alpha: T,
32 pub op_a: ct_sys::cutensorOperator_t,
33 pub compute: ComputeDesc,
34 pub alignment: u32,
35 pub reply: oneshot::Sender<Result<(), GpuError>>,
36}
37
38impl<T: TensorSupported> PermutationRequest<T> {
39 pub fn new(
40 a: OperandSpec<T>,
41 b: OperandSpec<T>,
42 alpha: T,
43 reply: oneshot::Sender<Result<(), GpuError>>,
44 ) -> Self {
45 Self {
46 a,
47 b,
48 alpha,
49 op_a: ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
50 compute: super::contract::default_compute_for::<T>(),
51 alignment: 16,
52 reply,
53 }
54 }
55}
56
57impl<T: TensorSupported> TensorDispatch for PermutationRequest<T> {
58 fn op_tag(&self) -> &'static str {
59 "permute"
60 }
61 fn dtype_tag(&self) -> &'static str {
62 <T as atomr_accel::AccelDtype>::NAME
63 }
64 fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx) {
65 execute(*self, ctx);
66 }
67 fn fail_mock(self: Box<Self>) {
68 let _ = self.reply.send(Err(GpuError::Unrecoverable(
69 "TensorActor in mock mode".into(),
70 )));
71 }
72}
73
74fn execute<T: TensorSupported>(req: PermutationRequest<T>, ctx: &TensorDispatchCtx) {
75 let PermutationRequest {
76 a,
77 b,
78 alpha,
79 op_a,
80 compute,
81 alignment,
82 reply,
83 } = req;
84
85 if a.extent.len() != a.modes.len() || b.extent.len() != b.modes.len() {
86 let _ = reply.send(Err(GpuError::Unrecoverable(
87 "Permutation: extent/modes length mismatch".into(),
88 )));
89 return;
90 }
91
92 let a_slice = match a.buf.access() {
93 Ok(s) => s.clone(),
94 Err(e) => {
95 let _ = reply.send(Err(e));
96 return;
97 }
98 };
99 let b_slice = match b.buf.access() {
100 Ok(s) => s.clone(),
101 Err(e) => {
102 let _ = reply.send(Err(e));
103 return;
104 }
105 };
106 let mut b_owned = match Arc::try_unwrap(b_slice) {
107 Ok(s) => s,
108 Err(_) => {
109 let _ = reply.send(Err(GpuError::Unrecoverable(
110 "Permutation b has multiple live references".into(),
111 )));
112 return;
113 }
114 };
115
116 let key = build_permutation_key_raw(
117 <T as atomr_accel::AccelDtype>::NAME,
118 &a.modes,
119 &b.modes,
120 &a.extent,
121 &b.extent,
122 alignment,
123 compute,
124 op_a,
125 );
126 let cached = match get_or_build_plan::<T>(
127 &ctx.handle,
128 &ctx.plan_cache,
129 &key,
130 &a,
131 &b,
132 alignment,
133 compute,
134 op_a,
135 ) {
136 Ok(p) => p,
137 Err(e) => {
138 let _ = reply.send(Err(e));
139 return;
140 }
141 };
142
143 b.buf.record_write(&ctx.stream);
144
145 let stream_for_check = ctx.stream.clone();
146 let handle_clone = ctx.handle.clone();
147 let plan_keepalive = cached.clone();
148
149 envelope::run_kernel(LIB, &ctx.stream, &ctx.completion, (), reply, move || {
150 let h = handle_clone.lock();
151 let (a_ptr, _ga) = a_slice.device_ptr(&stream_for_check);
152 let (b_ptr, _gb) = b_owned.device_ptr_mut(&stream_for_check);
153 let alpha_h = alpha;
154 let res = unsafe {
155 ct_local::permute(
156 h.0,
157 plan_keepalive.plan,
158 &alpha_h as *const T as *const _,
159 a_ptr as *const _,
160 b_ptr as *mut _,
161 stream_for_check.cu_stream() as *mut _,
162 )
163 };
164 drop((_ga, _gb));
165 match res {
166 Ok(()) => Ok((b_owned, a_slice, plan_keepalive)),
167 Err(e) => Err(GpuError::LibraryError {
168 lib: LIB,
169 msg: format!("Permute: {e}"),
170 }),
171 }
172 });
173}
174
175#[allow(clippy::too_many_arguments)]
176pub fn build_permutation_key_raw(
177 dtype_tag: &'static str,
178 modes_a: &[i32],
179 modes_b: &[i32],
180 extent_a: &[i64],
181 extent_b: &[i64],
182 alignment: u32,
183 compute: ComputeDesc,
184 op_a: ct_sys::cutensorOperator_t,
185) -> PlanKey {
186 let mut modes = Vec::with_capacity(modes_a.len() + modes_b.len() + 2);
187 modes.extend_from_slice(modes_a);
188 modes.push(i32::MIN);
189 modes.extend_from_slice(modes_b);
190 let mut extents = Vec::with_capacity(extent_a.len() + extent_b.len() + 2);
191 extents.extend_from_slice(extent_a);
192 extents.push(i64::MIN);
193 extents.extend_from_slice(extent_b);
194 PlanKey {
195 op_kind: OpKind::Permutation,
196 modes_hash: hash_i32s(&modes),
197 extents_hash: hash_i64s(&extents),
198 alignment,
199 compute_desc_tag: compute_desc_tag(compute) ^ ((op_a as u32).wrapping_mul(0x27d4_eb2f)),
200 dtype_tag,
201 algo: 0,
202 }
203}
204
205#[allow(clippy::too_many_arguments)]
206fn get_or_build_plan<T: TensorSupported>(
207 handle: &Arc<Mutex<SendHandle>>,
208 plan_cache: &Arc<PlanCache>,
209 key: &PlanKey,
210 a: &OperandSpec<T>,
211 b: &OperandSpec<T>,
212 alignment: u32,
213 compute: ComputeDesc,
214 op_a: ct_sys::cutensorOperator_t,
215) -> Result<Arc<CachedPlan>, GpuError> {
216 if let Some(p) = plan_cache.get(key) {
217 return Ok(p);
218 }
219 let plan = build_plan::<T>(handle, a, b, alignment, compute, op_a)?;
220 let arc = Arc::new(plan);
221 plan_cache.put(*key, arc.clone());
222 Ok(arc)
223}
224
225fn build_plan<T: TensorSupported>(
226 handle: &Arc<Mutex<SendHandle>>,
227 a: &OperandSpec<T>,
228 b: &OperandSpec<T>,
229 alignment: u32,
230 compute: ComputeDesc,
231 op_a: ct_sys::cutensorOperator_t,
232) -> Result<CachedPlan, GpuError> {
233 let h = handle.lock();
234 let dt: cudarc::cutensor::sys::cudaDataType_t =
235 unsafe { std::mem::transmute(T::cuda_data_type() as u32) };
236 let cd = resolve_compute_desc(compute);
237 let stride_ptr = |v: &Vec<i64>| {
238 if v.is_empty() {
239 std::ptr::null()
240 } else {
241 v.as_ptr()
242 }
243 };
244 let desc_a = unsafe {
245 ct_result::create_tensor_descriptor(
246 h.0,
247 a.extent.len() as u32,
248 a.extent.as_ptr(),
249 stride_ptr(&a.stride),
250 dt,
251 alignment,
252 )
253 }
254 .map_err(|e| GpuError::lib(LIB, format!("CreateTensorDescriptor(A): {e}")))?;
255 let desc_b = unsafe {
256 ct_result::create_tensor_descriptor(
257 h.0,
258 b.extent.len() as u32,
259 b.extent.as_ptr(),
260 stride_ptr(&b.stride),
261 dt,
262 alignment,
263 )
264 }
265 .map_err(|e| {
266 unsafe {
267 let _ = ct_result::destroy_tensor_descriptor(desc_a);
268 }
269 GpuError::lib(LIB, format!("CreateTensorDescriptor(B): {e}"))
270 })?;
271
272 let op = unsafe {
273 ct_local::create_permutation(
274 h.0,
275 desc_a,
276 a.modes.as_ptr(),
277 op_a,
278 desc_b,
279 b.modes.as_ptr(),
280 cd,
281 )
282 }
283 .map_err(|e| {
284 unsafe {
285 let _ = ct_result::destroy_tensor_descriptor(desc_b);
286 let _ = ct_result::destroy_tensor_descriptor(desc_a);
287 }
288 GpuError::lib(LIB, format!("CreatePermutation: {e}"))
289 })?;
290
291 let pref = unsafe {
292 ct_result::create_plan_preference(
293 h.0,
294 ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
295 ct_sys::cutensorJitMode_t::CUTENSOR_JIT_MODE_NONE,
296 )
297 }
298 .map_err(|e| {
299 unsafe {
300 let _ = ct_result::destroy_operation_descriptor(op);
301 let _ = ct_result::destroy_tensor_descriptor(desc_b);
302 let _ = ct_result::destroy_tensor_descriptor(desc_a);
303 }
304 GpuError::lib(LIB, format!("CreatePlanPreference: {e}"))
305 })?;
306
307 let ws_size = unsafe {
308 ct_result::estimate_workspace_size(
309 h.0,
310 op,
311 pref,
312 ct_sys::cutensorWorksizePreference_t::CUTENSOR_WORKSPACE_DEFAULT,
313 )
314 }
315 .map_err(|e| {
316 unsafe {
317 let _ = ct_result::destroy_plan_preference(pref);
318 let _ = ct_result::destroy_operation_descriptor(op);
319 let _ = ct_result::destroy_tensor_descriptor(desc_b);
320 let _ = ct_result::destroy_tensor_descriptor(desc_a);
321 }
322 GpuError::lib(LIB, format!("EstimateWorkspaceSize: {e}"))
323 })?;
324
325 let plan = unsafe { ct_result::create_plan(h.0, op, pref, ws_size) }.map_err(|e| {
326 unsafe {
327 let _ = ct_result::destroy_plan_preference(pref);
328 let _ = ct_result::destroy_operation_descriptor(op);
329 let _ = ct_result::destroy_tensor_descriptor(desc_b);
330 let _ = ct_result::destroy_tensor_descriptor(desc_a);
331 }
332 GpuError::lib(LIB, format!("CreatePlan: {e}"))
333 })?;
334
335 Ok(CachedPlan {
336 plan,
337 pref,
338 op,
339 descs: vec![desc_a, desc_b],
340 workspace_size: ws_size,
341 })
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn permutation_request_round_trip() {
350 let key32 = build_permutation_key_raw(
351 <f32 as atomr_accel::AccelDtype>::NAME,
352 &[1, 2, 3],
353 &[3, 1, 2],
354 &[8, 16, 32],
355 &[32, 8, 16],
356 16,
357 ComputeDesc::MinF32,
358 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
359 );
360 let key32_neg = build_permutation_key_raw(
361 <f32 as atomr_accel::AccelDtype>::NAME,
362 &[1, 2, 3],
363 &[3, 1, 2],
364 &[8, 16, 32],
365 &[32, 8, 16],
366 16,
367 ComputeDesc::MinF32,
368 ct_sys::cutensorOperator_t::CUTENSOR_OP_NEG,
369 );
370 assert_ne!(key32, key32_neg);
372 assert_eq!(key32.op_kind, OpKind::Permutation);
373 assert_eq!(key32.dtype_tag, "f32");
374
375 let key64 = build_permutation_key_raw(
376 <f64 as atomr_accel::AccelDtype>::NAME,
377 &[1, 2, 3],
378 &[3, 1, 2],
379 &[8, 16, 32],
380 &[32, 8, 16],
381 16,
382 ComputeDesc::MinF64,
383 ct_sys::cutensorOperator_t::CUTENSOR_OP_IDENTITY,
384 );
385 assert_ne!(key32, key64);
386 assert_eq!(key64.dtype_tag, "f64");
387 }
388}