Skip to main content

atomr_accel_cuda/kernel/collective/
capabilities.rs

1//! Runtime probe for NCCL capabilities.
2//!
3//! Surfaces version + opt-in feature gates: fp8 reduction (NCCL >=
4//! 2.20), NVLS (NCCL >= 2.18 on supported topologies), SHARP. The
5//! probe is best-effort: if NCCL isn't loadable on this host (e.g.
6//! a CPU-only CI runner), the probe returns
7//! [`NcclCapabilities::zeroed`] rather than panicking.
8
9/// Static description of the loaded NCCL library's capabilities.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub struct NcclCapabilities {
12    /// `(major, minor, patch)`. `(0, 0, 0)` if NCCL isn't loadable.
13    pub version: (i32, i32, i32),
14    /// True iff `nccl-fp8` feature is enabled and NCCL >= 2.20.
15    pub has_fp8: bool,
16    /// True iff `nccl-nvls` feature is enabled. Whether NVLS is
17    /// actually usable depends on topology — this flag indicates
18    /// only that the build path is compiled in.
19    pub has_nvls: bool,
20    /// SHARP support is reported via NCCL_NET_PLUGIN; this probe
21    /// reports `false` until we wire the plugin query.
22    pub has_sharp: bool,
23}
24
25impl NcclCapabilities {
26    /// All-zero capabilities — the value returned when NCCL isn't
27    /// initialised on this host.
28    pub fn zeroed() -> Self {
29        Self::default()
30    }
31}
32
33/// Best-effort capability probe. Calls `ncclGetVersion` via cudarc's
34/// safe wrapper; on any error returns [`NcclCapabilities::zeroed`].
35pub fn probe_capabilities() -> NcclCapabilities {
36    let version_int =
37        std::panic::catch_unwind(cudarc::nccl::result::get_nccl_version).unwrap_or(Ok(0));
38    let v = match version_int {
39        Ok(v) => v,
40        Err(_) => return NcclCapabilities::zeroed(),
41    };
42    if v == 0 {
43        return NcclCapabilities::zeroed();
44    }
45    // NCCL packs version as MAJOR*10000 + MINOR*100 + PATCH (NCCL >= 2.9)
46    // or MAJOR*1000 + MINOR*100 + PATCH (older). Detect by magnitude.
47    let (major, minor, patch) = if v >= 20000 {
48        (v / 10000, (v / 100) % 100, v % 100)
49    } else {
50        (v / 1000, (v / 100) % 10, v % 100)
51    };
52
53    let supports_fp8 = (major, minor) >= (2, 20);
54
55    NcclCapabilities {
56        version: (major, minor, patch),
57        has_fp8: cfg!(feature = "nccl-fp8") && supports_fp8,
58        has_nvls: cfg!(feature = "nccl-nvls"),
59        has_sharp: false,
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    /// On a host without a working NCCL install, `probe_capabilities`
68    /// must not panic — it must return `zeroed()`.
69    #[test]
70    fn probe_returns_zeroed_when_nccl_uninit() {
71        // Whatever the host has, the probe must succeed without
72        // panicking and either return zeros (no NCCL) or a real
73        // version. Both shapes are acceptable; we only assert the
74        // call returns.
75        let caps = probe_capabilities();
76        if caps.version == (0, 0, 0) {
77            assert_eq!(caps, NcclCapabilities::zeroed());
78        } else {
79            // Real NCCL: version major must be sane (>=2 in practice
80            // but we accept >=1 to avoid version-pinning the test).
81            assert!(caps.version.0 >= 1);
82        }
83    }
84}