Skip to main content

atomr_accel_cuda/
nvrtc_cache.rs

1//! Persistent disk cache for NVRTC-compiled CUDA kernels (Phase 0.6).
2//!
3//! Modern CUDA kernels — Hopper/Blackwell hand-rolled CUDA-C, CUTLASS
4//! template instantiations, FlashAttention 2/3 variants — take 10s to
5//! 60s each through NVRTC. A persistent disk cache turns subsequent
6//! runs into single-digit-millisecond hot starts.
7//!
8//! ## Design
9//!
10//! - **Key**: `(source_hash, arch, options_hash)` where `arch` is the
11//!   SM compute capability (e.g. `80`, `90`, `100`) and `options_hash`
12//!   is FNV-1a of the NVRTC compile options in their original order
13//!   (callers should sort beforehand if they want order-insensitive
14//!   keys — see [`hash_options`]).
15//! - **Value**: serialised PTX (and optional CUBIN) bytes wrapped in
16//!   [`CachedKernel`].
17//! - **Storage**: filesystem under
18//!   `$XDG_CACHE_HOME/atomr-accel/nvrtc/` (or
19//!   `$HOME/.cache/atomr-accel/nvrtc/`, falling back to
20//!   `$TMPDIR/atomr-accel/nvrtc/`). One file per cache entry, named
21//!   `{source_hash:016x}-{arch}-{options_hash:016x}.bin`.
22//! - **Format**: bincode of [`CachedKernel`]. Entries whose
23//!   `atomr_accel_version` does not match
24//!   [`env!("CARGO_PKG_VERSION")`] are rejected on load.
25//! - **Concurrency**: in-process [`RwLock`]ed [`HashMap`] read-through
26//!   cache. Cross-process safety via atomic file write
27//!   (`<name>.tmp` then `rename`).
28//!
29//! ## Usage
30//!
31//! ```no_run
32//! use atomr_accel_cuda::nvrtc_cache::{
33//!     hash_options, hash_source, CachedKernel, NvrtcCache, NvrtcCacheKey,
34//! };
35//!
36//! let cache = NvrtcCache::new().unwrap();
37//! let src = "extern \"C\" __global__ void noop() {}";
38//! let key = NvrtcCacheKey {
39//!     source_hash: hash_source(src),
40//!     arch: 80,
41//!     options_hash: hash_options(["-std=c++17", "--use_fast_math"]),
42//! };
43//! if let Some(entry) = cache.get(key) {
44//!     println!("hot: {} bytes of PTX", entry.ptx.len());
45//! } else {
46//!     // ... NVRTC compile ...
47//!     let ptx: Vec<u8> = b"PTX...".to_vec();
48//!     cache.insert(key, CachedKernel::new(ptx, None)).unwrap();
49//! }
50//! ```
51//!
52//! Phase 5 will wire `NvrtcActor` through this cache; this module ships
53//! the storage layer alone.
54
55use crate::error::GpuError;
56use serde::{Deserialize, Serialize};
57use std::collections::HashMap;
58use std::fs;
59use std::hash::Hasher;
60use std::io::Write;
61use std::path::{Path, PathBuf};
62use std::sync::{Arc, RwLock};
63
64/// Composite cache key. `source_hash` and `options_hash` are produced
65/// by [`hash_source`] / [`hash_options`]; `arch` is the SM compute
66/// capability as an integer (e.g. `80`, `90`, `100`).
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub struct NvrtcCacheKey {
69    pub source_hash: u64,
70    pub arch: u32,
71    pub options_hash: u64,
72}
73
74/// On-disk and in-memory cache value.
75///
76/// `atomr_accel_version` is checked on load: entries from older
77/// crate versions are silently rejected so a cache built against a
78/// stale `cudarc` / NVRTC ABI never gets loaded into a newer build.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CachedKernel {
81    pub ptx: Vec<u8>,
82    pub cubin: Option<Vec<u8>>,
83    pub atomr_accel_version: String,
84}
85
86impl CachedKernel {
87    /// Build a [`CachedKernel`] stamped with the current crate version.
88    pub fn new(ptx: Vec<u8>, cubin: Option<Vec<u8>>) -> Self {
89        Self {
90            ptx,
91            cubin,
92            atomr_accel_version: env!("CARGO_PKG_VERSION").to_string(),
93        }
94    }
95}
96
97/// Read-through disk cache for compiled NVRTC kernels.
98///
99/// Cross-process safe: writes go to a temp file and are atomically
100/// renamed into place. Reads tolerate concurrent writers (a partial
101/// `.tmp` is invisible to readers; a corrupt `.bin` returns `None`).
102#[derive(Debug)]
103pub struct NvrtcCache {
104    dir: PathBuf,
105    memory: RwLock<HashMap<NvrtcCacheKey, Arc<CachedKernel>>>,
106}
107
108impl NvrtcCache {
109    /// Construct a cache rooted at the OS-default location.
110    ///
111    /// Probe order:
112    /// 1. `$XDG_CACHE_HOME/atomr-accel/nvrtc/`
113    /// 2. `$HOME/.cache/atomr-accel/nvrtc/` (via [`dirs::cache_dir`])
114    /// 3. `<temp_dir>/atomr-accel/nvrtc/`
115    pub fn new() -> Result<Self, GpuError> {
116        Self::with_dir(default_cache_dir())
117    }
118
119    /// Construct a cache rooted at an explicit directory. Creates the
120    /// directory (recursively) if it does not exist.
121    pub fn with_dir(path: PathBuf) -> Result<Self, GpuError> {
122        fs::create_dir_all(&path).map_err(|e| {
123            GpuError::Unrecoverable(format!(
124                "NvrtcCache: failed to create cache dir {}: {}",
125                path.display(),
126                e
127            ))
128        })?;
129        Ok(Self {
130            dir: path,
131            memory: RwLock::new(HashMap::new()),
132        })
133    }
134
135    /// Cache root directory.
136    pub fn dir(&self) -> &Path {
137        &self.dir
138    }
139
140    /// Look up a cached kernel.
141    ///
142    /// Returns `None` if the entry is absent, unreadable, mis-encoded,
143    /// or stamped with a non-matching `atomr_accel_version`. A read
144    /// hit on disk but cold in memory promotes the entry into the
145    /// in-memory map for subsequent lookups.
146    pub fn get(&self, key: NvrtcCacheKey) -> Option<Arc<CachedKernel>> {
147        if let Some(hit) = self
148            .memory
149            .read()
150            .ok()
151            .and_then(|guard| guard.get(&key).cloned())
152        {
153            return Some(hit);
154        }
155
156        let path = self.entry_path(&key);
157        let bytes = fs::read(&path).ok()?;
158        let entry: CachedKernel = bincode::deserialize(&bytes).ok()?;
159        if entry.atomr_accel_version != env!("CARGO_PKG_VERSION") {
160            return None;
161        }
162        let arc = Arc::new(entry);
163        if let Ok(mut guard) = self.memory.write() {
164            guard.insert(key, arc.clone());
165        }
166        Some(arc)
167    }
168
169    /// Store a kernel.
170    ///
171    /// The on-disk write is atomic (write to `<name>.tmp` then
172    /// rename). Concurrent writers race the rename; the loser's bytes
173    /// are silently overwritten — bincode payloads of the same key are
174    /// expected to be identical so this is benign.
175    pub fn insert(&self, key: NvrtcCacheKey, value: CachedKernel) -> Result<(), GpuError> {
176        let bytes = bincode::serialize(&value).map_err(|e| {
177            GpuError::Unrecoverable(format!("NvrtcCache: bincode serialize: {}", e))
178        })?;
179        let final_path = self.entry_path(&key);
180        let tmp_path = final_path.with_extension("bin.tmp");
181
182        {
183            let mut f = fs::File::create(&tmp_path).map_err(|e| {
184                GpuError::Unrecoverable(format!("NvrtcCache: create {}: {}", tmp_path.display(), e))
185            })?;
186            f.write_all(&bytes).map_err(|e| {
187                GpuError::Unrecoverable(format!("NvrtcCache: write {}: {}", tmp_path.display(), e))
188            })?;
189            f.sync_all().map_err(|e| {
190                GpuError::Unrecoverable(format!("NvrtcCache: fsync {}: {}", tmp_path.display(), e))
191            })?;
192        }
193
194        fs::rename(&tmp_path, &final_path).map_err(|e| {
195            // Best-effort clean up the temp file on rename failure.
196            let _ = fs::remove_file(&tmp_path);
197            GpuError::Unrecoverable(format!(
198                "NvrtcCache: rename {} -> {}: {}",
199                tmp_path.display(),
200                final_path.display(),
201                e
202            ))
203        })?;
204
205        if let Ok(mut guard) = self.memory.write() {
206            guard.insert(key, Arc::new(value));
207        }
208        Ok(())
209    }
210
211    /// Drop every in-memory entry. Disk contents are left untouched
212    /// — subsequent `get` calls will re-populate from disk.
213    pub fn clear_in_memory(&self) {
214        if let Ok(mut guard) = self.memory.write() {
215            guard.clear();
216        }
217    }
218
219    fn entry_path(&self, key: &NvrtcCacheKey) -> PathBuf {
220        self.dir.join(format!(
221            "{:016x}-{}-{:016x}.bin",
222            key.source_hash, key.arch, key.options_hash
223        ))
224    }
225}
226
227/// FNV-1a 64-bit hash of a kernel source string. Stable across
228/// processes and across crate compilations.
229pub fn hash_source(src: &str) -> u64 {
230    let mut h = FnvHasher::new();
231    h.write(src.as_bytes());
232    h.finish()
233}
234
235/// FNV-1a 64-bit hash of an iterable of NVRTC compile options.
236///
237/// **Order-sensitive**: callers that want order-insensitive keys must
238/// sort the iterable before passing it in. NVRTC's `--define-macro`
239/// and `--include-path` flags are order-significant in general, so
240/// the cache key preserves the caller's order.
241pub fn hash_options<I, S>(opts: I) -> u64
242where
243    I: IntoIterator<Item = S>,
244    S: AsRef<str>,
245{
246    let mut h = FnvHasher::new();
247    for opt in opts {
248        let bytes = opt.as_ref().as_bytes();
249        // Length prefix so ["ab", "c"] != ["a", "bc"].
250        h.write_u64(bytes.len() as u64);
251        h.write(bytes);
252        // Separator byte to make the boundary unambiguous.
253        h.write_u8(0xff);
254    }
255    h.finish()
256}
257
258fn default_cache_dir() -> PathBuf {
259    if let Some(xdg) = std::env::var_os("XDG_CACHE_HOME") {
260        let p = PathBuf::from(xdg);
261        if !p.as_os_str().is_empty() {
262            return p.join("atomr-accel").join("nvrtc");
263        }
264    }
265    if let Some(cache) = dirs::cache_dir() {
266        return cache.join("atomr-accel").join("nvrtc");
267    }
268    std::env::temp_dir().join("atomr-accel").join("nvrtc")
269}
270
271// ---------------------------------------------------------------------------
272// FNV-1a 64-bit. Tiny, dependency-free, deterministic — good enough for a
273// kernel-source content hash. Not cryptographic.
274// ---------------------------------------------------------------------------
275
276const FNV_OFFSET_BASIS_64: u64 = 0xcbf2_9ce4_8422_2325;
277const FNV_PRIME_64: u64 = 0x0000_0100_0000_01b3;
278
279struct FnvHasher(u64);
280
281impl FnvHasher {
282    fn new() -> Self {
283        Self(FNV_OFFSET_BASIS_64)
284    }
285}
286
287impl Hasher for FnvHasher {
288    fn finish(&self) -> u64 {
289        self.0
290    }
291    fn write(&mut self, bytes: &[u8]) {
292        let mut h = self.0;
293        for &b in bytes {
294            h ^= b as u64;
295            h = h.wrapping_mul(FNV_PRIME_64);
296        }
297        self.0 = h;
298    }
299}
300
301// ---------------------------------------------------------------------------
302// Tests
303// ---------------------------------------------------------------------------
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use tempfile::tempdir;
309
310    fn sample_kernel(seed: u8) -> CachedKernel {
311        CachedKernel::new(vec![seed; 64], Some(vec![seed.wrapping_add(1); 32]))
312    }
313
314    fn key(source_hash: u64, arch: u32, options_hash: u64) -> NvrtcCacheKey {
315        NvrtcCacheKey {
316            source_hash,
317            arch,
318            options_hash,
319        }
320    }
321
322    #[test]
323    fn round_trip_via_with_dir() {
324        let tmp = tempdir().unwrap();
325        let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
326        let k = key(0xdead_beef, 80, 0x1234);
327        let v = sample_kernel(7);
328
329        assert!(cache.get(k).is_none(), "cold cache should miss");
330
331        cache.insert(k, v.clone()).unwrap();
332
333        let got = cache.get(k).expect("hot lookup must hit");
334        assert_eq!(got.ptx, v.ptx);
335        assert_eq!(got.cubin, v.cubin);
336        assert_eq!(got.atomr_accel_version, env!("CARGO_PKG_VERSION"));
337    }
338
339    #[test]
340    fn cache_persists_across_fresh_handles() {
341        let tmp = tempdir().unwrap();
342        let dir = tmp.path().to_path_buf();
343        let k = key(0x9999, 90, 0xabcd);
344        let v = sample_kernel(42);
345
346        {
347            let cache = NvrtcCache::with_dir(dir.clone()).unwrap();
348            cache.insert(k, v.clone()).unwrap();
349        } // drop the cache, only the file remains
350
351        let cache2 = NvrtcCache::with_dir(dir).unwrap();
352        let got = cache2.get(k).expect("disk-backed entry must survive");
353        assert_eq!(got.ptx, v.ptx);
354        assert_eq!(got.cubin, v.cubin);
355    }
356
357    #[test]
358    fn distinct_keys_distinct_paths() {
359        let tmp = tempdir().unwrap();
360        let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
361
362        let k_src = key(1, 80, 0);
363        let k_arch = key(0, 90, 0);
364        let k_opts = key(0, 80, 1);
365        let k_zero = key(0, 80, 0);
366
367        let p_src = cache.entry_path(&k_src);
368        let p_arch = cache.entry_path(&k_arch);
369        let p_opts = cache.entry_path(&k_opts);
370        let p_zero = cache.entry_path(&k_zero);
371
372        assert_ne!(p_src, p_arch);
373        assert_ne!(p_src, p_opts);
374        assert_ne!(p_src, p_zero);
375        assert_ne!(p_arch, p_opts);
376        assert_ne!(p_arch, p_zero);
377        assert_ne!(p_opts, p_zero);
378
379        // Inserting under each key writes a separate file.
380        cache.insert(k_src, sample_kernel(1)).unwrap();
381        cache.insert(k_arch, sample_kernel(2)).unwrap();
382        cache.insert(k_opts, sample_kernel(3)).unwrap();
383        cache.insert(k_zero, sample_kernel(4)).unwrap();
384
385        assert!(p_src.exists());
386        assert!(p_arch.exists());
387        assert!(p_opts.exists());
388        assert!(p_zero.exists());
389
390        // And reads come back distinct.
391        assert_eq!(cache.get(k_src).unwrap().ptx, vec![1u8; 64]);
392        assert_eq!(cache.get(k_arch).unwrap().ptx, vec![2u8; 64]);
393        assert_eq!(cache.get(k_opts).unwrap().ptx, vec![3u8; 64]);
394        assert_eq!(cache.get(k_zero).unwrap().ptx, vec![4u8; 64]);
395    }
396
397    #[test]
398    fn version_mismatch_rejected_on_load() {
399        let tmp = tempdir().unwrap();
400        let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
401        let k = key(11, 80, 22);
402
403        // Hand-write an entry with a stale version stamp.
404        let stale = CachedKernel {
405            ptx: vec![0xaa; 16],
406            cubin: None,
407            atomr_accel_version: "0.0.0-impossible".to_string(),
408        };
409        let bytes = bincode::serialize(&stale).unwrap();
410        let path = cache.entry_path(&k);
411        fs::write(&path, &bytes).unwrap();
412
413        // Skip the in-memory shortcut by clearing it explicitly.
414        cache.clear_in_memory();
415        assert!(
416            cache.get(k).is_none(),
417            "entry with mismatched atomr_accel_version must not be returned"
418        );
419    }
420
421    #[test]
422    fn hash_source_is_deterministic() {
423        let a = hash_source("__global__ void k() {}");
424        let b = hash_source("__global__ void k() {}");
425        let c = hash_source("__global__ void other() {}");
426        assert_eq!(a, b);
427        assert_ne!(a, c);
428    }
429
430    #[test]
431    fn hash_options_is_deterministic_and_order_sensitive() {
432        let a = hash_options(["-std=c++17", "--use_fast_math"]);
433        let b = hash_options(["-std=c++17", "--use_fast_math"]);
434        let c = hash_options(["--use_fast_math", "-std=c++17"]);
435        let d = hash_options(["-std=c++17"]);
436        let e = hash_options(Vec::<&str>::new());
437
438        assert_eq!(a, b, "same input must produce same hash");
439        assert_ne!(a, c, "option order must change the hash");
440        assert_ne!(a, d, "option count must change the hash");
441        assert_ne!(a, e);
442        assert_ne!(d, e);
443
444        // Length-prefix invariant: ["ab","c"] != ["a","bc"].
445        let split1 = hash_options(["ab", "c"]);
446        let split2 = hash_options(["a", "bc"]);
447        assert_ne!(split1, split2);
448    }
449
450    #[test]
451    fn clear_in_memory_keeps_disk() {
452        let tmp = tempdir().unwrap();
453        let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
454        let k = key(7, 80, 7);
455        cache.insert(k, sample_kernel(9)).unwrap();
456        cache.clear_in_memory();
457        let got = cache.get(k).expect("disk entry survives clear_in_memory");
458        assert_eq!(got.ptx, vec![9u8; 64]);
459    }
460}