Skip to main content

atomr_accel_cuda/kernel/blas_lt/
workspace.rs

1//! cuBLASLt workspace pool — recycles per-heuristic device buffers.
2//!
3//! cuBLASLt's `cublasLtMatmul` takes an opaque `workspace` device
4//! pointer + a `workspaceSizeInBytes`. The size depends on the
5//! selected algorithm; the heuristic cache reports a per-algorithm
6//! `workspaceSize` and we want to avoid allocating a fresh slab on
7//! every call. This pool buckets free slabs by **rounded-up size
8//! class** (next power of two ≥ requested) and hands them out under
9//! a `WorkspaceLease` RAII guard that returns the slab on Drop.
10//!
11//! The pool is intentionally a **plain struct** (not a separate
12//! actor) because `BlasLtActor` already has single-threaded ownership
13//! of all matmul calls. Wrapping it in another actor would just add a
14//! mailbox hop on the hot path. If a future phase needs cross-actor
15//! sharing we'll lift this to its own actor exactly the way
16//! `PinnedBufferPool` (whose allocator pattern this module mirrors)
17//! is structured.
18
19use std::collections::HashMap;
20use std::sync::Arc;
21
22use cudarc::driver::{CudaSlice, CudaStream};
23use parking_lot::Mutex;
24
25use crate::error::GpuError;
26
27/// Default cap on number of pooled slabs per size class. Beyond this,
28/// excess returns are dropped instead of pooled. With the default 256
29/// distinct heuristic shapes and a typical 2-3 distinct workspace
30/// classes (4 MiB, 32 MiB, 256 MiB) we expect ≤ a few hundred MiB of
31/// pinned VRAM in steady state.
32pub const DEFAULT_POOL_CAPACITY_PER_CLASS: usize = 4;
33
34/// Round a workspace request up to the next power of two ≥ 1 KiB.
35/// Bucketing by power-of-two limits the long-tail of unique sizes
36/// the pool tracks.
37pub fn size_class(bytes: usize) -> usize {
38    bytes.max(1024).next_power_of_two()
39}
40
41/// Inner pool state — guarded by a single mutex.
42struct WorkspacePoolInner {
43    /// Free slabs grouped by size class.
44    free: HashMap<usize, Vec<Arc<CudaSlice<u8>>>>,
45    /// Maximum number of free slabs to retain per size class.
46    per_class_capacity: usize,
47    /// Sum of bytes currently held in `free`. Tracked for
48    /// observability + to feed a future high-watermark eviction.
49    bytes_pooled: usize,
50}
51
52/// Cloneable handle to the workspace pool.
53#[derive(Clone)]
54pub struct WorkspacePool {
55    inner: Arc<Mutex<WorkspacePoolInner>>,
56}
57
58impl WorkspacePool {
59    pub fn new() -> Self {
60        Self::with_capacity(DEFAULT_POOL_CAPACITY_PER_CLASS)
61    }
62
63    pub fn with_capacity(per_class: usize) -> Self {
64        Self {
65            inner: Arc::new(Mutex::new(WorkspacePoolInner {
66                free: HashMap::new(),
67                per_class_capacity: per_class.max(1),
68                bytes_pooled: 0,
69            })),
70        }
71    }
72
73    /// Acquire a workspace slab of at least `requested_bytes`. If the
74    /// pool has a free slab in the matching size class it's reused;
75    /// otherwise a fresh `CudaSlice<u8>` is allocated against `stream`.
76    ///
77    /// The returned [`WorkspaceLease`] auto-returns to the pool on
78    /// Drop. Callers should *not* manually clone the inner slice
79    /// across the boundary — the kernel envelope already keeps the
80    /// `Arc<CudaSlice<u8>>` alive for the duration of the kernel.
81    pub fn acquire(
82        &self,
83        stream: &Arc<CudaStream>,
84        requested_bytes: usize,
85    ) -> Result<WorkspaceLease, GpuError> {
86        let class = size_class(requested_bytes.max(1));
87
88        // Check the pool first.
89        let pooled = {
90            let mut g = self.inner.lock();
91            if let Some(bucket) = g.free.get_mut(&class) {
92                let popped = bucket.pop();
93                if let Some(ref s) = popped {
94                    g.bytes_pooled = g.bytes_pooled.saturating_sub(s.len());
95                }
96                popped
97            } else {
98                None
99            }
100        };
101
102        let slab = match pooled {
103            Some(s) => s,
104            None => {
105                let s = unsafe { stream.alloc::<u8>(class) }.map_err(|e| {
106                    GpuError::OutOfMemory(format!("cublaslt workspace alloc {class}B: {e}"))
107                })?;
108                Arc::new(s)
109            }
110        };
111
112        Ok(WorkspaceLease {
113            slab: Some(slab),
114            class,
115            pool: self.inner.clone(),
116        })
117    }
118
119    /// Number of slabs currently in the free list across all size
120    /// classes. Diagnostic only.
121    pub fn pooled_slabs(&self) -> usize {
122        let g = self.inner.lock();
123        g.free.values().map(|v| v.len()).sum()
124    }
125
126    /// Total bytes currently in the free list.
127    pub fn pooled_bytes(&self) -> usize {
128        self.inner.lock().bytes_pooled
129    }
130}
131
132impl Default for WorkspacePool {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// RAII lease — returns the slab to the pool on Drop. Callers can
139/// take a shared reference to the inner slice for kernel launch via
140/// [`WorkspaceLease::slice`].
141pub struct WorkspaceLease {
142    slab: Option<Arc<CudaSlice<u8>>>,
143    class: usize,
144    pool: Arc<Mutex<WorkspacePoolInner>>,
145}
146
147impl WorkspaceLease {
148    /// Reference to the underlying device slice (`CudaSlice<u8>`).
149    pub fn slice(&self) -> &Arc<CudaSlice<u8>> {
150        self.slab
151            .as_ref()
152            .expect("WorkspaceLease::slice after Drop")
153    }
154
155    /// Size class (rounded-up bytes) of the leased slab.
156    pub fn size(&self) -> usize {
157        self.class
158    }
159}
160
161impl Drop for WorkspaceLease {
162    fn drop(&mut self) {
163        let Some(slab) = self.slab.take() else {
164            return;
165        };
166        let mut g = self.pool.lock();
167        let cap = g.per_class_capacity;
168        let bucket = g.free.entry(self.class).or_default();
169        if bucket.len() < cap {
170            let bytes = slab.len();
171            bucket.push(slab);
172            g.bytes_pooled = g.bytes_pooled.saturating_add(bytes);
173        }
174        // else: drop the slab; CudaSlice's Drop frees the device memory.
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    //! These tests stay GPU-free by exercising the pool's bookkeeping
181    //! through a hand-rolled `WorkspaceLease` path that reuses an
182    //! `Arc<CudaSlice<u8>>` we never construct (we use a synthetic
183    //! return helper that mirrors the Drop contract).
184    use super::*;
185
186    /// Reach into pool internals from tests to seed a free-list slab
187    /// without touching the GPU. Mirrors the Drop path.
188    fn seed_free_slot(pool: &WorkspacePool, class: usize, slab_bytes: usize) {
189        let mut g = pool.inner.lock();
190        // Synthesize a fake CudaSlice<u8>… we can't, so seed the
191        // bookkeeping directly: track the bytes_pooled but skip the
192        // actual `Arc<CudaSlice<u8>>` since constructing one without
193        // a context isn't possible. The pool exposes `pooled_bytes`
194        // so the recycle test verifies size-class accounting only.
195        g.bytes_pooled = g.bytes_pooled.saturating_add(slab_bytes);
196        // Ensure the bucket key exists (so we know recycling happens
197        // *into* this class on a subsequent return).
198        g.free.entry(class).or_default();
199    }
200
201    #[test]
202    fn size_class_rounds_up() {
203        assert_eq!(size_class(1), 1024);
204        assert_eq!(size_class(1024), 1024);
205        assert_eq!(size_class(1025), 2048);
206        assert_eq!(size_class(4 * 1024 * 1024), 4 * 1024 * 1024);
207        assert_eq!(size_class(4 * 1024 * 1024 + 1), 8 * 1024 * 1024);
208        assert_eq!(size_class(0), 1024);
209    }
210
211    #[test]
212    fn workspace_pool_recycles() {
213        // We can't allocate real CudaSlices without a GPU, so this
214        // test exercises the pool's bookkeeping (size_class +
215        // pooled_bytes accounting) and checks that the bucket map
216        // tracks classes correctly across "returns".
217        let pool = WorkspacePool::with_capacity(2);
218        assert_eq!(pool.pooled_slabs(), 0);
219
220        seed_free_slot(&pool, size_class(4 * 1024 * 1024), 4 * 1024 * 1024);
221        assert_eq!(pool.pooled_bytes(), 4 * 1024 * 1024);
222
223        seed_free_slot(&pool, size_class(33_554_432), 33_554_432);
224        assert_eq!(pool.pooled_bytes(), 4 * 1024 * 1024 + 33_554_432);
225        // Two distinct size classes — both buckets exist.
226        assert!(pool
227            .inner
228            .lock()
229            .free
230            .contains_key(&size_class(4 * 1024 * 1024)));
231        assert!(pool.inner.lock().free.contains_key(&size_class(33_554_432)));
232    }
233
234    #[test]
235    fn pool_capacity_clamps_to_one() {
236        let pool = WorkspacePool::with_capacity(0);
237        assert_eq!(pool.inner.lock().per_class_capacity, 1);
238    }
239}