1use std::marker::PhantomData;
15use std::ptr::NonNull;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use atomr_core::actor::{Actor, Context, Props};
21use cudarc::driver::sys as driver_sys;
22use cudarc::runtime::sys as runtime_sys;
23use tokio::sync::oneshot;
24
25use crate::error::GpuError;
26
27fn driver_location(target: PrefetchTarget) -> driver_sys::CUmemLocation {
28 unsafe {
32 let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
33 loc.type_ = match target {
34 PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
35 PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
36 };
37 loc
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
42pub enum ManagedFlags {
43 AttachGlobal,
44 AttachHost,
45}
46
47impl ManagedFlags {
48 fn raw(self) -> u32 {
49 match self {
50 ManagedFlags::AttachGlobal => runtime_sys::cudaMemAttachGlobal,
51 ManagedFlags::AttachHost => runtime_sys::cudaMemAttachHost,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy)]
57pub enum PrefetchTarget {
58 Device(u32),
59 Cpu,
60}
61
62#[derive(Debug, Clone, Copy, Default)]
63pub struct ManagedStats {
64 pub allocations: usize,
65 pub bytes_allocated: usize,
66}
67
68pub struct ManagedRef<T> {
72 inner: Option<Arc<ManagedRefInner>>,
73 _marker: PhantomData<T>,
74}
75
76struct ManagedRefInner {
77 ptr: NonNull<u8>,
78 bytes: usize,
79 elements: usize,
80 system_alive: Arc<AtomicBool>,
83}
84
85impl Drop for ManagedRefInner {
86 fn drop(&mut self) {
87 if self.system_alive.load(Ordering::Acquire) {
88 unsafe {
92 let _ = runtime_sys::cudaFree(self.ptr.as_ptr() as *mut _);
93 }
94 }
95 }
96}
97
98unsafe impl Send for ManagedRefInner {}
99unsafe impl Sync for ManagedRefInner {}
100
101impl<T> ManagedRef<T> {
102 pub fn is_valid(&self) -> bool {
104 self.inner
105 .as_ref()
106 .map(|i| i.system_alive.load(Ordering::Acquire))
107 .unwrap_or(false)
108 }
109
110 pub fn len(&self) -> usize {
111 self.inner.as_ref().map(|i| i.elements).unwrap_or(0)
112 }
113
114 pub fn is_empty(&self) -> bool {
115 self.len() == 0
116 }
117
118 pub fn as_ptr(&self) -> *const T {
122 self.inner
123 .as_ref()
124 .map(|i| i.ptr.as_ptr() as *const T)
125 .unwrap_or(std::ptr::null())
126 }
127
128 pub fn as_mut_ptr(&self) -> *mut T {
129 self.inner
130 .as_ref()
131 .map(|i| i.ptr.as_ptr() as *mut T)
132 .unwrap_or(std::ptr::null_mut())
133 }
134}
135
136impl<T: Copy> ManagedRef<T> {
137 pub fn as_slice(&self) -> &[T] {
146 match self.inner.as_ref() {
147 None => &[],
148 Some(i) => {
149 if !i.system_alive.load(Ordering::Acquire) {
150 return &[];
151 }
152 unsafe { std::slice::from_raw_parts(i.ptr.as_ptr() as *const T, i.elements) }
153 }
154 }
155 }
156
157 pub fn as_mut_slice(&mut self) -> &mut [T] {
165 match self.inner.as_ref() {
166 None => &mut [],
167 Some(i) => {
168 if !i.system_alive.load(Ordering::Acquire) {
169 return &mut [];
170 }
171 unsafe { std::slice::from_raw_parts_mut(i.ptr.as_ptr() as *mut T, i.elements) }
172 }
173 }
174 }
175}
176
177impl<T> Clone for ManagedRef<T> {
178 fn clone(&self) -> Self {
179 Self {
180 inner: self.inner.clone(),
181 _marker: PhantomData,
182 }
183 }
184}
185
186unsafe impl<T: Send> Send for ManagedRef<T> {}
187unsafe impl<T: Sync> Sync for ManagedRef<T> {}
188
189pub enum ManagedMsg {
190 AllocateManagedF32 {
191 len: usize,
192 flags: ManagedFlags,
193 reply: oneshot::Sender<Result<ManagedRef<f32>, GpuError>>,
194 },
195 PrefetchF32 {
199 mem: ManagedRef<f32>,
200 target: PrefetchTarget,
201 reply: oneshot::Sender<Result<(), GpuError>>,
202 },
203 AdviseF32 {
206 mem: ManagedRef<f32>,
207 advice: super::advise::MemAdvice,
208 reply: oneshot::Sender<Result<(), GpuError>>,
209 },
210 Stats {
211 reply: oneshot::Sender<ManagedStats>,
212 },
213}
214
215pub struct ManagedAllocatorActor {
216 system_alive: Arc<AtomicBool>,
217 stats: ManagedStats,
218}
219
220impl ManagedAllocatorActor {
221 pub fn props() -> Props<Self> {
222 Props::create(|| ManagedAllocatorActor {
223 system_alive: Arc::new(AtomicBool::new(true)),
224 stats: ManagedStats::default(),
225 })
226 }
227
228 fn allocate_f32(
229 &mut self,
230 len: usize,
231 flags: ManagedFlags,
232 ) -> Result<ManagedRef<f32>, GpuError> {
233 let bytes = len.checked_mul(std::mem::size_of::<f32>()).ok_or_else(|| {
234 GpuError::Unrecoverable("managed alloc: len * size_of overflowed".into())
235 })?;
236 let mut raw: *mut std::ffi::c_void = std::ptr::null_mut();
240 let raw_ref = &mut raw as *mut *mut std::ffi::c_void;
241 let raw_ref = raw_ref as usize; let status_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
243 unsafe {
247 runtime_sys::cudaMallocManaged(
248 raw_ref as *mut *mut std::ffi::c_void,
249 bytes,
250 flags.raw(),
251 )
252 }
253 }));
254 let status = match status_res {
255 Ok(s) => s,
256 Err(_) => {
257 return Err(GpuError::Unrecoverable(
258 "cudaMallocManaged: CUDA runtime not loadable".into(),
259 ));
260 }
261 };
262 if status != runtime_sys::cudaError::cudaSuccess {
263 return Err(GpuError::OutOfMemory(format!(
264 "cudaMallocManaged({bytes}B): {status:?}"
265 )));
266 }
267 let ptr = NonNull::new(raw as *mut u8)
268 .ok_or_else(|| GpuError::Unrecoverable("cudaMallocManaged returned null".into()))?;
269 self.stats.allocations += 1;
270 self.stats.bytes_allocated += bytes;
271 Ok(ManagedRef {
272 inner: Some(Arc::new(ManagedRefInner {
273 ptr,
274 bytes,
275 elements: len,
276 system_alive: self.system_alive.clone(),
277 })),
278 _marker: PhantomData,
279 })
280 }
281}
282
283#[async_trait]
284impl Actor for ManagedAllocatorActor {
285 type Msg = ManagedMsg;
286
287 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ManagedMsg) {
288 match msg {
289 ManagedMsg::AllocateManagedF32 { len, flags, reply } => {
290 let _ = reply.send(self.allocate_f32(len, flags));
291 }
292 ManagedMsg::PrefetchF32 { mem, target, reply } => {
293 let Some(inner) = mem.inner.as_ref() else {
294 let _ = reply.send(Err(GpuError::Unrecoverable(
295 "PrefetchF32: invalid ManagedRef".into(),
296 )));
297 return;
298 };
299 if !inner.system_alive.load(Ordering::Acquire) {
300 let _ = reply.send(Err(GpuError::Unrecoverable(
301 "PrefetchF32: allocator stopped".into(),
302 )));
303 return;
304 }
305 let location = driver_location(target);
308 let dev_ptr = inner.ptr.as_ptr() as cudarc::driver::sys::CUdeviceptr;
309 let r = crate::sys::cuda_driver::mem_prefetch_async_v2(
310 dev_ptr,
311 inner.bytes,
312 location,
313 0,
314 std::ptr::null_mut(),
315 );
316 let _ = reply.send(r);
317 }
318 ManagedMsg::AdviseF32 { mem, advice, reply } => {
319 let Some(inner) = mem.inner.as_ref() else {
320 let _ = reply.send(Err(GpuError::Unrecoverable(
321 "AdviseF32: invalid ManagedRef".into(),
322 )));
323 return;
324 };
325 if !inner.system_alive.load(Ordering::Acquire) {
326 let _ = reply.send(Err(GpuError::Unrecoverable(
327 "AdviseF32: allocator stopped".into(),
328 )));
329 return;
330 }
331 let dev_ptr = inner.ptr.as_ptr() as cudarc::driver::sys::CUdeviceptr;
332 let r = crate::memory::advise::advise(dev_ptr, inner.bytes, advice);
333 let _ = reply.send(r);
334 }
335 ManagedMsg::Stats { reply } => {
336 let _ = reply.send(self.stats);
337 }
338 }
339 }
340
341 async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
342 self.system_alive.store(false, Ordering::Release);
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use atomr_config::Config;
359 use atomr_core::actor::ActorSystem;
360 use std::time::Duration;
361
362 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
369 async fn allocate_replies_then_invalidate_on_post_stop() {
370 let sys = ActorSystem::create("managed-test", Config::empty())
371 .await
372 .unwrap();
373 let mgr = sys
374 .actor_of(ManagedAllocatorActor::props(), "managed")
375 .unwrap();
376
377 let (tx, rx) = oneshot::channel();
378 mgr.tell(ManagedMsg::AllocateManagedF32 {
379 len: 1024,
380 flags: ManagedFlags::AttachGlobal,
381 reply: tx,
382 });
383 let r = tokio::time::timeout(Duration::from_secs(2), rx)
386 .await
387 .unwrap()
388 .unwrap();
389 let _ = r;
390
391 let (tx, rx) = oneshot::channel();
392 mgr.tell(ManagedMsg::Stats { reply: tx });
393 let _stats = tokio::time::timeout(Duration::from_secs(2), rx)
394 .await
395 .unwrap()
396 .unwrap();
397
398 sys.terminate().await;
399 }
400
401 fn synthetic_managed_ref<T>(elements: usize) -> (ManagedRef<T>, Arc<AtomicBool>) {
408 let alive = Arc::new(AtomicBool::new(true));
409 let mut buf = Box::<u8>::new(0u8);
410 let raw = NonNull::new(&mut *buf as *mut u8).unwrap();
411 std::mem::forget(buf); let mref = ManagedRef::<T> {
413 inner: Some(Arc::new(ManagedRefInner {
414 ptr: raw,
415 bytes: elements * std::mem::size_of::<T>(),
416 elements,
417 system_alive: alive.clone(),
418 })),
419 _marker: PhantomData,
420 };
421 (mref, alive)
422 }
423
424 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
425 async fn prefetch_message_routes_through_actor() {
426 let sys = ActorSystem::create("managed-prefetch-test", Config::empty())
427 .await
428 .unwrap();
429 let mgr = sys
430 .actor_of(ManagedAllocatorActor::props(), "managed")
431 .unwrap();
432
433 let (mref, alive) = synthetic_managed_ref::<f32>(64);
434 let (tx, rx) = oneshot::channel();
435 mgr.tell(ManagedMsg::PrefetchF32 {
436 mem: mref.clone(),
437 target: PrefetchTarget::Cpu,
438 reply: tx,
439 });
440 let r = tokio::time::timeout(Duration::from_secs(2), rx)
441 .await
442 .unwrap()
443 .unwrap();
444 match r {
447 Ok(()) => {}
448 Err(GpuError::Unrecoverable(_)) => {}
449 Err(GpuError::LibraryError { .. }) => {}
450 other => panic!("unexpected: {other:?}"),
451 }
452
453 alive.store(false, Ordering::Release);
456 drop(mref);
457 sys.terminate().await;
458 }
459
460 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
461 async fn advise_message_routes_through_actor() {
462 let sys = ActorSystem::create("managed-advise-test", Config::empty())
463 .await
464 .unwrap();
465 let mgr = sys
466 .actor_of(ManagedAllocatorActor::props(), "managed")
467 .unwrap();
468
469 let (mref, alive) = synthetic_managed_ref::<f32>(64);
470 let (tx, rx) = oneshot::channel();
471 mgr.tell(ManagedMsg::AdviseF32 {
472 mem: mref.clone(),
473 advice: super::super::advise::MemAdvice::SetReadMostly,
474 reply: tx,
475 });
476 let r = tokio::time::timeout(Duration::from_secs(2), rx)
477 .await
478 .unwrap()
479 .unwrap();
480 match r {
481 Ok(()) => {}
482 Err(GpuError::Unrecoverable(_)) => {}
483 Err(GpuError::LibraryError { .. }) => {}
484 other => panic!("unexpected: {other:?}"),
485 }
486
487 alive.store(false, Ordering::Release);
488 drop(mref);
489 sys.terminate().await;
490 }
491}