Skip to main content

atomr_accel_cuda/kernel/tensor/
plan_cache.rs

1//! LRU plan cache for cuTENSOR operations.
2//!
3//! `cutensorCreatePlan` is expensive enough that real workloads
4//! amortise it across many calls with identical shape signatures.
5//! [`PlanCache`] holds an `LruCache<PlanKey, CachedPlan>` so the actor
6//! can hash a description once, look up an existing plan, and only
7//! pay for the descriptor + plan + workspace-estimate triplet on a
8//! miss.
9//!
10//! # Key
11//!
12//! Keyed by `(op_kind, modes_hash, extents_hash, alignment,
13//! compute_descriptor_tag, scalar_dtype_tag, autotune_algo)` —
14//! everything that influences cuTENSOR's choice of internal kernel
15//! and workspace size. The autotune-picked algo is folded into the
16//! key so an autotuned plan never collides with a default-algo plan.
17
18use std::sync::Arc;
19
20use cudarc::cutensor::result as ct_result;
21use cudarc::cutensor::sys as ct_sys;
22use lru::LruCache;
23use parking_lot::Mutex;
24
25/// Default LRU capacity. 256 is a generous upper bound: each entry
26/// owns a `cutensorPlan_t` plus a `cutensorOperationDescriptor_t`
27/// plus tensor descriptors, which together cost ~few KiB on the host
28/// — order-MiB total at full occupancy.
29pub const DEFAULT_PLAN_CACHE_SIZE: usize = 256;
30
31/// Operation kind discriminator embedded in the cache key.
32#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
33pub enum OpKind {
34    Contract,
35    Reduce,
36    ElementwiseBinary,
37    ElementwiseTrinary,
38    Permutation,
39}
40
41impl OpKind {
42    pub fn tag(self) -> &'static str {
43        match self {
44            OpKind::Contract => "contract",
45            OpKind::Reduce => "reduce",
46            OpKind::ElementwiseBinary => "ewbin",
47            OpKind::ElementwiseTrinary => "ewtri",
48            OpKind::Permutation => "permute",
49        }
50    }
51}
52
53/// Hashable plan key. Modes / extents arrive pre-hashed (u64) so the
54/// key remains `Copy + Eq`.
55#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
56pub struct PlanKey {
57    pub op_kind: OpKind,
58    pub modes_hash: u64,
59    pub extents_hash: u64,
60    pub alignment: u32,
61    pub compute_desc_tag: u32,
62    pub dtype_tag: &'static str,
63    /// `0` means "default algorithm". Autotune writes the chosen
64    /// `cutensorAlgo_t as i32` here so autotuned plans get their own
65    /// cache slot.
66    pub algo: i32,
67}
68
69/// Newtype around the cuTENSOR descriptor pointers so we can `unsafe
70/// impl Send`. Each `CachedPlan` owns its descriptors and the plan
71/// itself; on `Drop` we tear them down in reverse construction order.
72pub struct CachedPlan {
73    pub plan: ct_sys::cutensorPlan_t,
74    pub pref: ct_sys::cutensorPlanPreference_t,
75    pub op: ct_sys::cutensorOperationDescriptor_t,
76    pub descs: Vec<ct_sys::cutensorTensorDescriptor_t>,
77    pub workspace_size: u64,
78}
79
80unsafe impl Send for CachedPlan {}
81unsafe impl Sync for CachedPlan {}
82
83impl Drop for CachedPlan {
84    fn drop(&mut self) {
85        unsafe {
86            let _ = ct_result::destroy_plan(self.plan);
87            let _ = ct_result::destroy_plan_preference(self.pref);
88            let _ = ct_result::destroy_operation_descriptor(self.op);
89            for d in self.descs.drain(..) {
90                let _ = ct_result::destroy_tensor_descriptor(d);
91            }
92        }
93    }
94}
95
96/// Thread-safe wrapper around `LruCache<PlanKey, Arc<CachedPlan>>`.
97/// `Arc` lets the actor hand the plan out to a kernel-launch closure
98/// that may outlive a subsequent cache eviction.
99pub struct PlanCache {
100    cache: Mutex<LruCache<PlanKey, Arc<CachedPlan>>>,
101}
102
103impl PlanCache {
104    pub fn new(cap: usize) -> Self {
105        let cap = std::num::NonZeroUsize::new(cap.max(1)).expect("non-zero cap");
106        Self {
107            cache: Mutex::new(LruCache::new(cap)),
108        }
109    }
110
111    pub fn with_default_capacity() -> Self {
112        Self::new(DEFAULT_PLAN_CACHE_SIZE)
113    }
114
115    pub fn get(&self, key: &PlanKey) -> Option<Arc<CachedPlan>> {
116        self.cache.lock().get(key).cloned()
117    }
118
119    pub fn put(&self, key: PlanKey, plan: Arc<CachedPlan>) {
120        self.cache.lock().put(key, plan);
121    }
122
123    /// Cache size for tests / observability.
124    pub fn len(&self) -> usize {
125        self.cache.lock().len()
126    }
127
128    pub fn is_empty(&self) -> bool {
129        self.cache.lock().is_empty()
130    }
131}
132
133/// Hash a slice of `i64` (extents or strides) into a `u64`. Uses the
134/// std FxHash-equivalent default hasher. Cheap and stable within a
135/// single process — that's all we need for plan-cache lookups.
136pub fn hash_i64s(values: &[i64]) -> u64 {
137    use std::hash::{Hash, Hasher};
138    let mut h = std::collections::hash_map::DefaultHasher::new();
139    values.hash(&mut h);
140    h.finish()
141}
142
143/// Hash a slice of `i32` modes.
144pub fn hash_i32s(values: &[i32]) -> u64 {
145    use std::hash::{Hash, Hasher};
146    let mut h = std::collections::hash_map::DefaultHasher::new();
147    values.hash(&mut h);
148    h.finish()
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    fn make_key(op_kind: OpKind, dtype_tag: &'static str, modes: u64) -> PlanKey {
156        PlanKey {
157            op_kind,
158            modes_hash: modes,
159            extents_hash: 0,
160            alignment: 16,
161            compute_desc_tag: 1,
162            dtype_tag,
163            algo: 0,
164        }
165    }
166
167    #[test]
168    fn cache_lru_hit_miss() {
169        // Use a tiny synthetic CachedPlan that doesn't actually hold
170        // cuTENSOR resources — we test the LRU policy, not cudarc.
171        // Drop is suppressed because the descriptor pointers are
172        // null and `destroy_*` would no-op or fault. Wrap each in a
173        // ManuallyDrop to keep it inert.
174        let cache = PlanCache::new(2);
175        let k1 = make_key(OpKind::Contract, "f32", 1);
176        let k2 = make_key(OpKind::Reduce, "f32", 2);
177        let k3 = make_key(OpKind::Permutation, "f32", 3);
178
179        // We can't construct a real CachedPlan without GPU resources.
180        // The PlanCache `put`/`get` API works against `Arc<CachedPlan>`
181        // — instead exercise the hash/eq surface on PlanKey, plus the
182        // capacity bound, by checking len() after a mock-free path.
183        // (Integration tests on a GPU exercise the full insert/get.)
184        assert_eq!(cache.len(), 0);
185        assert!(cache.is_empty());
186        assert!(cache.get(&k1).is_none());
187        // Verify keys differ.
188        assert_ne!(k1, k2);
189        assert_ne!(k2, k3);
190        assert_eq!(k1, make_key(OpKind::Contract, "f32", 1));
191    }
192
193    #[test]
194    fn op_kind_tags_are_stable() {
195        assert_eq!(OpKind::Contract.tag(), "contract");
196        assert_eq!(OpKind::Reduce.tag(), "reduce");
197        assert_eq!(OpKind::ElementwiseBinary.tag(), "ewbin");
198        assert_eq!(OpKind::ElementwiseTrinary.tag(), "ewtri");
199        assert_eq!(OpKind::Permutation.tag(), "permute");
200    }
201
202    #[test]
203    fn hash_is_order_sensitive() {
204        assert_ne!(hash_i64s(&[1, 2, 3]), hash_i64s(&[3, 2, 1]));
205        assert_eq!(hash_i64s(&[1, 2, 3]), hash_i64s(&[1, 2, 3]));
206        assert_ne!(hash_i32s(&[1, 2]), hash_i32s(&[2, 1]));
207    }
208}