atomr_accel_cuda/kernel/blas_lt/
heuristic.rs1use std::num::NonZeroUsize;
21use std::sync::Arc;
22
23use cudarc::cublaslt::sys::cublasLtMatmulAlgo_t;
24use lru::LruCache;
25use parking_lot::Mutex;
26
27use crate::dtype::DTypeKind;
28use crate::kernel::blas_lt::epilogue::Epilogue;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct HeuristicKey {
34 pub m: i32,
35 pub n: i32,
36 pub k: i32,
37 pub dtype: u32,
41 pub transa: bool,
42 pub transb: bool,
43 pub epilogue: Epilogue,
44 pub sm_arch: u32,
45}
46
47impl HeuristicKey {
48 pub fn new(
49 m: i32,
50 n: i32,
51 k: i32,
52 dtype: DTypeKind,
53 transa: bool,
54 transb: bool,
55 epilogue: Epilogue,
56 sm_arch: u32,
57 ) -> Self {
58 Self {
59 m,
60 n,
61 k,
62 dtype: dtype as u32,
63 transa,
64 transb,
65 epilogue,
66 sm_arch,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy)]
74pub struct HeuristicEntry {
75 pub algo: cublasLtMatmulAlgo_t,
76 pub workspace_size: usize,
77 pub waves_count: f32,
82}
83
84unsafe impl Send for HeuristicEntry {}
86unsafe impl Sync for HeuristicEntry {}
87
88pub const DEFAULT_HEURISTIC_CAPACITY: usize = 256;
90
91pub const DEFAULT_TOP_K: usize = 3;
94
95#[derive(Clone)]
97pub struct HeuristicCacheRef {
98 inner: Arc<Mutex<LruCache<HeuristicKey, HeuristicEntry>>>,
99 top_k: usize,
100}
101
102impl HeuristicCacheRef {
103 pub fn with_capacity(capacity: usize) -> Self {
104 let cap = NonZeroUsize::new(capacity.max(1))
105 .expect("HeuristicCacheRef::with_capacity: cap.max(1) is non-zero");
106 Self {
107 inner: Arc::new(Mutex::new(LruCache::new(cap))),
108 top_k: DEFAULT_TOP_K,
109 }
110 }
111
112 pub fn default_size() -> Self {
113 Self::with_capacity(DEFAULT_HEURISTIC_CAPACITY)
114 }
115
116 pub fn top_k(&self) -> usize {
118 self.top_k
119 }
120
121 pub fn get(&self, key: &HeuristicKey) -> Option<HeuristicEntry> {
123 self.inner.lock().get(key).copied()
124 }
125
126 pub fn insert(&self, key: HeuristicKey, entry: HeuristicEntry) {
128 self.inner.lock().put(key, entry);
129 }
130
131 pub fn len(&self) -> usize {
133 self.inner.lock().len()
134 }
135
136 pub fn is_empty(&self) -> bool {
137 self.len() == 0
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 fn dummy_entry(waves: f32) -> HeuristicEntry {
146 HeuristicEntry {
147 algo: cublasLtMatmulAlgo_t { data: [0u64; 8] },
148 workspace_size: 4 * 1024 * 1024,
149 waves_count: waves,
150 }
151 }
152
153 fn k(m: i32, n: i32, k: i32) -> HeuristicKey {
154 HeuristicKey::new(m, n, k, DTypeKind::F32, false, false, Epilogue::None, 0)
155 }
156
157 #[test]
158 fn cache_lru_hit_miss() {
159 let cache = HeuristicCacheRef::with_capacity(2);
160 assert!(cache.is_empty());
161
162 let k1 = k(64, 64, 64);
163 let k2 = k(128, 128, 128);
164 let k3 = k(256, 256, 256);
165
166 assert!(cache.get(&k1).is_none());
168 cache.insert(k1, dummy_entry(1.5));
169 cache.insert(k2, dummy_entry(2.5));
170 assert_eq!(cache.len(), 2);
171
172 let hit = cache.get(&k1).expect("k1 should hit");
174 assert_eq!(hit.waves_count, 1.5);
175
176 cache.insert(k3, dummy_entry(3.5));
178 assert_eq!(cache.len(), 2);
179 assert!(cache.get(&k2).is_none(), "k2 should have been evicted");
180 assert!(cache.get(&k1).is_some(), "k1 should still be present");
181 assert!(cache.get(&k3).is_some(), "k3 should be present");
182 }
183
184 #[test]
185 fn capacity_min_one() {
186 let cache = HeuristicCacheRef::with_capacity(0);
187 cache.insert(k(1, 1, 1), dummy_entry(0.0));
188 assert_eq!(cache.len(), 1);
189 }
190
191 #[test]
192 fn distinct_keys_for_different_axes() {
193 let base = k(64, 64, 64);
194 let with_trans = HeuristicKey {
195 transa: true,
196 ..base
197 };
198 let with_arch = HeuristicKey {
199 sm_arch: 90,
200 ..base
201 };
202 let with_epi = HeuristicKey {
203 epilogue: Epilogue::Bias,
204 ..base
205 };
206 assert_ne!(base, with_trans);
207 assert_ne!(base, with_arch);
208 assert_ne!(base, with_epi);
209 }
210}