Skip to main content

atomr_accel_cuda/memory/
managed.rs

1//! `ManagedAllocatorActor` + `ManagedRef<T>`.
2//!
3//! `ManagedRef<T>` is distinct from [`crate::gpu_ref::GpuRef<T>`]
4//! because validity rules differ: managed memory survives
5//! `ContextActor` rebuilds, so generation tokens don't apply.
6//! Validity is tied to the allocator actor's lifetime via
7//! `Arc<AtomicBool>`.
8//!
9//! Backed by raw `cudaMallocManaged` + `cudaFree` from cudarc's
10//! runtime sys layer. The pointer is allocated once at construction
11//! and freed when the last `ManagedRef` clone drops or when the
12//! allocator actor stops, whichever comes first.
13
14use std::marker::PhantomData;
15use std::ptr::NonNull;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use atomr_core::actor::{Actor, Context, Props};
21use cudarc::driver::sys as driver_sys;
22use cudarc::runtime::sys as runtime_sys;
23use tokio::sync::oneshot;
24
25use crate::error::GpuError;
26
27fn driver_location(target: PrefetchTarget) -> driver_sys::CUmemLocation {
28    // CUDA 13.0+'s `CUmemLocation` uses an anonymous union for the
29    // location id; zero-init and set type only — concrete location
30    // pinning is wired in a follow-up PR.
31    unsafe {
32        let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
33        loc.type_ = match target {
34            PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
35            PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
36        };
37        loc
38    }
39}
40
41#[derive(Debug, Clone, Copy)]
42pub enum ManagedFlags {
43    AttachGlobal,
44    AttachHost,
45}
46
47impl ManagedFlags {
48    fn raw(self) -> u32 {
49        match self {
50            ManagedFlags::AttachGlobal => runtime_sys::cudaMemAttachGlobal,
51            ManagedFlags::AttachHost => runtime_sys::cudaMemAttachHost,
52        }
53    }
54}
55
56#[derive(Debug, Clone, Copy)]
57pub enum PrefetchTarget {
58    Device(u32),
59    Cpu,
60}
61
62#[derive(Debug, Clone, Copy, Default)]
63pub struct ManagedStats {
64    pub allocations: usize,
65    pub bytes_allocated: usize,
66}
67
68/// Owning handle to a managed-memory region. `Arc`-cloned across
69/// agents; the underlying allocation is freed when the last clone
70/// drops.
71pub struct ManagedRef<T> {
72    inner: Option<Arc<ManagedRefInner>>,
73    _marker: PhantomData<T>,
74}
75
76struct ManagedRefInner {
77    ptr: NonNull<u8>,
78    bytes: usize,
79    elements: usize,
80    /// While true, the allocator actor still owns the master ref;
81    /// once it drops to false the pointer must be considered freed.
82    system_alive: Arc<AtomicBool>,
83}
84
85impl Drop for ManagedRefInner {
86    fn drop(&mut self) {
87        if self.system_alive.load(Ordering::Acquire) {
88            // SAFETY: ptr was returned by cudaMallocManaged with
89            // the same allocator. cudaFree is the documented release
90            // call. We swallow the error — Drop can't propagate.
91            unsafe {
92                let _ = runtime_sys::cudaFree(self.ptr.as_ptr() as *mut _);
93            }
94        }
95    }
96}
97
98unsafe impl Send for ManagedRefInner {}
99unsafe impl Sync for ManagedRefInner {}
100
101impl<T> ManagedRef<T> {
102    /// True if the underlying allocation is still live.
103    pub fn is_valid(&self) -> bool {
104        self.inner
105            .as_ref()
106            .map(|i| i.system_alive.load(Ordering::Acquire))
107            .unwrap_or(false)
108    }
109
110    pub fn len(&self) -> usize {
111        self.inner.as_ref().map(|i| i.elements).unwrap_or(0)
112    }
113
114    pub fn is_empty(&self) -> bool {
115        self.len() == 0
116    }
117
118    /// Raw device pointer. Valid for both host and device access while
119    /// the allocator is alive. Caller is responsible for
120    /// synchronization.
121    pub fn as_ptr(&self) -> *const T {
122        self.inner
123            .as_ref()
124            .map(|i| i.ptr.as_ptr() as *const T)
125            .unwrap_or(std::ptr::null())
126    }
127
128    pub fn as_mut_ptr(&self) -> *mut T {
129        self.inner
130            .as_ref()
131            .map(|i| i.ptr.as_ptr() as *mut T)
132            .unwrap_or(std::ptr::null_mut())
133    }
134}
135
136impl<T: Copy> ManagedRef<T> {
137    /// Host-side immutable view of the managed memory.
138    ///
139    /// SAFETY contract: managed memory is coherent host/device, but
140    /// reads from the host before the device has finished writing
141    /// produce undefined values. Caller must synchronize the
142    /// relevant device stream first (e.g. via
143    /// `cudaDeviceSynchronize`) when reading data the device wrote.
144    /// Returns an empty slice if the allocator has stopped.
145    pub fn as_slice(&self) -> &[T] {
146        match self.inner.as_ref() {
147            None => &[],
148            Some(i) => {
149                if !i.system_alive.load(Ordering::Acquire) {
150                    return &[];
151                }
152                unsafe { std::slice::from_raw_parts(i.ptr.as_ptr() as *const T, i.elements) }
153            }
154        }
155    }
156
157    /// Host-side mutable view of the managed memory.
158    ///
159    /// Like [`Self::as_slice`] this is a host alias of device-visible
160    /// memory. The caller must hold a `WriteToken` (e.g. from
161    /// [`crate::memory::ManagedAllocatorActor`] or
162    /// `SharedGpuStateCoordinator`) to avoid concurrent device
163    /// writes while writing from the host.
164    pub fn as_mut_slice(&mut self) -> &mut [T] {
165        match self.inner.as_ref() {
166            None => &mut [],
167            Some(i) => {
168                if !i.system_alive.load(Ordering::Acquire) {
169                    return &mut [];
170                }
171                unsafe { std::slice::from_raw_parts_mut(i.ptr.as_ptr() as *mut T, i.elements) }
172            }
173        }
174    }
175}
176
177impl<T> Clone for ManagedRef<T> {
178    fn clone(&self) -> Self {
179        Self {
180            inner: self.inner.clone(),
181            _marker: PhantomData,
182        }
183    }
184}
185
186unsafe impl<T: Send> Send for ManagedRef<T> {}
187unsafe impl<T: Sync> Sync for ManagedRef<T> {}
188
189pub enum ManagedMsg {
190    AllocateManagedF32 {
191        len: usize,
192        flags: ManagedFlags,
193        reply: oneshot::Sender<Result<ManagedRef<f32>, GpuError>>,
194    },
195    /// Prefetch a managed allocation to a specific target. The
196    /// `mem` argument is a clone of the `ManagedRef` returned from
197    /// allocation.
198    PrefetchF32 {
199        mem: ManagedRef<f32>,
200        target: PrefetchTarget,
201        reply: oneshot::Sender<Result<(), GpuError>>,
202    },
203    /// Apply a memory advisory hint to a managed allocation.
204    /// Wraps `cuMemAdvise_v2`.
205    AdviseF32 {
206        mem: ManagedRef<f32>,
207        advice: super::advise::MemAdvice,
208        reply: oneshot::Sender<Result<(), GpuError>>,
209    },
210    Stats {
211        reply: oneshot::Sender<ManagedStats>,
212    },
213}
214
215pub struct ManagedAllocatorActor {
216    system_alive: Arc<AtomicBool>,
217    stats: ManagedStats,
218}
219
220impl ManagedAllocatorActor {
221    pub fn props() -> Props<Self> {
222        Props::create(|| ManagedAllocatorActor {
223            system_alive: Arc::new(AtomicBool::new(true)),
224            stats: ManagedStats::default(),
225        })
226    }
227
228    fn allocate_f32(
229        &mut self,
230        len: usize,
231        flags: ManagedFlags,
232    ) -> Result<ManagedRef<f32>, GpuError> {
233        let bytes = len.checked_mul(std::mem::size_of::<f32>()).ok_or_else(|| {
234            GpuError::Unrecoverable("managed alloc: len * size_of overflowed".into())
235        })?;
236        // cudarc's dynamic-loader panics if the CUDA runtime library
237        // isn't loadable on the host (e.g. no driver in CI). Catch
238        // that here so the actor stays alive on no-GPU machines.
239        let mut raw: *mut std::ffi::c_void = std::ptr::null_mut();
240        let raw_ref = &mut raw as *mut *mut std::ffi::c_void;
241        let raw_ref = raw_ref as usize; // copy for closure
242        let status_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
243            // SAFETY: cudaMallocManaged contract — writable out-ptr,
244            // valid size. The pointer-as-usize cast is needed
245            // because raw pointers aren't UnwindSafe.
246            unsafe {
247                runtime_sys::cudaMallocManaged(
248                    raw_ref as *mut *mut std::ffi::c_void,
249                    bytes,
250                    flags.raw(),
251                )
252            }
253        }));
254        let status = match status_res {
255            Ok(s) => s,
256            Err(_) => {
257                return Err(GpuError::Unrecoverable(
258                    "cudaMallocManaged: CUDA runtime not loadable".into(),
259                ));
260            }
261        };
262        if status != runtime_sys::cudaError::cudaSuccess {
263            return Err(GpuError::OutOfMemory(format!(
264                "cudaMallocManaged({bytes}B): {status:?}"
265            )));
266        }
267        let ptr = NonNull::new(raw as *mut u8)
268            .ok_or_else(|| GpuError::Unrecoverable("cudaMallocManaged returned null".into()))?;
269        self.stats.allocations += 1;
270        self.stats.bytes_allocated += bytes;
271        Ok(ManagedRef {
272            inner: Some(Arc::new(ManagedRefInner {
273                ptr,
274                bytes,
275                elements: len,
276                system_alive: self.system_alive.clone(),
277            })),
278            _marker: PhantomData,
279        })
280    }
281}
282
283#[async_trait]
284impl Actor for ManagedAllocatorActor {
285    type Msg = ManagedMsg;
286
287    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ManagedMsg) {
288        match msg {
289            ManagedMsg::AllocateManagedF32 { len, flags, reply } => {
290                let _ = reply.send(self.allocate_f32(len, flags));
291            }
292            ManagedMsg::PrefetchF32 { mem, target, reply } => {
293                let Some(inner) = mem.inner.as_ref() else {
294                    let _ = reply.send(Err(GpuError::Unrecoverable(
295                        "PrefetchF32: invalid ManagedRef".into(),
296                    )));
297                    return;
298                };
299                if !inner.system_alive.load(Ordering::Acquire) {
300                    let _ = reply.send(Err(GpuError::Unrecoverable(
301                        "PrefetchF32: allocator stopped".into(),
302                    )));
303                    return;
304                }
305                // Route through the driver-API helper so the
306                // runtime/driver paths share one panic-safe wrapper.
307                let location = driver_location(target);
308                let dev_ptr = inner.ptr.as_ptr() as cudarc::driver::sys::CUdeviceptr;
309                let r = crate::sys::cuda_driver::mem_prefetch_async_v2(
310                    dev_ptr,
311                    inner.bytes,
312                    location,
313                    0,
314                    std::ptr::null_mut(),
315                );
316                let _ = reply.send(r);
317            }
318            ManagedMsg::AdviseF32 { mem, advice, reply } => {
319                let Some(inner) = mem.inner.as_ref() else {
320                    let _ = reply.send(Err(GpuError::Unrecoverable(
321                        "AdviseF32: invalid ManagedRef".into(),
322                    )));
323                    return;
324                };
325                if !inner.system_alive.load(Ordering::Acquire) {
326                    let _ = reply.send(Err(GpuError::Unrecoverable(
327                        "AdviseF32: allocator stopped".into(),
328                    )));
329                    return;
330                }
331                let dev_ptr = inner.ptr.as_ptr() as cudarc::driver::sys::CUdeviceptr;
332                let r = crate::memory::advise::advise(dev_ptr, inner.bytes, advice);
333                let _ = reply.send(r);
334            }
335            ManagedMsg::Stats { reply } => {
336                let _ = reply.send(self.stats);
337            }
338        }
339    }
340
341    async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
342        // Mark the system dead. ManagedRefInner::Drop calls cudaFree
343        // only while system_alive is true; flipping it here makes
344        // outstanding clones into safe inert handles. The strong
345        // ref the actor doesn't hold (pointers are tracked by
346        // ManagedRefInner Arc) means the runtime frees the memory
347        // when the last clone drops anyway — but only if the
348        // system is alive. Trade-off documented in the module
349        // doc: allocations can outlive the actor only if at least
350        // one ManagedRef clone is alive when the actor stops.
351        self.system_alive.store(false, Ordering::Release);
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use atomr_config::Config;
359    use atomr_core::actor::ActorSystem;
360    use std::time::Duration;
361
362    /// We can't actually call cudaMallocManaged on a host without a
363    /// CUDA driver. The test below just verifies the actor's
364    /// surface — alloc + post_stop — by sending an alloc request;
365    /// the alloc fails with OutOfMemory on a no-GPU machine, which
366    /// is the expected behaviour. Stats reflects zero successful
367    /// allocations.
368    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
369    async fn allocate_replies_then_invalidate_on_post_stop() {
370        let sys = ActorSystem::create("managed-test", Config::empty())
371            .await
372            .unwrap();
373        let mgr = sys
374            .actor_of(ManagedAllocatorActor::props(), "managed")
375            .unwrap();
376
377        let (tx, rx) = oneshot::channel();
378        mgr.tell(ManagedMsg::AllocateManagedF32 {
379            len: 1024,
380            flags: ManagedFlags::AttachGlobal,
381            reply: tx,
382        });
383        // Either succeeds (real GPU) or returns OutOfMemory (no
384        // driver). Either is fine for this test.
385        let r = tokio::time::timeout(Duration::from_secs(2), rx)
386            .await
387            .unwrap()
388            .unwrap();
389        let _ = r;
390
391        let (tx, rx) = oneshot::channel();
392        mgr.tell(ManagedMsg::Stats { reply: tx });
393        let _stats = tokio::time::timeout(Duration::from_secs(2), rx)
394            .await
395            .unwrap()
396            .unwrap();
397
398        sys.terminate().await;
399    }
400
401    /// Construct a synthetic `ManagedRef` that stands in for a real
402    /// allocation. We point at a 1-byte heap-allocated buffer so the
403    /// pointer is non-null; the system_alive flag is held by the
404    /// caller. We never actually dereference the pointer or let
405    /// Drop call `cudaFree` on it (which would corrupt heap state) —
406    /// the test inspects the message-routing code path only.
407    fn synthetic_managed_ref<T>(elements: usize) -> (ManagedRef<T>, Arc<AtomicBool>) {
408        let alive = Arc::new(AtomicBool::new(true));
409        let mut buf = Box::<u8>::new(0u8);
410        let raw = NonNull::new(&mut *buf as *mut u8).unwrap();
411        std::mem::forget(buf); // intentionally leak — Drop won't run because we set system_alive=false at the end of the test.
412        let mref = ManagedRef::<T> {
413            inner: Some(Arc::new(ManagedRefInner {
414                ptr: raw,
415                bytes: elements * std::mem::size_of::<T>(),
416                elements,
417                system_alive: alive.clone(),
418            })),
419            _marker: PhantomData,
420        };
421        (mref, alive)
422    }
423
424    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
425    async fn prefetch_message_routes_through_actor() {
426        let sys = ActorSystem::create("managed-prefetch-test", Config::empty())
427            .await
428            .unwrap();
429        let mgr = sys
430            .actor_of(ManagedAllocatorActor::props(), "managed")
431            .unwrap();
432
433        let (mref, alive) = synthetic_managed_ref::<f32>(64);
434        let (tx, rx) = oneshot::channel();
435        mgr.tell(ManagedMsg::PrefetchF32 {
436            mem: mref.clone(),
437            target: PrefetchTarget::Cpu,
438            reply: tx,
439        });
440        let r = tokio::time::timeout(Duration::from_secs(2), rx)
441            .await
442            .unwrap()
443            .unwrap();
444        // No driver → Unrecoverable, real driver → LibraryError on bogus
445        // ptr, both acceptable. We never want a panic to propagate.
446        match r {
447            Ok(()) => {}
448            Err(GpuError::Unrecoverable(_)) => {}
449            Err(GpuError::LibraryError { .. }) => {}
450            other => panic!("unexpected: {other:?}"),
451        }
452
453        // Mark the synthetic alloc as inert so the leaked Box doesn't
454        // get cudaFree'd by ManagedRefInner::Drop.
455        alive.store(false, Ordering::Release);
456        drop(mref);
457        sys.terminate().await;
458    }
459
460    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
461    async fn advise_message_routes_through_actor() {
462        let sys = ActorSystem::create("managed-advise-test", Config::empty())
463            .await
464            .unwrap();
465        let mgr = sys
466            .actor_of(ManagedAllocatorActor::props(), "managed")
467            .unwrap();
468
469        let (mref, alive) = synthetic_managed_ref::<f32>(64);
470        let (tx, rx) = oneshot::channel();
471        mgr.tell(ManagedMsg::AdviseF32 {
472            mem: mref.clone(),
473            advice: super::super::advise::MemAdvice::SetReadMostly,
474            reply: tx,
475        });
476        let r = tokio::time::timeout(Duration::from_secs(2), rx)
477            .await
478            .unwrap()
479            .unwrap();
480        match r {
481            Ok(()) => {}
482            Err(GpuError::Unrecoverable(_)) => {}
483            Err(GpuError::LibraryError { .. }) => {}
484            other => panic!("unexpected: {other:?}"),
485        }
486
487        alive.store(false, Ordering::Release);
488        drop(mref);
489        sys.terminate().await;
490    }
491}