Skip to main content

atomr_accel_cuda/kernel/blas_lt/
heuristic.rs

1//! Heuristic-cache for cuBLASLt matmul algorithms.
2//!
3//! cuBLASLt's `cublasLtMatmulAlgoGetHeuristic` is a synchronous
4//! library call that takes single-digit milliseconds. For repeated
5//! shapes (every iteration of a transformer step) we cache the
6//! best-by-wall-time algorithm under
7//! `(m, n, k, dtype, layout, epilogue, sm_arch)` and reuse it.
8//!
9//! Key design points:
10//! - LRU eviction (capacity defaults to 256 entries — large enough to
11//!   cover a model's full shape repertoire, small enough to fit in a
12//!   couple KiB of host RAM).
13//! - The cache lives in a `parking_lot::Mutex<lru::LruCache>` behind
14//!   an `Arc`, so a cloneable `HeuristicCacheRef` can flow into
15//!   per-message `BlasLtDispatchCtx` without `Send` headaches.
16//! - We store the raw `cublasLtMatmulAlgo_t` plus a `workspace_size`
17//!   hint; the actor's `WorkspacePool` uses the workspace size to
18//!   recycle the right slot.
19
20use std::num::NonZeroUsize;
21use std::sync::Arc;
22
23use cudarc::cublaslt::sys::cublasLtMatmulAlgo_t;
24use lru::LruCache;
25use parking_lot::Mutex;
26
27use crate::dtype::DTypeKind;
28use crate::kernel::blas_lt::epilogue::Epilogue;
29
30/// Cache key — fully self-describing so two requests with the same
31/// shape/layout/dtype/epilogue/arch trio land in the same bucket.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct HeuristicKey {
34    pub m: i32,
35    pub n: i32,
36    pub k: i32,
37    /// Stable dtype tag. Captured as a u32 (rather than `DTypeKind`)
38    /// so the key derives cleanly even if `DTypeKind` grows new
39    /// variants.
40    pub dtype: u32,
41    pub transa: bool,
42    pub transb: bool,
43    pub epilogue: Epilogue,
44    pub sm_arch: u32,
45}
46
47impl HeuristicKey {
48    pub fn new(
49        m: i32,
50        n: i32,
51        k: i32,
52        dtype: DTypeKind,
53        transa: bool,
54        transb: bool,
55        epilogue: Epilogue,
56        sm_arch: u32,
57    ) -> Self {
58        Self {
59            m,
60            n,
61            k,
62            dtype: dtype as u32,
63            transa,
64            transb,
65            epilogue,
66            sm_arch,
67        }
68    }
69}
70
71/// Cached value — best algorithm by wall-time plus the workspace size
72/// the heuristic reported.
73#[derive(Debug, Clone, Copy)]
74pub struct HeuristicEntry {
75    pub algo: cublasLtMatmulAlgo_t,
76    pub workspace_size: usize,
77    /// Reported wall time (`wavesCount` from
78    /// `cublasLtMatmulHeuristicResult_t`); lower is better. Stored so
79    /// callers can decide to re-run the search if a better algorithm
80    /// might be available after a tuning sweep.
81    pub waves_count: f32,
82}
83
84// SAFETY: `cublasLtMatmulAlgo_t` is `repr(C) [u64; 8]` — pure POD.
85unsafe impl Send for HeuristicEntry {}
86unsafe impl Sync for HeuristicEntry {}
87
88/// Default capacity of the heuristic cache.
89pub const DEFAULT_HEURISTIC_CAPACITY: usize = 256;
90
91/// Default top-k of algorithms to query from cuBLASLt on each cold
92/// lookup. We keep the best by `waves_count` and discard the rest.
93pub const DEFAULT_TOP_K: usize = 3;
94
95/// Shareable handle to the heuristic cache. Cheap to clone.
96#[derive(Clone)]
97pub struct HeuristicCacheRef {
98    inner: Arc<Mutex<LruCache<HeuristicKey, HeuristicEntry>>>,
99    top_k: usize,
100}
101
102impl HeuristicCacheRef {
103    pub fn with_capacity(capacity: usize) -> Self {
104        let cap = NonZeroUsize::new(capacity.max(1))
105            .expect("HeuristicCacheRef::with_capacity: cap.max(1) is non-zero");
106        Self {
107            inner: Arc::new(Mutex::new(LruCache::new(cap))),
108            top_k: DEFAULT_TOP_K,
109        }
110    }
111
112    pub fn default_size() -> Self {
113        Self::with_capacity(DEFAULT_HEURISTIC_CAPACITY)
114    }
115
116    /// Number of algorithms to request from cuBLASLt on cold lookup.
117    pub fn top_k(&self) -> usize {
118        self.top_k
119    }
120
121    /// Cache hit (refreshes LRU order) or miss.
122    pub fn get(&self, key: &HeuristicKey) -> Option<HeuristicEntry> {
123        self.inner.lock().get(key).copied()
124    }
125
126    /// Insert a (key, entry) pair, possibly evicting the LRU tail.
127    pub fn insert(&self, key: HeuristicKey, entry: HeuristicEntry) {
128        self.inner.lock().put(key, entry);
129    }
130
131    /// Snapshot of cache occupancy. Diagnostic only.
132    pub fn len(&self) -> usize {
133        self.inner.lock().len()
134    }
135
136    pub fn is_empty(&self) -> bool {
137        self.len() == 0
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    fn dummy_entry(waves: f32) -> HeuristicEntry {
146        HeuristicEntry {
147            algo: cublasLtMatmulAlgo_t { data: [0u64; 8] },
148            workspace_size: 4 * 1024 * 1024,
149            waves_count: waves,
150        }
151    }
152
153    fn k(m: i32, n: i32, k: i32) -> HeuristicKey {
154        HeuristicKey::new(m, n, k, DTypeKind::F32, false, false, Epilogue::None, 0)
155    }
156
157    #[test]
158    fn cache_lru_hit_miss() {
159        let cache = HeuristicCacheRef::with_capacity(2);
160        assert!(cache.is_empty());
161
162        let k1 = k(64, 64, 64);
163        let k2 = k(128, 128, 128);
164        let k3 = k(256, 256, 256);
165
166        // Cold misses.
167        assert!(cache.get(&k1).is_none());
168        cache.insert(k1, dummy_entry(1.5));
169        cache.insert(k2, dummy_entry(2.5));
170        assert_eq!(cache.len(), 2);
171
172        // Hits refresh order — touching k1 makes k2 the LRU tail.
173        let hit = cache.get(&k1).expect("k1 should hit");
174        assert_eq!(hit.waves_count, 1.5);
175
176        // Overflow evicts the LRU tail (k2).
177        cache.insert(k3, dummy_entry(3.5));
178        assert_eq!(cache.len(), 2);
179        assert!(cache.get(&k2).is_none(), "k2 should have been evicted");
180        assert!(cache.get(&k1).is_some(), "k1 should still be present");
181        assert!(cache.get(&k3).is_some(), "k3 should be present");
182    }
183
184    #[test]
185    fn capacity_min_one() {
186        let cache = HeuristicCacheRef::with_capacity(0);
187        cache.insert(k(1, 1, 1), dummy_entry(0.0));
188        assert_eq!(cache.len(), 1);
189    }
190
191    #[test]
192    fn distinct_keys_for_different_axes() {
193        let base = k(64, 64, 64);
194        let with_trans = HeuristicKey {
195            transa: true,
196            ..base
197        };
198        let with_arch = HeuristicKey {
199            sm_arch: 90,
200            ..base
201        };
202        let with_epi = HeuristicKey {
203            epilogue: Epilogue::Bias,
204            ..base
205        };
206        assert_ne!(base, with_trans);
207        assert_ne!(base, with_arch);
208        assert_ne!(base, with_epi);
209    }
210}