atomr_accel_cuda/memory/prefetch.rs
1//! `cuMemPrefetchAsync` wrapper.
2//!
3//! Sits next to [`super::managed`] but is callable independently —
4//! callers that have a raw `CUdeviceptr` (e.g. from a custom
5//! allocator or an IPC-opened handle) can still issue prefetch hints.
6//!
7//! Use [`super::managed::PrefetchTarget`] to describe the destination.
8
9use std::sync::Arc;
10
11use cudarc::driver::sys as driver_sys;
12use cudarc::driver::CudaStream;
13
14use crate::error::GpuError;
15use crate::sys::cuda_driver;
16
17use super::managed::PrefetchTarget;
18
19/// Prefetch the byte range `[dev_ptr .. dev_ptr+bytes)` to `target`,
20/// issued onto `stream`. Wraps `cuMemPrefetchAsync_v2`.
21///
22/// Returns `Unrecoverable` on hosts where `libcuda.so` isn't loadable.
23pub fn prefetch_async(
24 dev_ptr: driver_sys::CUdeviceptr,
25 bytes: usize,
26 target: PrefetchTarget,
27 stream: &Arc<CudaStream>,
28) -> Result<(), GpuError> {
29 // CUDA 13.0+'s `CUmemLocation` uses an anonymous union for the
30 // location id; zero-init and set type only — concrete location
31 // pinning is wired in a follow-up PR.
32 let location = unsafe {
33 let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
34 loc.type_ = match target {
35 PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
36 PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
37 };
38 loc
39 };
40 let _ = target;
41 cuda_driver::mem_prefetch_async_v2(dev_ptr, bytes, location, 0, stream.cu_stream())
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 // Phase 3 mock-mode test: confirm the wrapper compiles and that
48 // calling against a null dev pointer surfaces a typed error rather
49 // than panicking on no-GPU hosts. The actual path is exercised
50 // through `memory::managed::tests` which threads through
51 // `ManagedAllocatorActor`.
52
53 #[test]
54 fn prefetch_async_returns_typed_error_on_no_driver() {
55 // Attempting to issue a real prefetch with a null pointer on a
56 // host without libcuda.so loadable produces Unrecoverable,
57 // which is what we want — the wrapper does not panic.
58 let host_loc = unsafe {
59 let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
60 loc.type_ = driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST;
61 loc
62 };
63 let r = cuda_driver::mem_prefetch_async_v2(0, 0, host_loc, 0, std::ptr::null_mut());
64 // Either Unrecoverable (no driver) or LibraryError (driver
65 // present, rejects null) — both acceptable.
66 match r {
67 Ok(()) => {}
68 Err(GpuError::Unrecoverable(_)) => {}
69 Err(GpuError::LibraryError { .. }) => {}
70 other => panic!("unexpected: {other:?}"),
71 }
72 }
73}