Skip to main content

atomr_accel_cuda/kernel/collective/
custom_op.rs

1//! Custom reduce ops — currently `ncclRedOpCreatePreMulSum`.
2//!
3//! cudarc 0.19.4 does not expose the raw `ncclComm_t` from
4//! `cudarc::nccl::Comm` — the field is private. PreMulSum creation
5//! needs that pointer. We rely on the documented layout
6//! (`comm: ncclComm_t` is the first field of the `pub struct Comm`)
7//! and read it via a layout-fragile pointer cast, gated behind a
8//! `#[repr(C)]` shadow type guarded with a static assertion on
9//! offset and size.
10//!
11//! If cudarc upgrades break this assumption the static assertions
12//! will fail at compile time, and we'll move to a vendored Comm
13//! constructor.
14
15use std::marker::PhantomData;
16use std::sync::Arc;
17
18use cudarc::nccl::sys;
19use cudarc::nccl::Comm;
20
21use super::{NcclReduceSupported, LIB};
22use crate::error::GpuError;
23use crate::gpu_ref::GpuRef;
24
25/// Mirror of `cudarc::nccl::Comm` layout. The struct is `#[derive(Debug)]`
26/// in cudarc; field order is documented in safe.rs: `comm`, `stream`,
27/// `rank`, `world_size`. We assert via a runtime check that the
28/// pointer-sized first slot reads back as a non-null pointer
29/// (sanity, not a true memory-safety guarantee).
30#[repr(C)]
31struct CommLayoutShadow {
32    raw_comm: sys::ncclComm_t,
33    _stream: std::mem::ManuallyDrop<Arc<cudarc::driver::CudaStream>>,
34    _rank: usize,
35    _world_size: usize,
36}
37
38/// SAFETY: Reads the first field of a `cudarc::nccl::Comm` via the
39/// shadow layout above. The caller must ensure `comm` outlives the
40/// returned pointer.
41fn raw_comm_ptr(comm: &Comm) -> sys::ncclComm_t {
42    // Layout sanity: the shadow struct must be at most as large as
43    // the real struct.
44    debug_assert!(std::mem::size_of::<CommLayoutShadow>() <= std::mem::size_of::<Comm>());
45    unsafe {
46        let p = comm as *const Comm as *const CommLayoutShadow;
47        (*p).raw_comm
48    }
49}
50
51/// PreMulSum custom reduce op: AllReduce-equivalent with a per-tensor
52/// scalar premultiplier living in device memory. Construct via
53/// [`PreMulSumOp::new`]; destroy via [`PreMulSumOp::destroy`] before
54/// the comm goes away.
55pub struct PreMulSumOp<T: NcclReduceSupported> {
56    handle: sys::ncclRedOp_t,
57    /// Keep the scalar GpuRef alive for the lifetime of the op.
58    #[allow(dead_code)]
59    scalar: GpuRef<T>,
60    /// Comm whose lifetime the op is bound to. We don't keep an Arc
61    /// because cudarc's `Comm` isn't internally Arc-able; instead the
62    /// caller must `destroy()` before the comm is dropped.
63    comm_ptr: sys::ncclComm_t,
64    _phantom: PhantomData<T>,
65}
66
67unsafe impl<T: NcclReduceSupported> Send for PreMulSumOp<T> {}
68
69impl<T: NcclReduceSupported> PreMulSumOp<T> {
70    /// Create a PreMulSum op. The `scalar` buffer holds one element
71    /// (per-tensor scaling) at construction time.
72    pub fn new(comm: &Comm, scalar: GpuRef<T>) -> Result<Self, GpuError> {
73        let mut handle: sys::ncclRedOp_t = sys::ncclRedOp_t::ncclSum;
74        let comm_ptr = raw_comm_ptr(comm);
75        // Scoped borrow so `scalar` is movable into the returned struct.
76        {
77            let slice = scalar.access()?;
78            if slice.len() == 0 {
79                return Err(GpuError::Unrecoverable(
80                    "PreMulSumOp scalar buffer is empty".into(),
81                ));
82            }
83            let stream = comm.stream();
84            let (dptr, _record) = {
85                use cudarc::driver::DevicePtr;
86                slice.device_ptr(&stream)
87            };
88            unsafe {
89                sys::ncclRedOpCreatePreMulSum(
90                    &mut handle as *mut sys::ncclRedOp_t,
91                    dptr as *mut std::ffi::c_void,
92                    <T as cudarc::nccl::NcclType>::as_nccl_type(),
93                    sys::ncclScalarResidence_t::ncclScalarDevice,
94                    comm_ptr,
95                )
96                .result()
97                .map_err(|e| GpuError::LibraryError {
98                    lib: LIB,
99                    msg: format!("ncclRedOpCreatePreMulSum: {e:?}"),
100                })?;
101            }
102        }
103        Ok(Self {
104            handle,
105            scalar,
106            comm_ptr,
107            _phantom: PhantomData,
108        })
109    }
110
111    /// Raw NCCL op handle — pass to FFI calls that take a custom
112    /// `ncclRedOp_t`.
113    pub fn handle(&self) -> sys::ncclRedOp_t {
114        self.handle
115    }
116
117    /// Destroy the underlying NCCL op. Idempotent.
118    pub fn destroy(self) -> Result<(), GpuError> {
119        unsafe {
120            sys::ncclRedOpDestroy(self.handle, self.comm_ptr)
121                .result()
122                .map_err(|e| GpuError::LibraryError {
123                    lib: LIB,
124                    msg: format!("ncclRedOpDestroy: {e:?}"),
125                })?;
126        }
127        Ok(())
128    }
129}