atomr_accel_cuda/kernel/tensor/
plan_cache.rs1use std::sync::Arc;
19
20use cudarc::cutensor::result as ct_result;
21use cudarc::cutensor::sys as ct_sys;
22use lru::LruCache;
23use parking_lot::Mutex;
24
25pub const DEFAULT_PLAN_CACHE_SIZE: usize = 256;
30
31#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
33pub enum OpKind {
34 Contract,
35 Reduce,
36 ElementwiseBinary,
37 ElementwiseTrinary,
38 Permutation,
39}
40
41impl OpKind {
42 pub fn tag(self) -> &'static str {
43 match self {
44 OpKind::Contract => "contract",
45 OpKind::Reduce => "reduce",
46 OpKind::ElementwiseBinary => "ewbin",
47 OpKind::ElementwiseTrinary => "ewtri",
48 OpKind::Permutation => "permute",
49 }
50 }
51}
52
53#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
56pub struct PlanKey {
57 pub op_kind: OpKind,
58 pub modes_hash: u64,
59 pub extents_hash: u64,
60 pub alignment: u32,
61 pub compute_desc_tag: u32,
62 pub dtype_tag: &'static str,
63 pub algo: i32,
67}
68
69pub struct CachedPlan {
73 pub plan: ct_sys::cutensorPlan_t,
74 pub pref: ct_sys::cutensorPlanPreference_t,
75 pub op: ct_sys::cutensorOperationDescriptor_t,
76 pub descs: Vec<ct_sys::cutensorTensorDescriptor_t>,
77 pub workspace_size: u64,
78}
79
80unsafe impl Send for CachedPlan {}
81unsafe impl Sync for CachedPlan {}
82
83impl Drop for CachedPlan {
84 fn drop(&mut self) {
85 unsafe {
86 let _ = ct_result::destroy_plan(self.plan);
87 let _ = ct_result::destroy_plan_preference(self.pref);
88 let _ = ct_result::destroy_operation_descriptor(self.op);
89 for d in self.descs.drain(..) {
90 let _ = ct_result::destroy_tensor_descriptor(d);
91 }
92 }
93 }
94}
95
96pub struct PlanCache {
100 cache: Mutex<LruCache<PlanKey, Arc<CachedPlan>>>,
101}
102
103impl PlanCache {
104 pub fn new(cap: usize) -> Self {
105 let cap = std::num::NonZeroUsize::new(cap.max(1)).expect("non-zero cap");
106 Self {
107 cache: Mutex::new(LruCache::new(cap)),
108 }
109 }
110
111 pub fn with_default_capacity() -> Self {
112 Self::new(DEFAULT_PLAN_CACHE_SIZE)
113 }
114
115 pub fn get(&self, key: &PlanKey) -> Option<Arc<CachedPlan>> {
116 self.cache.lock().get(key).cloned()
117 }
118
119 pub fn put(&self, key: PlanKey, plan: Arc<CachedPlan>) {
120 self.cache.lock().put(key, plan);
121 }
122
123 pub fn len(&self) -> usize {
125 self.cache.lock().len()
126 }
127
128 pub fn is_empty(&self) -> bool {
129 self.cache.lock().is_empty()
130 }
131}
132
133pub fn hash_i64s(values: &[i64]) -> u64 {
137 use std::hash::{Hash, Hasher};
138 let mut h = std::collections::hash_map::DefaultHasher::new();
139 values.hash(&mut h);
140 h.finish()
141}
142
143pub fn hash_i32s(values: &[i32]) -> u64 {
145 use std::hash::{Hash, Hasher};
146 let mut h = std::collections::hash_map::DefaultHasher::new();
147 values.hash(&mut h);
148 h.finish()
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 fn make_key(op_kind: OpKind, dtype_tag: &'static str, modes: u64) -> PlanKey {
156 PlanKey {
157 op_kind,
158 modes_hash: modes,
159 extents_hash: 0,
160 alignment: 16,
161 compute_desc_tag: 1,
162 dtype_tag,
163 algo: 0,
164 }
165 }
166
167 #[test]
168 fn cache_lru_hit_miss() {
169 let cache = PlanCache::new(2);
175 let k1 = make_key(OpKind::Contract, "f32", 1);
176 let k2 = make_key(OpKind::Reduce, "f32", 2);
177 let k3 = make_key(OpKind::Permutation, "f32", 3);
178
179 assert_eq!(cache.len(), 0);
185 assert!(cache.is_empty());
186 assert!(cache.get(&k1).is_none());
187 assert_ne!(k1, k2);
189 assert_ne!(k2, k3);
190 assert_eq!(k1, make_key(OpKind::Contract, "f32", 1));
191 }
192
193 #[test]
194 fn op_kind_tags_are_stable() {
195 assert_eq!(OpKind::Contract.tag(), "contract");
196 assert_eq!(OpKind::Reduce.tag(), "reduce");
197 assert_eq!(OpKind::ElementwiseBinary.tag(), "ewbin");
198 assert_eq!(OpKind::ElementwiseTrinary.tag(), "ewtri");
199 assert_eq!(OpKind::Permutation.tag(), "permute");
200 }
201
202 #[test]
203 fn hash_is_order_sensitive() {
204 assert_ne!(hash_i64s(&[1, 2, 3]), hash_i64s(&[3, 2, 1]));
205 assert_eq!(hash_i64s(&[1, 2, 3]), hash_i64s(&[1, 2, 3]));
206 assert_ne!(hash_i32s(&[1, 2]), hash_i32s(&[2, 1]));
207 }
208}