Skip to main content

atomr_accel_cuda/stream/
pool.rs

1//! `PooledAllocator` (§5.7) — bounded round-robin across N streams.
2//!
3//! Each `KernelActor` constructed with this allocator picks the next
4//! stream from a fixed-size pool. Trade-off vs. `PerActorAllocator`:
5//! capped stream count (less driver overhead, less memory) at the
6//! cost of cross-actor stream contention.
7
8use std::sync::Arc;
9
10use cudarc::driver::{CudaContext, CudaStream};
11
12use super::{ActorHints, StreamAllocator};
13
14pub struct PooledAllocator {
15    pool: Vec<Arc<CudaStream>>,
16    cursor: parking_lot::Mutex<usize>,
17}
18
19impl PooledAllocator {
20    /// Construct a pool from a vector of pre-existing streams. All
21    /// streams must belong to the same context.
22    pub fn new(streams: Vec<Arc<CudaStream>>) -> Self {
23        assert!(
24            !streams.is_empty(),
25            "PooledAllocator requires at least one stream"
26        );
27        Self {
28            pool: streams,
29            cursor: parking_lot::Mutex::new(0),
30        }
31    }
32
33    /// Construct a pool by minting `count` fresh streams on `ctx`.
34    /// Panics with the `ContextPoisoned` tag if any stream creation
35    /// fails so the parent supervisor can restart the actor.
36    pub fn with_size(ctx: &Arc<CudaContext>, count: usize) -> Self {
37        assert!(count > 0, "PooledAllocator requires count >= 1");
38        let mut streams = Vec::with_capacity(count);
39        for _ in 0..count {
40            let s = ctx
41                .new_stream()
42                .unwrap_or_else(|e| panic!("ContextPoisoned: new_stream: {e}"));
43            streams.push(s);
44        }
45        Self::new(streams)
46    }
47
48    pub fn size(&self) -> usize {
49        self.pool.len()
50    }
51}
52
53impl StreamAllocator for PooledAllocator {
54    fn acquire(&self, _hints: ActorHints) -> Arc<CudaStream> {
55        let mut cur = self.cursor.lock();
56        let idx = *cur % self.pool.len();
57        *cur = cur.wrapping_add(1);
58        self.pool[idx].clone()
59    }
60}