Skip to main content

atomr_accel_cuda/memory/
advise.rs

1//! `cuMemAdvise` wrapper.
2//!
3//! Public typed enum [`MemAdvice`] mirrors the six driver-level
4//! `CU_MEM_ADVISE_*` variants. The actor surface routes
5//! `ManagedMsg::Advise` through here, but the wrapper is also callable
6//! directly for callers operating on managed allocations they own.
7
8use cudarc::driver::sys as driver_sys;
9
10use crate::error::GpuError;
11use crate::sys::cuda_driver;
12
13use super::managed::PrefetchTarget;
14
15/// Memory advisory hints. Each maps 1:1 to a `CU_MEM_ADVISE_*`
16/// variant. The `Set*` variants take a target location; the `Unset*`
17/// variants ignore it but require it to typecheck (CUDA accepts the
18/// arg either way).
19#[derive(Debug, Clone, Copy)]
20pub enum MemAdvice {
21    /// Hint that the range is read predominantly. CUDA may duplicate
22    /// pages across processors.
23    SetReadMostly,
24    UnsetReadMostly,
25    /// Pin the range to `target`'s memory.
26    SetPreferredLocation(PrefetchTarget),
27    UnsetPreferredLocation,
28    /// Advise that `target` will access the range; CUDA configures
29    /// page-table mappings accordingly.
30    SetAccessedBy(PrefetchTarget),
31    UnsetAccessedBy(PrefetchTarget),
32}
33
34impl MemAdvice {
35    fn raw(self) -> (driver_sys::CUmem_advise, driver_sys::CUmemLocation) {
36        // CUDA 13.0+'s `CUmemLocation` uses an anonymous union for the
37        // location id (`__bindgen_anon_1`) rather than a plain `id: i32`.
38        // We zero-initialize and write the type only — concrete location
39        // pinning is wired in a follow-up PR with the right layout.
40        let host_loc: driver_sys::CUmemLocation = unsafe {
41            let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
42            loc.type_ = driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST;
43            loc
44        };
45        match self {
46            MemAdvice::SetReadMostly => (
47                driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_READ_MOSTLY,
48                host_loc,
49            ),
50            MemAdvice::UnsetReadMostly => (
51                driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_READ_MOSTLY,
52                host_loc,
53            ),
54            MemAdvice::SetPreferredLocation(t) => (
55                driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_PREFERRED_LOCATION,
56                location_for(t),
57            ),
58            MemAdvice::UnsetPreferredLocation => (
59                driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
60                host_loc,
61            ),
62            MemAdvice::SetAccessedBy(t) => (
63                driver_sys::CUmem_advise::CU_MEM_ADVISE_SET_ACCESSED_BY,
64                location_for(t),
65            ),
66            MemAdvice::UnsetAccessedBy(t) => (
67                driver_sys::CUmem_advise::CU_MEM_ADVISE_UNSET_ACCESSED_BY,
68                location_for(t),
69            ),
70        }
71    }
72}
73
74fn location_for(t: PrefetchTarget) -> driver_sys::CUmemLocation {
75    unsafe {
76        let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
77        loc.type_ = match t {
78            PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
79            PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
80        };
81        loc
82    }
83}
84
85/// Apply `advice` to the byte range `[dev_ptr .. dev_ptr+bytes)`.
86/// Wraps `cuMemAdvise_v2`.
87pub fn advise(
88    dev_ptr: driver_sys::CUdeviceptr,
89    bytes: usize,
90    advice: MemAdvice,
91) -> Result<(), GpuError> {
92    let (a, loc) = advice.raw();
93    cuda_driver::mem_advise_v2(dev_ptr, bytes, a, loc)
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn mem_advice_constructs_for_each_variant() {
102        // Verify each enum variant maps to a distinct (advice, location)
103        // pair without panicking.
104        let variants = [
105            MemAdvice::SetReadMostly,
106            MemAdvice::UnsetReadMostly,
107            MemAdvice::SetPreferredLocation(PrefetchTarget::Device(0)),
108            MemAdvice::UnsetPreferredLocation,
109            MemAdvice::SetAccessedBy(PrefetchTarget::Device(1)),
110            MemAdvice::UnsetAccessedBy(PrefetchTarget::Cpu),
111        ];
112        for v in variants {
113            let (_a, _loc) = v.raw();
114        }
115    }
116
117    #[test]
118    fn advise_returns_typed_error_on_no_driver() {
119        let r = advise(0, 0, MemAdvice::SetReadMostly);
120        match r {
121            Ok(()) => {}
122            Err(GpuError::Unrecoverable(_)) => {}
123            Err(GpuError::LibraryError { .. }) => {}
124            other => panic!("unexpected: {other:?}"),
125        }
126    }
127}