Skip to main content

atomr_accel_cuda/placement/
mod.rs

1//! `PlacementActor` — picks the best-fit `DeviceActor` for each
2//! request based on a configurable [`PlacementPolicy`].
3//!
4//! Polls each device's `DeviceMsg::Stats` periodically (default 250
5//! ms) to maintain a load snapshot. Callers send `Pick` to receive
6//! a `DeviceChoice` with the selected `ActorRef<DeviceMsg>`.
7
8#[cfg(feature = "cluster")]
9pub mod sharded;
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use atomr_core::actor::{Actor, ActorRef, Context, Props};
16use parking_lot::Mutex;
17use tokio::sync::oneshot;
18
19use crate::device::{DeviceLoad, DeviceMsg};
20use crate::error::GpuError;
21use crate::stream::Priority;
22
23#[derive(Debug, Clone, Copy, Default)]
24pub struct PlacementHints {
25    pub min_free_bytes: usize,
26    pub min_compute_cap: Option<(i32, i32)>,
27    pub priority: Option<Priority>,
28}
29
30pub struct DeviceChoice {
31    pub device_id: u32,
32    pub device: ActorRef<DeviceMsg>,
33    pub load: DeviceLoad,
34}
35
36pub trait PlacementPolicy: Send + Sync + 'static {
37    fn choose(&self, hints: &PlacementHints, candidates: &[(u32, &DeviceLoad)]) -> Option<u32>;
38}
39
40/// Round-robin policy. Ignores `hints`.
41pub struct RoundRobinPolicy {
42    cursor: Mutex<usize>,
43}
44
45impl Default for RoundRobinPolicy {
46    fn default() -> Self {
47        Self {
48            cursor: Mutex::new(0),
49        }
50    }
51}
52
53impl PlacementPolicy for RoundRobinPolicy {
54    fn choose(&self, _hints: &PlacementHints, candidates: &[(u32, &DeviceLoad)]) -> Option<u32> {
55        if candidates.is_empty() {
56            return None;
57        }
58        let mut c = self.cursor.lock();
59        let idx = *c % candidates.len();
60        *c = c.wrapping_add(1);
61        Some(candidates[idx].0)
62    }
63}
64
65/// Least-loaded by `queue_depth + active_streams` heuristic. Filters
66/// out devices below `min_free_bytes` / `min_compute_cap`.
67pub struct LeastLoadedPolicy;
68
69impl PlacementPolicy for LeastLoadedPolicy {
70    fn choose(&self, hints: &PlacementHints, candidates: &[(u32, &DeviceLoad)]) -> Option<u32> {
71        let mut best: Option<(u32, u64)> = None;
72        for (id, load) in candidates {
73            if load.free_bytes < hints.min_free_bytes {
74                continue;
75            }
76            if let Some((mj, mn)) = hints.min_compute_cap {
77                if load.compute_cap.0 < mj || (load.compute_cap.0 == mj && load.compute_cap.1 < mn)
78                {
79                    continue;
80                }
81            }
82            let score = load.queue_depth as u64 + load.active_streams as u64;
83            match best {
84                None => best = Some((*id, score)),
85                Some((_, s)) if score < s => best = Some((*id, score)),
86                _ => {}
87            }
88        }
89        best.map(|(id, _)| id)
90    }
91}
92
93pub enum PlacementMsg {
94    Pick {
95        hints: PlacementHints,
96        reply: oneshot::Sender<Result<DeviceChoice, GpuError>>,
97    },
98    /// Internal: timer fires the per-device stats poll.
99    PollStats,
100    /// Internal: a single device's stats reply arrived, update the
101    /// cached load snapshot.
102    StatsUpdate { slot: usize, load: DeviceLoad },
103}
104
105pub struct PlacementActor {
106    devices: Vec<(u32, ActorRef<DeviceMsg>)>,
107    loads: Vec<DeviceLoad>,
108    policy: Arc<dyn PlacementPolicy>,
109    poll_interval: Duration,
110}
111
112impl PlacementActor {
113    pub fn props(
114        devices: Vec<(u32, ActorRef<DeviceMsg>)>,
115        policy: Arc<dyn PlacementPolicy>,
116    ) -> Props<Self> {
117        let n = devices.len();
118        Props::create(move || PlacementActor {
119            devices: devices.clone(),
120            loads: (0..n)
121                .map(|_| DeviceLoad {
122                    free_bytes: 0,
123                    total_bytes: 0,
124                    active_streams: 0,
125                    queue_depth: 0,
126                    compute_cap: (0, 0),
127                })
128                .collect(),
129            policy: policy.clone(),
130            poll_interval: Duration::from_millis(250),
131        })
132    }
133
134    fn schedule_poll(&self, ctx: &Context<Self>) {
135        let self_ref = ctx.self_ref().clone();
136        let interval = self.poll_interval;
137        tokio::spawn(async move {
138            tokio::time::sleep(interval).await;
139            self_ref.tell(PlacementMsg::PollStats);
140        });
141    }
142}
143
144#[async_trait]
145impl Actor for PlacementActor {
146    type Msg = PlacementMsg;
147
148    async fn pre_start(&mut self, ctx: &mut Context<Self>) {
149        self.schedule_poll(ctx);
150    }
151
152    async fn handle(&mut self, ctx: &mut Context<Self>, msg: PlacementMsg) {
153        match msg {
154            PlacementMsg::Pick { hints, reply } => {
155                let candidates: Vec<(u32, &DeviceLoad)> = self
156                    .devices
157                    .iter()
158                    .zip(self.loads.iter())
159                    .map(|((id, _), load)| (*id, load))
160                    .collect();
161                match self.policy.choose(&hints, &candidates) {
162                    None => {
163                        let _ = reply.send(Err(GpuError::Unrecoverable(
164                            "placement: no eligible device".into(),
165                        )));
166                    }
167                    Some(id) => {
168                        let pos = self.devices.iter().position(|(d, _)| *d == id).unwrap();
169                        let _ = reply.send(Ok(DeviceChoice {
170                            device_id: id,
171                            device: self.devices[pos].1.clone(),
172                            load: self.loads[pos],
173                        }));
174                    }
175                }
176            }
177            PlacementMsg::PollStats => {
178                // Fire one Stats request per device. Each reply
179                // arrives via a tokio task that posts a
180                // `StatsUpdate { slot, load }` back to this actor's
181                // mailbox — closing the feedback loop without
182                // needing &mut self inside the async block.
183                let self_ref = ctx.self_ref().clone();
184                for (i, (_, dev)) in self.devices.iter().enumerate() {
185                    let (tx, rx) = oneshot::channel();
186                    dev.tell(DeviceMsg::Stats { reply: tx });
187                    let self_ref2 = self_ref.clone();
188                    tokio::spawn(async move {
189                        if let Ok(load) = rx.await {
190                            self_ref2.tell(PlacementMsg::StatsUpdate { slot: i, load });
191                        }
192                    });
193                }
194                self.schedule_poll(ctx);
195            }
196            PlacementMsg::StatsUpdate { slot, load } => {
197                if let Some(s) = self.loads.get_mut(slot) {
198                    *s = load;
199                }
200            }
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use atomr_config::Config;
209    use atomr_core::actor::ActorSystem;
210
211    use crate::device::DeviceActor;
212    use crate::device::DeviceConfig;
213
214    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
215    async fn round_robin_picks_alternates() {
216        let sys = ActorSystem::create("placement-rr", Config::empty())
217            .await
218            .unwrap();
219        let d0 = sys
220            .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "d0")
221            .unwrap();
222        let d1 = sys
223            .actor_of(DeviceActor::props(DeviceConfig::mock(1)), "d1")
224            .unwrap();
225        let actor = sys
226            .actor_of(
227                PlacementActor::props(
228                    vec![(0, d0), (1, d1)],
229                    Arc::new(RoundRobinPolicy::default()),
230                ),
231                "placement",
232            )
233            .unwrap();
234
235        let mut picks = Vec::new();
236        for _ in 0..4 {
237            let (tx, rx) = oneshot::channel();
238            actor.tell(PlacementMsg::Pick {
239                hints: PlacementHints::default(),
240                reply: tx,
241            });
242            let c = tokio::time::timeout(Duration::from_secs(2), rx)
243                .await
244                .unwrap()
245                .unwrap()
246                .unwrap();
247            picks.push(c.device_id);
248        }
249        // Round-robin alternates: 0,1,0,1.
250        assert_eq!(picks, vec![0, 1, 0, 1]);
251
252        sys.terminate().await;
253    }
254}