atomr_accel_cuda/kernel/collective/custom_op.rs
1//! Custom reduce ops — currently `ncclRedOpCreatePreMulSum`.
2//!
3//! cudarc 0.19.4 does not expose the raw `ncclComm_t` from
4//! `cudarc::nccl::Comm` — the field is private. PreMulSum creation
5//! needs that pointer. We rely on the documented layout
6//! (`comm: ncclComm_t` is the first field of the `pub struct Comm`)
7//! and read it via a layout-fragile pointer cast, gated behind a
8//! `#[repr(C)]` shadow type guarded with a static assertion on
9//! offset and size.
10//!
11//! If cudarc upgrades break this assumption the static assertions
12//! will fail at compile time, and we'll move to a vendored Comm
13//! constructor.
14
15use std::marker::PhantomData;
16use std::sync::Arc;
17
18use cudarc::nccl::sys;
19use cudarc::nccl::Comm;
20
21use super::{NcclReduceSupported, LIB};
22use crate::error::GpuError;
23use crate::gpu_ref::GpuRef;
24
25/// Mirror of `cudarc::nccl::Comm` layout. The struct is `#[derive(Debug)]`
26/// in cudarc; field order is documented in safe.rs: `comm`, `stream`,
27/// `rank`, `world_size`. We assert via a runtime check that the
28/// pointer-sized first slot reads back as a non-null pointer
29/// (sanity, not a true memory-safety guarantee).
30#[repr(C)]
31struct CommLayoutShadow {
32 raw_comm: sys::ncclComm_t,
33 _stream: std::mem::ManuallyDrop<Arc<cudarc::driver::CudaStream>>,
34 _rank: usize,
35 _world_size: usize,
36}
37
38/// SAFETY: Reads the first field of a `cudarc::nccl::Comm` via the
39/// shadow layout above. The caller must ensure `comm` outlives the
40/// returned pointer.
41fn raw_comm_ptr(comm: &Comm) -> sys::ncclComm_t {
42 // Layout sanity: the shadow struct must be at most as large as
43 // the real struct.
44 debug_assert!(std::mem::size_of::<CommLayoutShadow>() <= std::mem::size_of::<Comm>());
45 unsafe {
46 let p = comm as *const Comm as *const CommLayoutShadow;
47 (*p).raw_comm
48 }
49}
50
51/// PreMulSum custom reduce op: AllReduce-equivalent with a per-tensor
52/// scalar premultiplier living in device memory. Construct via
53/// [`PreMulSumOp::new`]; destroy via [`PreMulSumOp::destroy`] before
54/// the comm goes away.
55pub struct PreMulSumOp<T: NcclReduceSupported> {
56 handle: sys::ncclRedOp_t,
57 /// Keep the scalar GpuRef alive for the lifetime of the op.
58 #[allow(dead_code)]
59 scalar: GpuRef<T>,
60 /// Comm whose lifetime the op is bound to. We don't keep an Arc
61 /// because cudarc's `Comm` isn't internally Arc-able; instead the
62 /// caller must `destroy()` before the comm is dropped.
63 comm_ptr: sys::ncclComm_t,
64 _phantom: PhantomData<T>,
65}
66
67unsafe impl<T: NcclReduceSupported> Send for PreMulSumOp<T> {}
68
69impl<T: NcclReduceSupported> PreMulSumOp<T> {
70 /// Create a PreMulSum op. The `scalar` buffer holds one element
71 /// (per-tensor scaling) at construction time.
72 pub fn new(comm: &Comm, scalar: GpuRef<T>) -> Result<Self, GpuError> {
73 let mut handle: sys::ncclRedOp_t = sys::ncclRedOp_t::ncclSum;
74 let comm_ptr = raw_comm_ptr(comm);
75 // Scoped borrow so `scalar` is movable into the returned struct.
76 {
77 let slice = scalar.access()?;
78 if slice.len() == 0 {
79 return Err(GpuError::Unrecoverable(
80 "PreMulSumOp scalar buffer is empty".into(),
81 ));
82 }
83 let stream = comm.stream();
84 let (dptr, _record) = {
85 use cudarc::driver::DevicePtr;
86 slice.device_ptr(&stream)
87 };
88 unsafe {
89 sys::ncclRedOpCreatePreMulSum(
90 &mut handle as *mut sys::ncclRedOp_t,
91 dptr as *mut std::ffi::c_void,
92 <T as cudarc::nccl::NcclType>::as_nccl_type(),
93 sys::ncclScalarResidence_t::ncclScalarDevice,
94 comm_ptr,
95 )
96 .result()
97 .map_err(|e| GpuError::LibraryError {
98 lib: LIB,
99 msg: format!("ncclRedOpCreatePreMulSum: {e:?}"),
100 })?;
101 }
102 }
103 Ok(Self {
104 handle,
105 scalar,
106 comm_ptr,
107 _phantom: PhantomData,
108 })
109 }
110
111 /// Raw NCCL op handle — pass to FFI calls that take a custom
112 /// `ncclRedOp_t`.
113 pub fn handle(&self) -> sys::ncclRedOp_t {
114 self.handle
115 }
116
117 /// Destroy the underlying NCCL op. Idempotent.
118 pub fn destroy(self) -> Result<(), GpuError> {
119 unsafe {
120 sys::ncclRedOpDestroy(self.handle, self.comm_ptr)
121 .result()
122 .map_err(|e| GpuError::LibraryError {
123 lib: LIB,
124 msg: format!("ncclRedOpDestroy: {e:?}"),
125 })?;
126 }
127 Ok(())
128 }
129}