1use std::collections::HashMap;
45use std::path::PathBuf;
46use std::sync::Arc;
47
48use async_trait::async_trait;
49use atomr_core::actor::{Actor, Context, Props};
50use cudarc::driver::{CudaFunction, CudaModule, LaunchConfig, PushKernelArg};
51use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions, Ptx};
52use parking_lot::Mutex;
53use tokio::sync::oneshot;
54
55use crate::completion::CompletionStrategy;
56use crate::device::DeviceState;
57use crate::error::GpuError;
58use crate::gpu_ref::GpuRef;
59use crate::kernel::dispatch::{DevSliceArg, ScalarArg};
60use crate::kernel::envelope;
61use crate::nvrtc_cache::{hash_options, hash_source, CachedKernel, NvrtcCache, NvrtcCacheKey};
62use crate::stream::StreamAllocator;
63
64const LIB: &str = "nvrtc";
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SmArch {
76 Sm80,
77 Sm86,
78 Sm89,
79 Sm90,
80 Sm90a,
81 Sm100,
82 Sm120,
83}
84
85impl SmArch {
86 pub fn nvrtc_flag(self) -> &'static str {
88 match self {
89 SmArch::Sm80 => "compute_80",
90 SmArch::Sm86 => "compute_86",
91 SmArch::Sm89 => "compute_89",
92 SmArch::Sm90 => "compute_90",
93 SmArch::Sm90a => "compute_90a",
94 SmArch::Sm100 => "compute_100",
95 SmArch::Sm120 => "compute_120",
96 }
97 }
98
99 pub fn compute_capability(self) -> u32 {
102 match self {
103 SmArch::Sm80 => 80,
104 SmArch::Sm86 => 86,
105 SmArch::Sm89 => 89,
106 SmArch::Sm90 | SmArch::Sm90a => 90,
107 SmArch::Sm100 => 100,
108 SmArch::Sm120 => 120,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
115pub enum CppStd {
116 Cpp14,
117 Cpp17,
118 Cpp20,
119}
120
121impl CppStd {
122 pub fn nvrtc_flag(self) -> &'static str {
123 match self {
124 CppStd::Cpp14 => "--std=c++14",
125 CppStd::Cpp17 => "--std=c++17",
126 CppStd::Cpp20 => "--std=c++20",
127 }
128 }
129}
130
131#[derive(Debug, Clone, Default)]
136pub struct NvrtcOpts {
137 pub ftz: Option<bool>,
138 pub maxrregcount: Option<usize>,
139 pub name: Option<String>,
140 pub use_fast_math: Option<bool>,
141 pub lto: bool,
146 pub cpp_std: Option<CppStd>,
148 pub arch: Option<SmArch>,
151 pub name_expressions: Vec<String>,
156 pub extra_options: Vec<String>,
158 pub include_paths: Vec<String>,
160}
161
162impl NvrtcOpts {
163 pub fn for_arch(arch: SmArch) -> Self {
165 Self {
166 arch: Some(arch),
167 ..Default::default()
168 }
169 }
170
171 pub fn with_lto(mut self) -> Self {
173 self.lto = true;
174 self
175 }
176
177 pub fn with_cpp_std(mut self, std: CppStd) -> Self {
179 self.cpp_std = Some(std);
180 self
181 }
182
183 pub fn with_name_expression(mut self, expr: impl Into<String>) -> Self {
185 self.name_expressions.push(expr.into());
186 self
187 }
188
189 pub fn with_extra_option(mut self, opt: impl Into<String>) -> Self {
191 self.extra_options.push(opt.into());
192 self
193 }
194
195 pub fn with_include_path(mut self, path: impl Into<String>) -> Self {
197 self.include_paths.push(path.into());
198 self
199 }
200
201 pub fn build_flags(&self) -> Vec<String> {
204 let mut flags = Vec::new();
205 if let Some(v) = self.ftz {
206 flags.push(format!("--ftz={v}"));
207 }
208 if let Some(true) = self.use_fast_math {
209 flags.push("--use_fast_math".into());
210 }
211 if let Some(c) = self.maxrregcount {
212 flags.push(format!("--maxrregcount={c}"));
213 }
214 if let Some(s) = self.cpp_std {
215 flags.push(s.nvrtc_flag().to_string());
216 }
217 if self.lto {
218 flags.push("-dlto".into());
219 }
220 if let Some(a) = self.arch {
221 flags.push(format!("--gpu-architecture={}", a.nvrtc_flag()));
222 }
223 for path in &self.include_paths {
224 flags.push(format!("--include-path={path}"));
225 }
226 for opt in &self.extra_options {
227 flags.push(opt.clone());
228 }
229 flags
230 }
231
232 fn into_cudarc(self) -> CompileOptions {
233 let arch_flag = self.arch.map(|a| a.nvrtc_flag());
239 let mut extra: Vec<String> = Vec::new();
240 if let Some(s) = self.cpp_std {
241 extra.push(s.nvrtc_flag().to_string());
242 }
243 if self.lto {
244 extra.push("-dlto".into());
245 }
246 for opt in self.extra_options {
247 extra.push(opt);
248 }
249 CompileOptions {
250 ftz: self.ftz,
251 maxrregcount: self.maxrregcount,
252 name: self.name,
253 use_fast_math: self.use_fast_math,
254 include_paths: self.include_paths,
255 arch: arch_flag,
256 options: extra,
257 ..Default::default()
258 }
259 }
260}
261
262#[derive(Clone)]
265pub struct KernelHandle {
266 func: Arc<CudaFunction>,
267 generation: u64,
269 #[allow(dead_code)]
271 src_hash: u64,
272 pub name: String,
273 lowered_names: Arc<HashMap<String, String>>,
276 ptx: Option<Arc<Vec<u8>>>,
280 cubin: Option<Arc<Vec<u8>>>,
284}
285
286impl KernelHandle {
287 pub fn generation(&self) -> u64 {
288 self.generation
289 }
290
291 pub fn lowered_name(&self, expr: &str) -> Option<&str> {
296 self.lowered_names.get(expr).map(|s| s.as_str())
297 }
298
299 pub fn ptx_bytes(&self) -> Option<&[u8]> {
301 self.ptx.as_deref().map(|v| v.as_slice())
302 }
303
304 pub fn cubin_bytes(&self) -> Option<&[u8]> {
306 self.cubin.as_deref().map(|v| v.as_slice())
307 }
308}
309
310pub enum KernelArg {
319 DevSlice(Box<dyn DevSliceArg>),
324 Scalar(Box<dyn ScalarArg>),
327 Usize(usize),
330
331 #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
333 DevSliceF32(GpuRef<f32>),
334 #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
335 DevSliceF64(GpuRef<f64>),
336 #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
337 DevSliceI32(GpuRef<i32>),
338 #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
339 DevSliceU32(GpuRef<u32>),
340 #[deprecated(note = "use KernelArg::DevSlice with GpuRef directly")]
341 DevSliceU8(GpuRef<u8>),
342 #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
343 ScalarF32(f32),
344 #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
345 ScalarF64(f64),
346 #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
347 ScalarI32(i32),
348 #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
349 ScalarU32(u32),
350 #[deprecated(note = "use KernelArg::Scalar with the scalar value directly")]
351 ScalarU64(u64),
352}
353
354impl KernelArg {
355 #[allow(deprecated)]
364 pub fn canonicalize(self) -> KernelArg {
365 match self {
366 KernelArg::DevSlice(_) | KernelArg::Scalar(_) | KernelArg::Usize(_) => self,
367
368 KernelArg::DevSliceF32(g) => KernelArg::DevSlice(Box::new(g)),
369 KernelArg::DevSliceF64(g) => KernelArg::DevSlice(Box::new(g)),
370 KernelArg::DevSliceI32(g) => KernelArg::DevSlice(Box::new(g)),
371 KernelArg::DevSliceU32(g) => KernelArg::DevSlice(Box::new(g)),
372 KernelArg::DevSliceU8(g) => KernelArg::DevSlice(Box::new(g)),
373
374 KernelArg::ScalarF32(v) => KernelArg::Scalar(Box::new(v)),
375 KernelArg::ScalarF64(v) => KernelArg::Scalar(Box::new(v)),
376 KernelArg::ScalarI32(v) => KernelArg::Scalar(Box::new(v)),
377 KernelArg::ScalarU32(v) => KernelArg::Scalar(Box::new(v)),
378 KernelArg::ScalarU64(v) => KernelArg::Scalar(Box::new(v)),
379 }
380 }
381}
382
383pub enum NvrtcMsg {
384 Compile {
385 src: String,
386 kernel_name: String,
387 opts: NvrtcOpts,
388 reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
389 },
390 CompileAsync {
396 src: String,
397 kernel_name: String,
398 opts: NvrtcOpts,
399 reply: oneshot::Sender<Result<KernelHandle, GpuError>>,
400 },
401 Launch {
402 kernel: KernelHandle,
403 args: Vec<KernelArg>,
404 cfg: LaunchConfig,
405 reply: oneshot::Sender<Result<(), GpuError>>,
406 },
407}
408
409pub struct NvrtcActor {
410 inner: NvrtcInner,
411}
412
413struct SendModule(Arc<CudaModule>);
414unsafe impl Send for SendModule {}
415unsafe impl Sync for SendModule {}
416impl Clone for SendModule {
417 fn clone(&self) -> Self {
418 Self(self.0.clone())
419 }
420}
421
422enum NvrtcInner {
423 Real {
424 ctx: Arc<cudarc::driver::CudaContext>,
425 stream: Arc<cudarc::driver::CudaStream>,
426 completion: Arc<dyn CompletionStrategy>,
427 state: Arc<DeviceState>,
428 modules: Mutex<HashMap<u64, SendModule>>,
429 disk_cache: Option<Arc<NvrtcCache>>,
431 },
432 Mock,
433}
434
435impl NvrtcActor {
436 pub fn props(
437 stream: Arc<cudarc::driver::CudaStream>,
438 _allocator: Arc<dyn StreamAllocator>,
439 completion: Arc<dyn CompletionStrategy>,
440 state: Arc<DeviceState>,
441 ctx: Arc<cudarc::driver::CudaContext>,
442 ) -> Props<Self> {
443 let disk_cache = NvrtcCache::new().ok().map(Arc::new);
447 Self::props_with_cache(stream, completion, state, ctx, disk_cache)
448 }
449
450 pub fn props_with_cache(
453 stream: Arc<cudarc::driver::CudaStream>,
454 completion: Arc<dyn CompletionStrategy>,
455 state: Arc<DeviceState>,
456 ctx: Arc<cudarc::driver::CudaContext>,
457 disk_cache: Option<Arc<NvrtcCache>>,
458 ) -> Props<Self> {
459 Props::create(move || NvrtcActor {
460 inner: NvrtcInner::Real {
461 ctx: ctx.clone(),
462 stream: stream.clone(),
463 completion: completion.clone(),
464 state: state.clone(),
465 modules: Mutex::new(HashMap::new()),
466 disk_cache: disk_cache.clone(),
467 },
468 })
469 }
470
471 pub fn mock_props() -> Props<Self> {
472 Props::create(|| NvrtcActor {
473 inner: NvrtcInner::Mock,
474 })
475 }
476}
477
478#[async_trait]
479impl Actor for NvrtcActor {
480 type Msg = NvrtcMsg;
481
482 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: NvrtcMsg) {
483 match &self.inner {
484 NvrtcInner::Mock => match msg {
485 NvrtcMsg::Compile { reply, .. } | NvrtcMsg::CompileAsync { reply, .. } => {
486 let _ = reply.send(Err(GpuError::Unrecoverable(
487 "NvrtcActor in mock mode".into(),
488 )));
489 }
490 NvrtcMsg::Launch { reply, .. } => {
491 let _ = reply.send(Err(GpuError::Unrecoverable(
492 "NvrtcActor in mock mode".into(),
493 )));
494 }
495 },
496 NvrtcInner::Real {
497 ctx,
498 stream,
499 completion,
500 state,
501 modules,
502 disk_cache,
503 } => match msg {
504 NvrtcMsg::Compile {
505 src,
506 kernel_name,
507 opts,
508 reply,
509 } => {
510 let _ = reply.send(handle_compile(
511 ctx,
512 state,
513 modules,
514 disk_cache.as_ref(),
515 src,
516 kernel_name,
517 opts,
518 ));
519 }
520 NvrtcMsg::CompileAsync {
521 src,
522 kernel_name,
523 opts,
524 reply,
525 } => {
526 let ctx_c = ctx.clone();
530 let state_c = state.clone();
531 let cache_c = disk_cache.clone();
532 tokio::task::spawn_blocking(move || {
533 let local: Mutex<HashMap<u64, SendModule>> = Mutex::new(HashMap::new());
537 let res = handle_compile(
538 &ctx_c,
539 &state_c,
540 &local,
541 cache_c.as_ref(),
542 src,
543 kernel_name,
544 opts,
545 );
546 let _ = reply.send(res);
547 });
548 }
549 NvrtcMsg::Launch {
550 kernel,
551 args,
552 cfg,
553 reply,
554 } => {
555 handle_launch(stream, completion, state, kernel, args, cfg, reply);
556 }
557 },
558 }
559 }
560}
561
562fn hash_src(src: &str) -> u64 {
563 use std::hash::{Hash, Hasher};
564 let mut h = std::collections::hash_map::DefaultHasher::new();
565 src.hash(&mut h);
566 h.finish()
567}
568
569fn handle_compile(
570 ctx: &Arc<cudarc::driver::CudaContext>,
571 state: &Arc<DeviceState>,
572 modules: &Mutex<HashMap<u64, SendModule>>,
573 disk_cache: Option<&Arc<NvrtcCache>>,
574 src: String,
575 kernel_name: String,
576 opts: NvrtcOpts,
577) -> Result<KernelHandle, GpuError> {
578 let src_hash = hash_src(&src);
579 let opts_flags = opts.build_flags();
580 let arch = opts.arch.map(|a| a.compute_capability()).unwrap_or(0);
581 let cache_key = NvrtcCacheKey {
582 source_hash: hash_source(&src),
583 arch,
584 options_hash: hash_options(&opts_flags),
585 };
586 let lowered_names = build_lowered_names(&opts.name_expressions);
587
588 if let Some(m) = modules.lock().get(&src_hash).cloned() {
590 let func =
591 m.0.load_function(&kernel_name)
592 .map_err(|e| GpuError::LibraryError {
593 lib: LIB,
594 msg: format!("load_function {kernel_name}: {e}"),
595 })?;
596 return Ok(KernelHandle {
597 func: Arc::new(func),
598 generation: state.generation(),
599 src_hash,
600 name: kernel_name,
601 lowered_names: Arc::new(lowered_names),
602 ptx: None,
603 cubin: None,
604 });
605 }
606
607 let mut ptx_bytes: Option<Vec<u8>> = None;
609 let mut cubin_bytes: Option<Vec<u8>> = None;
610 if let Some(cache) = disk_cache {
611 if let Some(entry) = cache.get(cache_key) {
612 ptx_bytes = Some(entry.ptx.clone());
613 cubin_bytes = entry.cubin.clone();
614 }
615 }
616
617 let ptx: Ptx = if let Some(bytes) = &ptx_bytes {
619 let s = String::from_utf8(bytes.clone()).map_err(|e| GpuError::LibraryError {
621 lib: LIB,
622 msg: format!("nvrtc cache: invalid UTF-8 PTX: {e}"),
623 })?;
624 Ptx::from_src(s)
625 } else {
626 let compiled = compile_ptx_with_opts(&src, opts.into_cudarc()).map_err(|e| {
627 GpuError::LibraryError {
628 lib: LIB,
629 msg: format!("compile_ptx: {e}"),
630 }
631 })?;
632 let bytes_v = compiled.to_src().into_bytes();
634 ptx_bytes = Some(bytes_v.clone());
635 if let Some(cache) = disk_cache {
636 let cached = CachedKernel::new(bytes_v, cubin_bytes.clone());
640 if let Err(e) = cache.insert(cache_key, cached) {
641 tracing::debug!(?e, "nvrtc disk cache insert failed (non-fatal)");
642 }
643 }
644 compiled
645 };
646
647 let module = ctx.load_module(ptx).map_err(|e| GpuError::LibraryError {
648 lib: LIB,
649 msg: format!("load_module: {e}"),
650 })?;
651 let sm = SendModule(module.clone());
652 modules.lock().insert(src_hash, sm);
653
654 let func = module
655 .load_function(&kernel_name)
656 .map_err(|e| GpuError::LibraryError {
657 lib: LIB,
658 msg: format!("load_function {kernel_name}: {e}"),
659 })?;
660 Ok(KernelHandle {
661 func: Arc::new(func),
662 generation: state.generation(),
663 src_hash,
664 name: kernel_name,
665 lowered_names: Arc::new(lowered_names),
666 ptx: ptx_bytes.map(Arc::new),
667 cubin: cubin_bytes.map(Arc::new),
668 })
669}
670
671fn build_lowered_names(exprs: &[String]) -> HashMap<String, String> {
682 exprs.iter().map(|e| (e.clone(), e.clone())).collect()
683}
684
685pub fn compile_to_ptx(
691 src: &str,
692 opts: NvrtcOpts,
693 disk_cache: Option<&NvrtcCache>,
694) -> Result<(Vec<u8>, Option<Vec<u8>>), GpuError> {
695 let opts_flags = opts.build_flags();
696 let arch = opts.arch.map(|a| a.compute_capability()).unwrap_or(0);
697 let cache_key = NvrtcCacheKey {
698 source_hash: hash_source(src),
699 arch,
700 options_hash: hash_options(&opts_flags),
701 };
702 if let Some(cache) = disk_cache {
703 if let Some(hit) = cache.get(cache_key) {
704 return Ok((hit.ptx.clone(), hit.cubin.clone()));
705 }
706 }
707 let compiled =
708 compile_ptx_with_opts(src, opts.into_cudarc()).map_err(|e| GpuError::LibraryError {
709 lib: LIB,
710 msg: format!("compile_ptx: {e}"),
711 })?;
712 let ptx = compiled.to_src().into_bytes();
713 let cubin: Option<Vec<u8>> = None;
714 if let Some(cache) = disk_cache {
715 let cached = CachedKernel::new(ptx.clone(), cubin.clone());
716 let _ = cache.insert(cache_key, cached);
717 }
718 Ok((ptx, cubin))
719}
720
721pub fn default_disk_cache_path() -> Option<PathBuf> {
726 NvrtcCache::new().ok().map(|c| c.dir().to_path_buf())
727}
728
729fn handle_launch(
730 stream: &Arc<cudarc::driver::CudaStream>,
731 completion: &Arc<dyn CompletionStrategy>,
732 state: &Arc<DeviceState>,
733 kernel: KernelHandle,
734 args: Vec<KernelArg>,
735 cfg: LaunchConfig,
736 reply: oneshot::Sender<Result<(), GpuError>>,
737) {
738 if kernel.generation != state.generation() {
739 let _ = reply.send(Err(GpuError::GpuRefStale(
740 "nvrtc kernel from prior context generation",
741 )));
742 return;
743 }
744
745 let args: Vec<KernelArg> = args.into_iter().map(KernelArg::canonicalize).collect();
749
750 let mut gpu_owners: Vec<Box<dyn std::any::Any + Send>> = Vec::new();
752 for arg in &args {
753 if let KernelArg::DevSlice(b) = arg {
754 match b.validate() {
755 Ok(owner) => gpu_owners.push(owner),
756 Err(e) => {
757 let _ = reply.send(Err(e));
758 return;
759 }
760 }
761 }
762 }
763
764 let func = kernel.func.clone();
765 let stream_clone = stream.clone();
766 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
767 let mut builder = stream_clone.launch_builder(&func);
768 for arg in args.iter() {
775 match arg {
776 KernelArg::DevSlice(b) => b.push(&mut builder)?,
777 KernelArg::Scalar(b) => {
778 b.push(&mut builder);
779 }
780 KernelArg::Usize(v) => {
781 builder.arg(v);
782 }
783 #[allow(deprecated)]
789 KernelArg::DevSliceF32(_)
790 | KernelArg::DevSliceF64(_)
791 | KernelArg::DevSliceI32(_)
792 | KernelArg::DevSliceU32(_)
793 | KernelArg::DevSliceU8(_)
794 | KernelArg::ScalarF32(_)
795 | KernelArg::ScalarF64(_)
796 | KernelArg::ScalarI32(_)
797 | KernelArg::ScalarU32(_)
798 | KernelArg::ScalarU64(_) => unreachable!("canonicalize() folds these arms"),
799 }
800 }
801 let res = unsafe { builder.launch(cfg) };
802 match res {
803 Ok(_) => Ok((gpu_owners, func, args)),
804 Err(e) => Err(GpuError::LibraryError {
805 lib: LIB,
806 msg: format!("launch: {e}"),
807 }),
808 }
809 });
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815
816 #[test]
822 fn launch_args_collapse_compile() {
823 let args: Vec<KernelArg> = vec![
824 KernelArg::Scalar(Box::new(1.0f32)),
825 KernelArg::Scalar(Box::new(42i32)),
826 KernelArg::Usize(128),
827 ];
828 assert_eq!(args.len(), 3);
829 let canon: Vec<KernelArg> = args.into_iter().map(KernelArg::canonicalize).collect();
831 assert_eq!(canon.len(), 3);
832 let mut n_scalar = 0;
834 let mut n_usize = 0;
835 for a in &canon {
836 match a {
837 KernelArg::Scalar(_) => n_scalar += 1,
838 KernelArg::Usize(_) => n_usize += 1,
839 _ => panic!("unexpected variant"),
840 }
841 }
842 assert_eq!((n_scalar, n_usize), (2, 1));
843 }
844
845 #[test]
848 fn deprecated_aliases_still_construct() {
849 #[allow(deprecated)]
850 let aliases = vec![
851 KernelArg::ScalarF32(1.0),
852 KernelArg::ScalarF64(2.0),
853 KernelArg::ScalarI32(3),
854 KernelArg::ScalarU32(4),
855 KernelArg::ScalarU64(5),
856 ];
857 for a in aliases {
858 let c = a.canonicalize();
859 assert!(matches!(c, KernelArg::Scalar(_)));
860 }
861 }
862
863 #[test]
866 fn lto_option_round_trip() {
867 let opts = NvrtcOpts::default().with_lto();
868 assert!(opts.lto, "with_lto sets the lto flag");
869 let flags = opts.build_flags();
870 assert!(
871 flags.iter().any(|f| f == "-dlto"),
872 "lto opt must emit `-dlto`, got {flags:?}"
873 );
874 let none = NvrtcOpts::default();
876 assert!(!none.build_flags().iter().any(|f| f == "-dlto"));
877 }
878
879 #[test]
884 fn name_expression_round_trip() {
885 let opts = NvrtcOpts::default()
886 .with_name_expression("my_kernel<float, 256>")
887 .with_name_expression("my_kernel<double, 128>");
888 assert_eq!(opts.name_expressions.len(), 2);
889
890 let lowered = build_lowered_names(&opts.name_expressions);
891 assert_eq!(lowered.len(), 2);
892 assert_eq!(
894 lowered.get("my_kernel<float, 256>").map(|s| s.as_str()),
895 Some("my_kernel<float, 256>")
896 );
897 assert_eq!(
898 lowered.get("my_kernel<double, 128>").map(|s| s.as_str()),
899 Some("my_kernel<double, 128>")
900 );
901
902 let arc = Arc::new(lowered);
907 assert_eq!(
908 arc.get("my_kernel<float, 256>").map(|s| s.as_str()),
909 Some("my_kernel<float, 256>")
910 );
911 assert!(arc.get("never_registered").is_none());
914
915 let empty = build_lowered_names(&[]);
917 assert!(empty.is_empty());
918 }
919
920 #[test]
925 fn async_compile_request_constructs() {
926 let (tx, _rx) = oneshot::channel::<Result<KernelHandle, GpuError>>();
927 let msg = NvrtcMsg::CompileAsync {
928 src: "extern \"C\" __global__ void k() {}".into(),
929 kernel_name: "k".into(),
930 opts: NvrtcOpts::default().with_lto().with_cpp_std(CppStd::Cpp17),
931 reply: tx,
932 };
933 match msg {
934 NvrtcMsg::CompileAsync {
935 src, kernel_name, ..
936 } => {
937 assert!(src.contains("__global__"));
938 assert_eq!(kernel_name, "k");
939 }
940 _ => panic!("expected CompileAsync variant"),
941 }
942 }
943
944 #[test]
947 fn arch_selection_emits_correct_flag() {
948 let cases = [
949 (SmArch::Sm80, "compute_80", 80),
950 (SmArch::Sm86, "compute_86", 86),
951 (SmArch::Sm89, "compute_89", 89),
952 (SmArch::Sm90, "compute_90", 90),
953 (SmArch::Sm90a, "compute_90a", 90),
954 (SmArch::Sm100, "compute_100", 100),
955 (SmArch::Sm120, "compute_120", 120),
956 ];
957 for (arch, expect_flag, expect_cc) in cases {
958 assert_eq!(arch.nvrtc_flag(), expect_flag);
959 assert_eq!(arch.compute_capability(), expect_cc);
960 let opts = NvrtcOpts::for_arch(arch);
961 let flags = opts.build_flags();
962 let want = format!("--gpu-architecture={expect_flag}");
963 assert!(
964 flags.iter().any(|f| f == &want),
965 "arch {arch:?} must emit `{want}`, got {flags:?}"
966 );
967 }
968 }
969
970 #[test]
972 fn cpp_std_emits_flag() {
973 for (s, want) in [
974 (CppStd::Cpp14, "--std=c++14"),
975 (CppStd::Cpp17, "--std=c++17"),
976 (CppStd::Cpp20, "--std=c++20"),
977 ] {
978 let opts = NvrtcOpts::default().with_cpp_std(s);
979 let flags = opts.build_flags();
980 assert!(
981 flags.iter().any(|f| f == want),
982 "{s:?} must emit `{want}`, got {flags:?}"
983 );
984 }
985 }
986}