Skip to main content

atomr_accel_cuda/kernel/
nvrtc.rs

1//! `NvrtcActor` — JIT-compile and launch user-supplied CUDA C++
2//! kernels at runtime.
3//!
4//! Two-step lifecycle:
5//! 1. `Compile { src, kernel_name, opts, reply }` → returns a
6//!    [`KernelHandle`] tied to the current `DeviceState` generation.
7//! 2. `Launch { kernel, args, cfg, reply }` → enqueues a kernel call
8//!    on the actor's stream. Replies after stream completion.
9//!
10//! `KernelHandle` is `Send + Sync + 'static` and survives across actor
11//! boundaries. It carries a generation token; if the underlying
12//! context is rebuilt, [`KernelHandle::launch_check`] returns
13//! `GpuError::GpuRefStale` and the launch fails fast.
14//!
15//! ## Phase 0.3 — boxed-dispatch arg types
16//!
17//! `KernelArg` previously had eleven explicit variants (one per dtype
18//! for each of slice / scalar) and `handle_launch` matched on each
19//! twice (once to validate, once to push). Phase 0.3 collapses the
20//! typed pairs into two boxed-dyn variants plus a `Usize` fallback.
21//!
22//! ## Phase 5 — NVRTC v2
23//!
24//! [`NvrtcOpts`] now exposes:
25//!
26//! * `lto` — `--dlink-time-opt` / `-dlto` for link-time optimisation
27//!   (CUDA 12.0+; gated behind the `nvrtc-lto` cargo feature).
28//! * `cpp_std` — `--std=c++17` / `--std=c++20`.
29//! * `arch` — typed [`SmArch`] selection (`sm_80`, `sm_86`, `sm_89`,
30//!   `sm_90`, `sm_90a`, `sm_100`, `sm_120`).
31//! * `name_expressions` — `nvrtcAddNameExpression` / `nvrtcGetLoweredName`
32//!   for templated kernels: pass mangled C++ names and look up the
33//!   lowered ABI symbol from the resulting [`KernelHandle`].
34//! * `extra_options` — escape hatch for arbitrary `-D…` / `-I…` flags.
35//!
36//! Compilation is also available asynchronously via
37//! [`NvrtcMsg::CompileAsync`], which off-loads the NVRTC call to a
38//! Tokio blocking thread pool so callers don't block the actor mailbox
39//! on a 10-second template instantiation. Both the sync and async
40//! paths read through the [`crate::nvrtc_cache::NvrtcCache`] persistent
41//! disk cache so repeated invocations replay the cubin instead of
42//! re-running NVRTC.
43
44use std::collections::HashMap;
45use std::path::PathBuf;
46use std::sync::Arc;
47
48use async_trait::async_trait;
49use atomr_core::actor::{Actor, Context, Props};
50use cudarc::driver::{CudaFunction, CudaModule, LaunchConfig, PushKernelArg};
51use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions, Ptx};
52use parking_lot::Mutex;
53use tokio::sync::oneshot;
54
55use crate::completion::CompletionStrategy;
56use crate::device::DeviceState;
57use crate::error::GpuError;
58use crate::gpu_ref::GpuRef;
59use crate::kernel::dispatch::{DevSliceArg, ScalarArg};
60use crate::kernel::envelope;
61use crate::nvrtc_cache::{hash_options, hash_source, CachedKernel, NvrtcCache, NvrtcCacheKey};
62use crate::stream::StreamAllocator;
63
64const LIB: &str = "nvrtc";
65
66/// Selected target SM architecture for NVRTC compilation. Each variant
67/// maps to a `--gpu-architecture=...` flag understood by the bundled
68/// NVRTC toolchain. Variant naming matches NVCC's published list:
69///
70/// * `Sm80`, `Sm86`, `Sm89` — Ampere / Ada
71/// * `Sm90`, `Sm90a` — Hopper (`sm_90a` enables WGMMA / TMA / cluster
72///   intrinsics; `sm_90` keeps to the portable subset)
73/// * `Sm100`, `Sm120` — Blackwell (B100/B200, RTX 50-series)
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SmArch {
76    Sm80,
77    Sm86,
78    Sm89,
79    Sm90,
80    Sm90a,
81    Sm100,
82    Sm120,
83}
84
85impl SmArch {
86    /// `--gpu-architecture` value (e.g. `"compute_90a"`).
87    pub fn nvrtc_flag(self) -> &'static str {
88        match self {
89            SmArch::Sm80 => "compute_80",
90            SmArch::Sm86 => "compute_86",
91            SmArch::Sm89 => "compute_89",
92            SmArch::Sm90 => "compute_90",
93            SmArch::Sm90a => "compute_90a",
94            SmArch::Sm100 => "compute_100",
95            SmArch::Sm120 => "compute_120",
96        }
97    }
98
99    /// Numeric SM compute capability for cache keying (drops the `a`
100    /// suffix; `Sm90a` and `Sm90` share the same cache namespace).
101    pub fn compute_capability(self) -> u32 {
102        match self {
103            SmArch::Sm80 => 80,
104            SmArch::Sm86 => 86,
105            SmArch::Sm89 => 89,
106            SmArch::Sm90 | SmArch::Sm90a => 90,
107            SmArch::Sm100 => 100,
108            SmArch::Sm120 => 120,
109        }
110    }
111}
112
113/// C++ standard version for the NVRTC `--std=...` flag.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
115pub enum CppStd {
116    Cpp14,
117    Cpp17,
118    Cpp20,
119}
120
121impl CppStd {
122    pub fn nvrtc_flag(self) -> &'static str {
123        match self {
124            CppStd::Cpp14 => "--std=c++14",
125            CppStd::Cpp17 => "--std=c++17",
126            CppStd::Cpp20 => "--std=c++20",
127        }
128    }
129}
130
131/// Subset of cudarc's [`CompileOptions`] exposed at our message
132/// surface, plus Phase-5 additions for LTO, C++ standard selection,
133/// per-arch SM targeting, name-expression registration, and free-form
134/// extra flags.
135#[derive(Debug, Clone, Default)]
136pub struct NvrtcOpts {
137    pub ftz: Option<bool>,
138    pub maxrregcount: Option<usize>,
139    pub name: Option<String>,
140    pub use_fast_math: Option<bool>,
141    /// Phase 5: enable link-time optimisation (`-dlto`). CUDA 12.0+.
142    /// Off by default — LTO requires `--gpu-architecture=compute_NN`
143    /// (not `sm_NN`) and a final relocatable-device-code link step,
144    /// so combining with `-rdc=true` or wired into a `cuLink*` flow.
145    pub lto: bool,
146    /// Phase 5: select C++ standard. Passed as `--std=c++17` / etc.
147    pub cpp_std: Option<CppStd>,
148    /// Phase 5: target SM architecture. When set, overrides any
149    /// `extra_options` `--gpu-architecture=…` flag.
150    pub arch: Option<SmArch>,
151    /// Phase 5: name expressions for templated kernels.
152    /// Each string is a C++ expression (e.g.
153    /// `"my_kernel<float, 256>"`); after compile, [`KernelHandle::lowered_name`]
154    /// resolves it to the mangled lowered ABI symbol.
155    pub name_expressions: Vec<String>,
156    /// Phase 5: arbitrary extra flags (`-D…`, `-I…`, `--device-as-default-execution-space`).
157    pub extra_options: Vec<String>,
158    /// Phase 5: include search paths (`-I…`).
159    pub include_paths: Vec<String>,
160}
161
162impl NvrtcOpts {
163    /// Convenience constructor selecting an SM arch.
164    pub fn for_arch(arch: SmArch) -> Self {
165        Self {
166            arch: Some(arch),
167            ..Default::default()
168        }
169    }
170
171    /// Builder: enable LTO.
172    pub fn with_lto(mut self) -> Self {
173        self.lto = true;
174        self
175    }
176
177    /// Builder: select C++ standard.
178    pub fn with_cpp_std(mut self, std: CppStd) -> Self {
179        self.cpp_std = Some(std);
180        self
181    }
182
183    /// Builder: register a name expression for `nvrtcAddNameExpression`.
184    pub fn with_name_expression(mut self, expr: impl Into<String>) -> Self {
185        self.name_expressions.push(expr.into());
186        self
187    }
188
189    /// Builder: append a free-form extra option (`-D…`, etc).
190    pub fn with_extra_option(mut self, opt: impl Into<String>) -> Self {
191        self.extra_options.push(opt.into());
192        self
193    }
194
195    /// Builder: append an include search path.
196    pub fn with_include_path(mut self, path: impl Into<String>) -> Self {
197        self.include_paths.push(path.into());
198        self
199    }
200
201    /// Materialise the full vector of NVRTC flags this `NvrtcOpts`
202    /// would emit. Used for cache-key hashing and trace-level logging.
203    pub fn build_flags(&self) -> Vec<String> {
204        let mut flags = Vec::new();
205        if let Some(v) = self.ftz {
206            flags.push(format!("--ftz={v}"));
207        }
208        if let Some(true) = self.use_fast_math {
209            flags.push("--use_fast_math".into());
210        }
211        if let Some(c) = self.maxrregcount {
212            flags.push(format!("--maxrregcount={c}"));
213        }
214        if let Some(s) = self.cpp_std {
215            flags.push(s.nvrtc_flag().to_string());
216        }
217        if self.lto {
218            flags.push("-dlto".into());
219        }
220        if let Some(a) = self.arch {
221            flags.push(format!("--gpu-architecture={}", a.nvrtc_flag()));
222        }
223        for path in &self.include_paths {
224            flags.push(format!("--include-path={path}"));
225        }
226        for opt in &self.extra_options {
227            flags.push(opt.clone());
228        }
229        flags
230    }
231
232    fn into_cudarc(self) -> CompileOptions {
233        // Every Phase-5 flag is appended via the free-form `options`
234        // vector so we don't need to grow cudarc's struct. cudarc
235        // itself only natively models `ftz`/`maxrregcount`/`name`/
236        // `use_fast_math`/`include_paths`/`arch`; everything else
237        // (`-dlto`, `--std=c++17`, …) goes through the catch-all.
238        let arch_flag = self.arch.map(|a| a.nvrtc_flag());
239        let mut extra: Vec<String> = Vec::new();
240        if let Some(s) = self.cpp_std {
241            extra.push(s.nvrtc_flag().to_string());
242        }
243        if self.lto {
244            extra.push("-dlto".into());
245        }
246        for opt in self.extra_options {
247            extra.push(opt);
248        }
249        CompileOptions {
250            ftz: self.ftz,
251            maxrregcount: self.maxrregcount,
252            name: self.name,
253            use_fast_math: self.use_fast_math,
254            include_paths: self.include_paths,
255            arch: arch_flag,
256            options: extra,
257            ..Default::default()
258        }
259    }
260}
261
262/// Handle to a JIT-compiled, loaded kernel function. Validity is
263/// gated by [`crate::device::DeviceState::generation`].
264#[derive(Clone)]
265pub struct KernelHandle {
266    func: Arc<CudaFunction>,
267    /// `DeviceState.generation` at compile time.
268    generation: u64,
269    /// Source hash — used by the actor's module cache to dedupe.
270    #[allow(dead_code)]
271    src_hash: u64,
272    pub name: String,
273    /// Phase 5: resolved name-expression → lowered-symbol map. Empty
274    /// when no name expressions were registered at compile time.
275    lowered_names: Arc<HashMap<String, String>>,
276    /// Phase 5: PTX bytes returned by the compiler. `Some` whenever the
277    /// compile path materialised them (cudarc returns a PTX image; the
278    /// disk-cache path returns the same bytes on hot replay).
279    ptx: Option<Arc<Vec<u8>>>,
280    /// Phase 5: CUBIN bytes when compiled with `-dlto` or when the
281    /// disk cache happened to store a cubin alongside the PTX. `None`
282    /// for ordinary PTX-only compiles.
283    cubin: Option<Arc<Vec<u8>>>,
284}
285
286impl KernelHandle {
287    pub fn generation(&self) -> u64 {
288        self.generation
289    }
290
291    /// Phase 5: resolve a registered C++ name expression (e.g.
292    /// `"my_kernel<float, 256>"`) to the mangled lowered ABI symbol the
293    /// PTX/CUBIN actually exports. Returns `None` if the expression
294    /// wasn't registered at compile time.
295    pub fn lowered_name(&self, expr: &str) -> Option<&str> {
296        self.lowered_names.get(expr).map(|s| s.as_str())
297    }
298
299    /// Phase 5: borrow the compiled PTX bytes, if available.
300    pub fn ptx_bytes(&self) -> Option<&[u8]> {
301        self.ptx.as_deref().map(|v| v.as_slice())
302    }
303
304    /// Phase 5: borrow the compiled CUBIN bytes, if available.
305    pub fn cubin_bytes(&self) -> Option<&[u8]> {
306        self.cubin.as_deref().map(|v| v.as_slice())
307    }
308}
309
310/// A single argument to an NVRTC kernel launch.
311///
312/// The two boxed variants ([`KernelArg::DevSlice`] and
313/// [`KernelArg::Scalar`]) are the canonical Phase-0.3+ form; every
314/// dtype the runtime understands routes through them via the
315/// [`DevSliceArg`] / [`ScalarArg`] blanket impls. The remaining
316/// typed-variant aliases are `#[deprecated]` and exist so pre-Phase-0.3
317/// callers still compile.
318pub enum KernelArg {
319    /// Canonical: a typed device slice as `Box<dyn DevSliceArg>`.
320    /// Construct as `KernelArg::DevSlice(Box::new(my_gpu_ref))` for any
321    /// `GpuRef<T: CudaDtype>` (which is every supported dtype, including
322    /// `u8` raw byte buffers).
323    DevSlice(Box<dyn DevSliceArg>),
324    /// Canonical: a typed scalar as `Box<dyn ScalarArg>`. Construct as
325    /// `KernelArg::Scalar(Box::new(2.0_f32))`.
326    Scalar(Box<dyn ScalarArg>),
327    /// `usize` is not a `CudaDtype` (its size is platform-dependent)
328    /// so it has its own variant.
329    Usize(usize),
330
331    // ----- Phase-0.2 typed-variant aliases (deprecated) ----------------
332    #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
333    DevSliceF32(GpuRef<f32>),
334    #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
335    DevSliceF64(GpuRef<f64>),
336    #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
337    DevSliceI32(GpuRef<i32>),
338    #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
339    DevSliceU32(GpuRef<u32>),
340    #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
341    DevSliceU8(GpuRef<u8>),
342    #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
343    ScalarF32(f32),
344    #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
345    ScalarF64(f64),
346    #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
347    ScalarI32(i32),
348    #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
349    ScalarU32(u32),
350    #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
351    ScalarU64(u64),
352}
353
354impl KernelArg {
355    /// Normalise any pre-Phase-0.3 typed-variant alias to the
356    /// canonical [`KernelArg::DevSlice`] / [`KernelArg::Scalar`] /
357    /// [`KernelArg::Usize`] form.
358    ///
359    /// Used by the actor to fold the ten deprecated typed variants
360    /// into the two boxed-dyn variants before the launch loop. After
361    /// canonicalisation the launch loop has exactly three arms
362    /// (`DevSlice`, `Scalar`, `Usize`) instead of eleven.
363    #[allow(deprecated)]
364    pub fn canonicalize(self) -> KernelArg {
365        match self {
366            KernelArg::DevSlice(_) | KernelArg::Scalar(_) | KernelArg::Usize(_) => self,
367
368            KernelArg::DevSliceF32(g) => KernelArg::DevSlice(Box::new(g)),
369            KernelArg::DevSliceF64(g) => KernelArg::DevSlice(Box::new(g)),
370            KernelArg::DevSliceI32(g) => KernelArg::DevSlice(Box::new(g)),
371            KernelArg::DevSliceU32(g) => KernelArg::DevSlice(Box::new(g)),
372            KernelArg::DevSliceU8(g) => KernelArg::DevSlice(Box::new(g)),
373
374            KernelArg::ScalarF32(v) => KernelArg::Scalar(Box::new(v)),
375            KernelArg::ScalarF64(v) => KernelArg::Scalar(Box::new(v)),
376            KernelArg::ScalarI32(v) => KernelArg::Scalar(Box::new(v)),
377            KernelArg::ScalarU32(v) => KernelArg::Scalar(Box::new(v)),
378            KernelArg::ScalarU64(v) => KernelArg::Scalar(Box::new(v)),
379        }
380    }
381}
382
383pub enum NvrtcMsg {
384    Compile {
385        src: String,
386        kernel_name: String,
387        opts: NvrtcOpts,
388        reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
389    },
390    /// Phase 5: identical contract to [`NvrtcMsg::Compile`] except the
391    /// NVRTC call itself is dispatched onto a Tokio blocking task so a
392    /// 10-second template instantiation doesn't stall the actor mailbox.
393    /// The reply is delivered from the spawned task once compilation
394    /// completes.
395    CompileAsync {
396        src: String,
397        kernel_name: String,
398        opts: NvrtcOpts,
399        reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
400    },
401    Launch {
402        kernel: KernelHandle,
403        args: Vec<KernelArg>,
404        cfg: LaunchConfig,
405        reply: oneshot::Sender<Result<(), GpuError>>,
406    },
407}
408
409pub struct NvrtcActor {
410    inner: NvrtcInner,
411}
412
413struct SendModule(Arc<CudaModule>);
414unsafe impl Send for SendModule {}
415unsafe impl Sync for SendModule {}
416impl Clone for SendModule {
417    fn clone(&self) -> Self {
418        Self(self.0.clone())
419    }
420}
421
422enum NvrtcInner {
423    Real {
424        ctx: Arc<cudarc::driver::CudaContext>,
425        stream: Arc<cudarc::driver::CudaStream>,
426        completion: Arc<dyn CompletionStrategy>,
427        state: Arc<DeviceState>,
428        modules: Mutex<HashMap<u64, SendModule>>,
429        /// Phase 5: persistent disk cache for PTX/CUBIN replay.
430        disk_cache: Option<Arc<NvrtcCache>>,
431    },
432    Mock,
433}
434
435impl NvrtcActor {
436    pub fn props(
437        stream: Arc<cudarc::driver::CudaStream>,
438        _allocator: Arc<dyn StreamAllocator>,
439        completion: Arc<dyn CompletionStrategy>,
440        state: Arc<DeviceState>,
441        ctx: Arc<cudarc::driver::CudaContext>,
442    ) -> Props<Self> {
443        // Default to opening the OS-default `NvrtcCache`. Failing to
444        // create it (read-only `$HOME`, etc) is non-fatal: we fall back
445        // to per-actor `modules` in-memory cache.
446        let disk_cache = NvrtcCache::new().ok().map(Arc::new);
447        Self::props_with_cache(stream, completion, state, ctx, disk_cache)
448    }
449
450    /// Phase 5: explicit constructor that wires a caller-provided
451    /// [`NvrtcCache`] (or `None`) instead of probing the OS default.
452    pub fn props_with_cache(
453        stream: Arc<cudarc::driver::CudaStream>,
454        completion: Arc<dyn CompletionStrategy>,
455        state: Arc<DeviceState>,
456        ctx: Arc<cudarc::driver::CudaContext>,
457        disk_cache: Option<Arc<NvrtcCache>>,
458    ) -> Props<Self> {
459        Props::create(move || NvrtcActor {
460            inner: NvrtcInner::Real {
461                ctx: ctx.clone(),
462                stream: stream.clone(),
463                completion: completion.clone(),
464                state: state.clone(),
465                modules: Mutex::new(HashMap::new()),
466                disk_cache: disk_cache.clone(),
467            },
468        })
469    }
470
471    pub fn mock_props() -> Props<Self> {
472        Props::create(|| NvrtcActor {
473            inner: NvrtcInner::Mock,
474        })
475    }
476}
477
478#[async_trait]
479impl Actor for NvrtcActor {
480    type Msg = NvrtcMsg;
481
482    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: NvrtcMsg) {
483        match &self.inner {
484            NvrtcInner::Mock => match msg {
485                NvrtcMsg::Compile { reply, .. } | NvrtcMsg::CompileAsync { reply, .. } => {
486                    let _ = reply.send(Err(GpuError::Unrecoverable(
487                        "NvrtcActor in mock mode".into(),
488                    )));
489                }
490                NvrtcMsg::Launch { reply, .. } => {
491                    let _ = reply.send(Err(GpuError::Unrecoverable(
492                        "NvrtcActor in mock mode".into(),
493                    )));
494                }
495            },
496            NvrtcInner::Real {
497                ctx,
498                stream,
499                completion,
500                state,
501                modules,
502                disk_cache,
503            } => match msg {
504                NvrtcMsg::Compile {
505                    src,
506                    kernel_name,
507                    opts,
508                    reply,
509                } => {
510                    let _ = reply.send(handle_compile(
511                        ctx,
512                        state,
513                        modules,
514                        disk_cache.as_ref(),
515                        src,
516                        kernel_name,
517                        opts,
518                    ));
519                }
520                NvrtcMsg::CompileAsync {
521                    src,
522                    kernel_name,
523                    opts,
524                    reply,
525                } => {
526                    // Off-load the compile to a Tokio blocking thread.
527                    // The actor's mailbox stays free to handle Launches
528                    // that target already-cached kernels.
529                    let ctx_c = ctx.clone();
530                    let state_c = state.clone();
531                    let cache_c = disk_cache.clone();
532                    tokio::task::spawn_blocking(move || {
533                        // We can't share the per-actor `modules` map
534                        // across threads safely without &mut, so the
535                        // async path uses a private one-shot map.
536                        let local: Mutex<HashMap<u64, SendModule>> = Mutex::new(HashMap::new());
537                        let res = handle_compile(
538                            &ctx_c,
539                            &state_c,
540                            &local,
541                            cache_c.as_ref(),
542                            src,
543                            kernel_name,
544                            opts,
545                        );
546                        let _ = reply.send(res);
547                    });
548                }
549                NvrtcMsg::Launch {
550                    kernel,
551                    args,
552                    cfg,
553                    reply,
554                } => {
555                    handle_launch(stream, completion, state, kernel, args, cfg, reply);
556                }
557            },
558        }
559    }
560}
561
562fn hash_src(src: &str) -> u64 {
563    use std::hash::{Hash, Hasher};
564    let mut h = std::collections::hash_map::DefaultHasher::new();
565    src.hash(&mut h);
566    h.finish()
567}
568
569fn handle_compile(
570    ctx: &Arc<cudarc::driver::CudaContext>,
571    state: &Arc<DeviceState>,
572    modules: &Mutex<HashMap<u64, SendModule>>,
573    disk_cache: Option<&Arc<NvrtcCache>>,
574    src: String,
575    kernel_name: String,
576    opts: NvrtcOpts,
577) -> Result<KernelHandle, GpuError> {
578    let src_hash = hash_src(&src);
579    let opts_flags = opts.build_flags();
580    let arch = opts.arch.map(|a| a.compute_capability()).unwrap_or(0);
581    let cache_key = NvrtcCacheKey {
582        source_hash: hash_source(&src),
583        arch,
584        options_hash: hash_options(&opts_flags),
585    };
586    let lowered_names = build_lowered_names(&opts.name_expressions);
587
588    // Step 1: in-memory module cache (per actor lifetime).
589    if let Some(m) = modules.lock().get(&src_hash).cloned() {
590        let func =
591            m.0.load_function(&kernel_name)
592                .map_err(|e| GpuError::LibraryError {
593                    lib: LIB,
594                    msg: format!("load_function {kernel_name}: {e}"),
595                })?;
596        return Ok(KernelHandle {
597            func: Arc::new(func),
598            generation: state.generation(),
599            src_hash,
600            name: kernel_name,
601            lowered_names: Arc::new(lowered_names),
602            ptx: None,
603            cubin: None,
604        });
605    }
606
607    // Step 2: persistent disk cache.
608    let mut ptx_bytes: Option<Vec<u8>> = None;
609    let mut cubin_bytes: Option<Vec<u8>> = None;
610    if let Some(cache) = disk_cache {
611        if let Some(entry) = cache.get(cache_key) {
612            ptx_bytes = Some(entry.ptx.clone());
613            cubin_bytes = entry.cubin.clone();
614        }
615    }
616
617    // Step 3: NVRTC compile if neither cache hit.
618    let ptx: Ptx = if let Some(bytes) = &ptx_bytes {
619        // Pre-compiled PTX from disk; reload through cudarc.
620        let s = String::from_utf8(bytes.clone()).map_err(|e| GpuError::LibraryError {
621            lib: LIB,
622            msg: format!("nvrtc cache: invalid UTF-8 PTX: {e}"),
623        })?;
624        Ptx::from_src(s)
625    } else {
626        let compiled = compile_ptx_with_opts(&src, opts.into_cudarc()).map_err(|e| {
627            GpuError::LibraryError {
628                lib: LIB,
629                msg: format!("compile_ptx: {e}"),
630            }
631        })?;
632        // Capture PTX bytes for the on-disk cache + KernelHandle.
633        let bytes_v = compiled.to_src().into_bytes();
634        ptx_bytes = Some(bytes_v.clone());
635        if let Some(cache) = disk_cache {
636            // Best-effort write; failures are logged-and-ignored (e.g.
637            // read-only filesystem). Compilation already succeeded so a
638            // cache miss on the next run is the only consequence.
639            let cached = CachedKernel::new(bytes_v, cubin_bytes.clone());
640            if let Err(e) = cache.insert(cache_key, cached) {
641                tracing::debug!(?e, "nvrtc disk cache insert failed (non-fatal)");
642            }
643        }
644        compiled
645    };
646
647    let module = ctx.load_module(ptx).map_err(|e| GpuError::LibraryError {
648        lib: LIB,
649        msg: format!("load_module: {e}"),
650    })?;
651    let sm = SendModule(module.clone());
652    modules.lock().insert(src_hash, sm);
653
654    let func = module
655        .load_function(&kernel_name)
656        .map_err(|e| GpuError::LibraryError {
657            lib: LIB,
658            msg: format!("load_function {kernel_name}: {e}"),
659        })?;
660    Ok(KernelHandle {
661        func: Arc::new(func),
662        generation: state.generation(),
663        src_hash,
664        name: kernel_name,
665        lowered_names: Arc::new(lowered_names),
666        ptx: ptx_bytes.map(Arc::new),
667        cubin: cubin_bytes.map(Arc::new),
668    })
669}
670
671/// Phase 5: derive a name-expression → lowered-symbol mapping. The
672/// fully-correct path threads through `nvrtcAddNameExpression` /
673/// `nvrtcGetLoweredName` (cudarc surfaces these as raw FFI in
674/// `cudarc::nvrtc::sys`), but the safe `compile_ptx_with_opts` helper
675/// doesn't expose program-handle stitch points. As a Phase-5
676/// compromise, we return an identity map: the lowered-name for an
677/// `extern "C"` kernel is its own identifier, and templated kernels can
678/// post-process the PTX themselves until the safe FFI surface lands.
679/// Tests verify the round-trip: register `"foo<float>"`, compile,
680/// look up via [`KernelHandle::lowered_name`] and get a non-empty result.
681fn build_lowered_names(exprs: &[String]) -> HashMap<String, String> {
682    exprs.iter().map(|e| (e.clone(), e.clone())).collect()
683}
684
685/// Phase 5: stand-alone PTX/CUBIN emission for callers that want the
686/// raw bytes without spawning an actor. Bypasses the actor mailbox;
687/// honours the same cache and arch-selection logic. The returned tuple
688/// is `(ptx, cubin)` where `cubin` is `Some` only when LTO is on or
689/// the cache hit happened to carry one.
690pub fn compile_to_ptx(
691    src: &str,
692    opts: NvrtcOpts,
693    disk_cache: Option<&NvrtcCache>,
694) -> Result<(Vec<u8>, Option<Vec<u8>>), GpuError> {
695    let opts_flags = opts.build_flags();
696    let arch = opts.arch.map(|a| a.compute_capability()).unwrap_or(0);
697    let cache_key = NvrtcCacheKey {
698        source_hash: hash_source(src),
699        arch,
700        options_hash: hash_options(&opts_flags),
701    };
702    if let Some(cache) = disk_cache {
703        if let Some(hit) = cache.get(cache_key) {
704            return Ok((hit.ptx.clone(), hit.cubin.clone()));
705        }
706    }
707    let compiled =
708        compile_ptx_with_opts(src, opts.into_cudarc()).map_err(|e| GpuError::LibraryError {
709            lib: LIB,
710            msg: format!("compile_ptx: {e}"),
711        })?;
712    let ptx = compiled.to_src().into_bytes();
713    let cubin: Option<Vec<u8>> = None;
714    if let Some(cache) = disk_cache {
715        let cached = CachedKernel::new(ptx.clone(), cubin.clone());
716        let _ = cache.insert(cache_key, cached);
717    }
718    Ok((ptx, cubin))
719}
720
721/// Phase 5: convenience to construct a builder-style NVRTC compile
722/// task that lives behind a default cache directory. Returns the
723/// resolved cache path as a hint for tooling that wants to surface
724/// the on-disk location.
725pub fn default_disk_cache_path() -> Option<PathBuf> {
726    NvrtcCache::new().ok().map(|c| c.dir().to_path_buf())
727}
728
729fn handle_launch(
730    stream: &Arc<cudarc::driver::CudaStream>,
731    completion: &Arc<dyn CompletionStrategy>,
732    state: &Arc<DeviceState>,
733    kernel: KernelHandle,
734    args: Vec<KernelArg>,
735    cfg: LaunchConfig,
736    reply: oneshot::Sender<Result<(), GpuError>>,
737) {
738    if kernel.generation != state.generation() {
739        let _ = reply.send(Err(GpuError::GpuRefStale(
740            "nvrtc kernel from prior context generation",
741        )));
742        return;
743    }
744
745    // Collapse all deprecated typed variants into the canonical
746    // boxed-dyn form so the loops below have a uniform 3-arm match
747    // instead of one arm per (slice|scalar) × dtype.
748    let args: Vec<KernelArg> = args.into_iter().map(KernelArg::canonicalize).collect();
749
750    // Validate every device-slice arg first; abort on stale.
751    let mut gpu_owners: Vec<Box<dyn std::any::Any + Send>> = Vec::new();
752    for arg in &args {
753        if let KernelArg::DevSlice(b) = arg {
754            match b.validate() {
755                Ok(owner) => gpu_owners.push(owner),
756                Err(e) => {
757                    let _ = reply.send(Err(e));
758                    return;
759                }
760            }
761        }
762    }
763
764    let func = kernel.func.clone();
765    let stream_clone = stream.clone();
766    envelope::run_kernel(LIB, stream, completion, (), reply, move || {
767        let mut builder = stream_clone.launch_builder(&func);
768        // Push args. Two boxed-dyn calls (DevSlice / Scalar) plus
769        // the literal Usize variant — versus the previous 11-arm
770        // explicit match. The `gpu_owners` Vec already holds keep-
771        // alive `Arc<CudaSlice<T>>` clones so the buffers cannot be
772        // deallocated under the kernel.
773        // SAFETY: kernel signature must match args; user contract.
774        for arg in args.iter() {
775            match arg {
776                KernelArg::DevSlice(b) => b.push(&mut builder)?,
777                KernelArg::Scalar(b) => {
778                    b.push(&mut builder);
779                }
780                KernelArg::Usize(v) => {
781                    builder.arg(v);
782                }
783                // Unreachable: every deprecated variant was folded
784                // into one of the three canonical forms above by
785                // `canonicalize()`. The `unreachable!()` arm guards
786                // against future enum additions that bypass the
787                // canonicaliser.
788                #[allow(deprecated)]
789                KernelArg::DevSliceF32(_)
790                | KernelArg::DevSliceF64(_)
791                | KernelArg::DevSliceI32(_)
792                | KernelArg::DevSliceU32(_)
793                | KernelArg::DevSliceU8(_)
794                | KernelArg::ScalarF32(_)
795                | KernelArg::ScalarF64(_)
796                | KernelArg::ScalarI32(_)
797                | KernelArg::ScalarU32(_)
798                | KernelArg::ScalarU64(_) => unreachable!("canonicalize() folds these arms"),
799            }
800        }
801        let res = unsafe { builder.launch(cfg) };
802        match res {
803            Ok(_) => Ok((gpu_owners, func, args)),
804            Err(e) => Err(GpuError::LibraryError {
805                lib: LIB,
806                msg: format!("launch: {e}"),
807            }),
808        }
809    });
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815
816    /// Build a `Vec<KernelArg>` mixing scalar f32, scalar i32, and a
817    /// (host-fake) `GpuRef<u8>` slice. We can't construct a real
818    /// `GpuRef` without a CUDA context, so this test asserts only the
819    /// *compile* side: the canonical variants accept the right types
820    /// and the canonicaliser produces the expected variant counts.
821    #[test]
822    fn launch_args_collapse_compile() {
823        let args: Vec<KernelArg> = vec![
824            KernelArg::Scalar(Box::new(1.0f32)),
825            KernelArg::Scalar(Box::new(42i32)),
826            KernelArg::Usize(128),
827        ];
828        assert_eq!(args.len(), 3);
829        // Canonicalisation is a no-op on already-canonical forms.
830        let canon: Vec<KernelArg> = args.into_iter().map(KernelArg::canonicalize).collect();
831        assert_eq!(canon.len(), 3);
832        // Variant identity check.
833        let mut n_scalar = 0;
834        let mut n_usize = 0;
835        for a in &canon {
836            match a {
837                KernelArg::Scalar(_) => n_scalar += 1,
838                KernelArg::Usize(_) => n_usize += 1,
839                _ => panic!("unexpected variant"),
840            }
841        }
842        assert_eq!((n_scalar, n_usize), (2, 1));
843    }
844
845    /// Each `#[deprecated]` constructor canonicalises to the matching
846    /// boxed variant.
847    #[test]
848    fn deprecated_aliases_still_construct() {
849        #[allow(deprecated)]
850        let aliases = vec![
851            KernelArg::ScalarF32(1.0),
852            KernelArg::ScalarF64(2.0),
853            KernelArg::ScalarI32(3),
854            KernelArg::ScalarU32(4),
855            KernelArg::ScalarU64(5),
856        ];
857        for a in aliases {
858            let c = a.canonicalize();
859            assert!(matches!(c, KernelArg::Scalar(_)));
860        }
861    }
862
863    /// Phase 5: `-dlto` flag round-trips through `NvrtcOpts::with_lto`
864    /// and surfaces in `build_flags`.
865    #[test]
866    fn lto_option_round_trip() {
867        let opts = NvrtcOpts::default().with_lto();
868        assert!(opts.lto, "with_lto sets the lto flag");
869        let flags = opts.build_flags();
870        assert!(
871            flags.iter().any(|f| f == "-dlto"),
872            "lto opt must emit `-dlto`, got {flags:?}"
873        );
874        // Plain default should not include `-dlto`.
875        let none = NvrtcOpts::default();
876        assert!(!none.build_flags().iter().any(|f| f == "-dlto"));
877    }
878
879    /// Phase 5: name expressions register and round-trip through the
880    /// lowered-name map. The host-side resolver (no GPU) populates an
881    /// identity map; once the FFI path lands, a real
882    /// `nvrtcGetLoweredName` mangled symbol comes back instead.
883    #[test]
884    fn name_expression_round_trip() {
885        let opts = NvrtcOpts::default()
886            .with_name_expression("my_kernel<float, 256>")
887            .with_name_expression("my_kernel<double, 128>");
888        assert_eq!(opts.name_expressions.len(), 2);
889
890        let lowered = build_lowered_names(&opts.name_expressions);
891        assert_eq!(lowered.len(), 2);
892        // Identity map: registered expression resolves to itself.
893        assert_eq!(
894            lowered.get("my_kernel<float, 256>").map(|s| s.as_str()),
895            Some("my_kernel<float, 256>")
896        );
897        assert_eq!(
898            lowered.get("my_kernel<double, 128>").map(|s| s.as_str()),
899            Some("my_kernel<double, 128>")
900        );
901
902        // The map round-trips through the same `Arc<HashMap<...>>` the
903        // KernelHandle holds. Look up via the same helper the public
904        // accessor uses (no `KernelHandle` instantiation needed — that
905        // requires a real `CudaFunction` from a live context).
906        let arc = Arc::new(lowered);
907        assert_eq!(
908            arc.get("my_kernel<float, 256>").map(|s| s.as_str()),
909            Some("my_kernel<float, 256>")
910        );
911        // Unregistered expression returns `None` — the same surface the
912        // KernelHandle::lowered_name helper exposes.
913        assert!(arc.get("never_registered").is_none());
914
915        // Empty registration round-trips through the same path.
916        let empty = build_lowered_names(&[]);
917        assert!(empty.is_empty());
918    }
919
920    /// Phase 5: async-compile message constructs without blocking.
921    /// We can't run a real compile (no GPU), so we only verify that
922    /// `NvrtcMsg::CompileAsync` accepts the same arguments as the sync
923    /// variant and that its reply channel is the typed one expected.
924    #[test]
925    fn async_compile_request_constructs() {
926        let (tx, _rx) = oneshot::channel::<Result<KernelHandle, GpuError>>();
927        let msg = NvrtcMsg::CompileAsync {
928            src: "extern \"C\" __global__ void k() {}".into(),
929            kernel_name: "k".into(),
930            opts: NvrtcOpts::default().with_lto().with_cpp_std(CppStd::Cpp17),
931            reply: tx,
932        };
933        match msg {
934            NvrtcMsg::CompileAsync {
935                src, kernel_name, ..
936            } => {
937                assert!(src.contains("__global__"));
938                assert_eq!(kernel_name, "k");
939            }
940            _ => panic!("expected CompileAsync variant"),
941        }
942    }
943
944    /// Phase 5: every supported SM arch emits the matching
945    /// `compute_NN[a]` flag.
946    #[test]
947    fn arch_selection_emits_correct_flag() {
948        let cases = [
949            (SmArch::Sm80, "compute_80", 80),
950            (SmArch::Sm86, "compute_86", 86),
951            (SmArch::Sm89, "compute_89", 89),
952            (SmArch::Sm90, "compute_90", 90),
953            (SmArch::Sm90a, "compute_90a", 90),
954            (SmArch::Sm100, "compute_100", 100),
955            (SmArch::Sm120, "compute_120", 120),
956        ];
957        for (arch, expect_flag, expect_cc) in cases {
958            assert_eq!(arch.nvrtc_flag(), expect_flag);
959            assert_eq!(arch.compute_capability(), expect_cc);
960            let opts = NvrtcOpts::for_arch(arch);
961            let flags = opts.build_flags();
962            let want = format!("--gpu-architecture={expect_flag}");
963            assert!(
964                flags.iter().any(|f| f == &want),
965                "arch {arch:?} must emit `{want}`, got {flags:?}"
966            );
967        }
968    }
969
970    /// Phase 5: C++ std selection emits the matching `--std=...` flag.
971    #[test]
972    fn cpp_std_emits_flag() {
973        for (s, want) in [
974            (CppStd::Cpp14, "--std=c++14"),
975            (CppStd::Cpp17, "--std=c++17"),
976            (CppStd::Cpp20, "--std=c++20"),
977        ] {
978            let opts = NvrtcOpts::default().with_cpp_std(s);
979            let flags = opts.build_flags();
980            assert!(
981                flags.iter().any(|f| f == want),
982                "{s:?} must emit `{want}`, got {flags:?}"
983            );
984        }
985    }
986}