Skip to main content

atomr_accel_cuda/memory/
prefetch.rs

1//! `cuMemPrefetchAsync` wrapper.
2//!
3//! Sits next to [`super::managed`] but is callable independently —
4//! callers that have a raw `CUdeviceptr` (e.g. from a custom
5//! allocator or an IPC-opened handle) can still issue prefetch hints.
6//!
7//! Use [`super::managed::PrefetchTarget`] to describe the destination.
8
9use std::sync::Arc;
10
11use cudarc::driver::sys as driver_sys;
12use cudarc::driver::CudaStream;
13
14use crate::error::GpuError;
15use crate::sys::cuda_driver;
16
17use super::managed::PrefetchTarget;
18
19/// Prefetch the byte range `[dev_ptr .. dev_ptr+bytes)` to `target`,
20/// issued onto `stream`. Wraps `cuMemPrefetchAsync_v2`.
21///
22/// Returns `Unrecoverable` on hosts where `libcuda.so` isn't loadable.
23pub fn prefetch_async(
24    dev_ptr: driver_sys::CUdeviceptr,
25    bytes: usize,
26    target: PrefetchTarget,
27    stream: &Arc<CudaStream>,
28) -> Result<(), GpuError> {
29    // CUDA 13.0+'s `CUmemLocation` uses an anonymous union for the
30    // location id; zero-init and set type only — concrete location
31    // pinning is wired in a follow-up PR.
32    let location = unsafe {
33        let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
34        loc.type_ = match target {
35            PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
36            PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
37        };
38        loc
39    };
40    let _ = target;
41    cuda_driver::mem_prefetch_async_v2(dev_ptr, bytes, location, 0, stream.cu_stream())
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    // Phase 3 mock-mode test: confirm the wrapper compiles and that
48    // calling against a null dev pointer surfaces a typed error rather
49    // than panicking on no-GPU hosts. The actual path is exercised
50    // through `memory::managed::tests` which threads through
51    // `ManagedAllocatorActor`.
52
53    #[test]
54    fn prefetch_async_returns_typed_error_on_no_driver() {
55        // Attempting to issue a real prefetch with a null pointer on a
56        // host without libcuda.so loadable produces Unrecoverable,
57        // which is what we want — the wrapper does not panic.
58        let host_loc = unsafe {
59            let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
60            loc.type_ = driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST;
61            loc
62        };
63        let r = cuda_driver::mem_prefetch_async_v2(0, 0, host_loc, 0, std::ptr::null_mut());
64        // Either Unrecoverable (no driver) or LibraryError (driver
65        // present, rejects null) — both acceptable.
66        match r {
67            Ok(()) => {}
68            Err(GpuError::Unrecoverable(_)) => {}
69            Err(GpuError::LibraryError { .. }) => {}
70            other => panic!("unexpected: {other:?}"),
71        }
72    }
73}