1use crate::error::GpuError;
56use serde::{Deserialize, Serialize};
57use std::collections::HashMap;
58use std::fs;
59use std::hash::Hasher;
60use std::io::Write;
61use std::path::{Path, PathBuf};
62use std::sync::{Arc, RwLock};
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub struct NvrtcCacheKey {
69 pub source_hash: u64,
70 pub arch: u32,
71 pub options_hash: u64,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CachedKernel {
81 pub ptx: Vec<u8>,
82 pub cubin: Option<Vec<u8>>,
83 pub atomr_accel_version: String,
84}
85
86impl CachedKernel {
87 pub fn new(ptx: Vec<u8>, cubin: Option<Vec<u8>>) -> Self {
89 Self {
90 ptx,
91 cubin,
92 atomr_accel_version: env!("CARGO_PKG_VERSION").to_string(),
93 }
94 }
95}
96
97#[derive(Debug)]
103pub struct NvrtcCache {
104 dir: PathBuf,
105 memory: RwLock<HashMap<NvrtcCacheKey, Arc<CachedKernel>>>,
106}
107
108impl NvrtcCache {
109 pub fn new() -> Result<Self, GpuError> {
116 Self::with_dir(default_cache_dir())
117 }
118
119 pub fn with_dir(path: PathBuf) -> Result<Self, GpuError> {
122 fs::create_dir_all(&path).map_err(|e| {
123 GpuError::Unrecoverable(format!(
124 "NvrtcCache: failed to create cache dir {}: {}",
125 path.display(),
126 e
127 ))
128 })?;
129 Ok(Self {
130 dir: path,
131 memory: RwLock::new(HashMap::new()),
132 })
133 }
134
135 pub fn dir(&self) -> &Path {
137 &self.dir
138 }
139
140 pub fn get(&self, key: NvrtcCacheKey) -> Option<Arc<CachedKernel>> {
147 if let Some(hit) = self
148 .memory
149 .read()
150 .ok()
151 .and_then(|guard| guard.get(&key).cloned())
152 {
153 return Some(hit);
154 }
155
156 let path = self.entry_path(&key);
157 let bytes = fs::read(&path).ok()?;
158 let entry: CachedKernel = bincode::deserialize(&bytes).ok()?;
159 if entry.atomr_accel_version != env!("CARGO_PKG_VERSION") {
160 return None;
161 }
162 let arc = Arc::new(entry);
163 if let Ok(mut guard) = self.memory.write() {
164 guard.insert(key, arc.clone());
165 }
166 Some(arc)
167 }
168
169 pub fn insert(&self, key: NvrtcCacheKey, value: CachedKernel) -> Result<(), GpuError> {
176 let bytes = bincode::serialize(&value).map_err(|e| {
177 GpuError::Unrecoverable(format!("NvrtcCache: bincode serialize: {}", e))
178 })?;
179 let final_path = self.entry_path(&key);
180 let tmp_path = final_path.with_extension("bin.tmp");
181
182 {
183 let mut f = fs::File::create(&tmp_path).map_err(|e| {
184 GpuError::Unrecoverable(format!("NvrtcCache: create {}: {}", tmp_path.display(), e))
185 })?;
186 f.write_all(&bytes).map_err(|e| {
187 GpuError::Unrecoverable(format!("NvrtcCache: write {}: {}", tmp_path.display(), e))
188 })?;
189 f.sync_all().map_err(|e| {
190 GpuError::Unrecoverable(format!("NvrtcCache: fsync {}: {}", tmp_path.display(), e))
191 })?;
192 }
193
194 fs::rename(&tmp_path, &final_path).map_err(|e| {
195 let _ = fs::remove_file(&tmp_path);
197 GpuError::Unrecoverable(format!(
198 "NvrtcCache: rename {} -> {}: {}",
199 tmp_path.display(),
200 final_path.display(),
201 e
202 ))
203 })?;
204
205 if let Ok(mut guard) = self.memory.write() {
206 guard.insert(key, Arc::new(value));
207 }
208 Ok(())
209 }
210
211 pub fn clear_in_memory(&self) {
214 if let Ok(mut guard) = self.memory.write() {
215 guard.clear();
216 }
217 }
218
219 fn entry_path(&self, key: &NvrtcCacheKey) -> PathBuf {
220 self.dir.join(format!(
221 "{:016x}-{}-{:016x}.bin",
222 key.source_hash, key.arch, key.options_hash
223 ))
224 }
225}
226
227pub fn hash_source(src: &str) -> u64 {
230 let mut h = FnvHasher::new();
231 h.write(src.as_bytes());
232 h.finish()
233}
234
235pub fn hash_options<I, S>(opts: I) -> u64
242where
243 I: IntoIterator<Item = S>,
244 S: AsRef<str>,
245{
246 let mut h = FnvHasher::new();
247 for opt in opts {
248 let bytes = opt.as_ref().as_bytes();
249 h.write_u64(bytes.len() as u64);
251 h.write(bytes);
252 h.write_u8(0xff);
254 }
255 h.finish()
256}
257
258fn default_cache_dir() -> PathBuf {
259 if let Some(xdg) = std::env::var_os("XDG_CACHE_HOME") {
260 let p = PathBuf::from(xdg);
261 if !p.as_os_str().is_empty() {
262 return p.join("atomr-accel").join("nvrtc");
263 }
264 }
265 if let Some(cache) = dirs::cache_dir() {
266 return cache.join("atomr-accel").join("nvrtc");
267 }
268 std::env::temp_dir().join("atomr-accel").join("nvrtc")
269}
270
271const FNV_OFFSET_BASIS_64: u64 = 0xcbf2_9ce4_8422_2325;
277const FNV_PRIME_64: u64 = 0x0000_0100_0000_01b3;
278
279struct FnvHasher(u64);
280
281impl FnvHasher {
282 fn new() -> Self {
283 Self(FNV_OFFSET_BASIS_64)
284 }
285}
286
287impl Hasher for FnvHasher {
288 fn finish(&self) -> u64 {
289 self.0
290 }
291 fn write(&mut self, bytes: &[u8]) {
292 let mut h = self.0;
293 for &b in bytes {
294 h ^= b as u64;
295 h = h.wrapping_mul(FNV_PRIME_64);
296 }
297 self.0 = h;
298 }
299}
300
301#[cfg(test)]
306mod tests {
307 use super::*;
308 use tempfile::tempdir;
309
310 fn sample_kernel(seed: u8) -> CachedKernel {
311 CachedKernel::new(vec![seed; 64], Some(vec![seed.wrapping_add(1); 32]))
312 }
313
314 fn key(source_hash: u64, arch: u32, options_hash: u64) -> NvrtcCacheKey {
315 NvrtcCacheKey {
316 source_hash,
317 arch,
318 options_hash,
319 }
320 }
321
322 #[test]
323 fn round_trip_via_with_dir() {
324 let tmp = tempdir().unwrap();
325 let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
326 let k = key(0xdead_beef, 80, 0x1234);
327 let v = sample_kernel(7);
328
329 assert!(cache.get(k).is_none(), "cold cache should miss");
330
331 cache.insert(k, v.clone()).unwrap();
332
333 let got = cache.get(k).expect("hot lookup must hit");
334 assert_eq!(got.ptx, v.ptx);
335 assert_eq!(got.cubin, v.cubin);
336 assert_eq!(got.atomr_accel_version, env!("CARGO_PKG_VERSION"));
337 }
338
339 #[test]
340 fn cache_persists_across_fresh_handles() {
341 let tmp = tempdir().unwrap();
342 let dir = tmp.path().to_path_buf();
343 let k = key(0x9999, 90, 0xabcd);
344 let v = sample_kernel(42);
345
346 {
347 let cache = NvrtcCache::with_dir(dir.clone()).unwrap();
348 cache.insert(k, v.clone()).unwrap();
349 } let cache2 = NvrtcCache::with_dir(dir).unwrap();
352 let got = cache2.get(k).expect("disk-backed entry must survive");
353 assert_eq!(got.ptx, v.ptx);
354 assert_eq!(got.cubin, v.cubin);
355 }
356
357 #[test]
358 fn distinct_keys_distinct_paths() {
359 let tmp = tempdir().unwrap();
360 let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
361
362 let k_src = key(1, 80, 0);
363 let k_arch = key(0, 90, 0);
364 let k_opts = key(0, 80, 1);
365 let k_zero = key(0, 80, 0);
366
367 let p_src = cache.entry_path(&k_src);
368 let p_arch = cache.entry_path(&k_arch);
369 let p_opts = cache.entry_path(&k_opts);
370 let p_zero = cache.entry_path(&k_zero);
371
372 assert_ne!(p_src, p_arch);
373 assert_ne!(p_src, p_opts);
374 assert_ne!(p_src, p_zero);
375 assert_ne!(p_arch, p_opts);
376 assert_ne!(p_arch, p_zero);
377 assert_ne!(p_opts, p_zero);
378
379 cache.insert(k_src, sample_kernel(1)).unwrap();
381 cache.insert(k_arch, sample_kernel(2)).unwrap();
382 cache.insert(k_opts, sample_kernel(3)).unwrap();
383 cache.insert(k_zero, sample_kernel(4)).unwrap();
384
385 assert!(p_src.exists());
386 assert!(p_arch.exists());
387 assert!(p_opts.exists());
388 assert!(p_zero.exists());
389
390 assert_eq!(cache.get(k_src).unwrap().ptx, vec![1u8; 64]);
392 assert_eq!(cache.get(k_arch).unwrap().ptx, vec![2u8; 64]);
393 assert_eq!(cache.get(k_opts).unwrap().ptx, vec![3u8; 64]);
394 assert_eq!(cache.get(k_zero).unwrap().ptx, vec![4u8; 64]);
395 }
396
397 #[test]
398 fn version_mismatch_rejected_on_load() {
399 let tmp = tempdir().unwrap();
400 let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
401 let k = key(11, 80, 22);
402
403 let stale = CachedKernel {
405 ptx: vec![0xaa; 16],
406 cubin: None,
407 atomr_accel_version: "0.0.0-impossible".to_string(),
408 };
409 let bytes = bincode::serialize(&stale).unwrap();
410 let path = cache.entry_path(&k);
411 fs::write(&path, &bytes).unwrap();
412
413 cache.clear_in_memory();
415 assert!(
416 cache.get(k).is_none(),
417 "entry with mismatched atomr_accel_version must not be returned"
418 );
419 }
420
421 #[test]
422 fn hash_source_is_deterministic() {
423 let a = hash_source("__global__ void k() {}");
424 let b = hash_source("__global__ void k() {}");
425 let c = hash_source("__global__ void other() {}");
426 assert_eq!(a, b);
427 assert_ne!(a, c);
428 }
429
430 #[test]
431 fn hash_options_is_deterministic_and_order_sensitive() {
432 let a = hash_options(["-std=c++17", "--use_fast_math"]);
433 let b = hash_options(["-std=c++17", "--use_fast_math"]);
434 let c = hash_options(["--use_fast_math", "-std=c++17"]);
435 let d = hash_options(["-std=c++17"]);
436 let e = hash_options(Vec::<&str>::new());
437
438 assert_eq!(a, b, "same input must produce same hash");
439 assert_ne!(a, c, "option order must change the hash");
440 assert_ne!(a, d, "option count must change the hash");
441 assert_ne!(a, e);
442 assert_ne!(d, e);
443
444 let split1 = hash_options(["ab", "c"]);
446 let split2 = hash_options(["a", "bc"]);
447 assert_ne!(split1, split2);
448 }
449
450 #[test]
451 fn clear_in_memory_keeps_disk() {
452 let tmp = tempdir().unwrap();
453 let cache = NvrtcCache::with_dir(tmp.path().to_path_buf()).unwrap();
454 let k = key(7, 80, 7);
455 cache.insert(k, sample_kernel(9)).unwrap();
456 cache.clear_in_memory();
457 let got = cache.get(k).expect("disk entry survives clear_in_memory");
458 assert_eq!(got.ptx, vec![9u8; 64]);
459 }
460}