atomr_accel_cuda/kernel/collective/capabilities.rs
1//! Runtime probe for NCCL capabilities.
2//!
3//! Surfaces version + opt-in feature gates: fp8 reduction (NCCL >=
4//! 2.20), NVLS (NCCL >= 2.18 on supported topologies), SHARP. The
5//! probe is best-effort: if NCCL isn't loadable on this host (e.g.
6//! a CPU-only CI runner), the probe returns
7//! [`NcclCapabilities::zeroed`] rather than panicking.
8
9/// Static description of the loaded NCCL library's capabilities.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub struct NcclCapabilities {
12 /// `(major, minor, patch)`. `(0, 0, 0)` if NCCL isn't loadable.
13 pub version: (i32, i32, i32),
14 /// True iff `nccl-fp8` feature is enabled and NCCL >= 2.20.
15 pub has_fp8: bool,
16 /// True iff `nccl-nvls` feature is enabled. Whether NVLS is
17 /// actually usable depends on topology — this flag indicates
18 /// only that the build path is compiled in.
19 pub has_nvls: bool,
20 /// SHARP support is reported via NCCL_NET_PLUGIN; this probe
21 /// reports `false` until we wire the plugin query.
22 pub has_sharp: bool,
23}
24
25impl NcclCapabilities {
26 /// All-zero capabilities — the value returned when NCCL isn't
27 /// initialised on this host.
28 pub fn zeroed() -> Self {
29 Self::default()
30 }
31}
32
33/// Best-effort capability probe. Calls `ncclGetVersion` via cudarc's
34/// safe wrapper; on any error returns [`NcclCapabilities::zeroed`].
35pub fn probe_capabilities() -> NcclCapabilities {
36 let version_int =
37 std::panic::catch_unwind(cudarc::nccl::result::get_nccl_version).unwrap_or(Ok(0));
38 let v = match version_int {
39 Ok(v) => v,
40 Err(_) => return NcclCapabilities::zeroed(),
41 };
42 if v == 0 {
43 return NcclCapabilities::zeroed();
44 }
45 // NCCL packs version as MAJOR*10000 + MINOR*100 + PATCH (NCCL >= 2.9)
46 // or MAJOR*1000 + MINOR*100 + PATCH (older). Detect by magnitude.
47 let (major, minor, patch) = if v >= 20000 {
48 (v / 10000, (v / 100) % 100, v % 100)
49 } else {
50 (v / 1000, (v / 100) % 10, v % 100)
51 };
52
53 let supports_fp8 = (major, minor) >= (2, 20);
54
55 NcclCapabilities {
56 version: (major, minor, patch),
57 has_fp8: cfg!(feature = "nccl-fp8") && supports_fp8,
58 has_nvls: cfg!(feature = "nccl-nvls"),
59 has_sharp: false,
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 /// On a host without a working NCCL install, `probe_capabilities`
68 /// must not panic — it must return `zeroed()`.
69 #[test]
70 fn probe_returns_zeroed_when_nccl_uninit() {
71 // Whatever the host has, the probe must succeed without
72 // panicking and either return zeros (no NCCL) or a real
73 // version. Both shapes are acceptable; we only assert the
74 // call returns.
75 let caps = probe_capabilities();
76 if caps.version == (0, 0, 0) {
77 assert_eq!(caps, NcclCapabilities::zeroed());
78 } else {
79 // Real NCCL: version major must be sane (>=2 in practice
80 // but we accept >=1 to avoid version-pinning the test).
81 assert!(caps.version.0 >= 1);
82 }
83 }
84}