atomr_accel_cuda/placement/
mod.rs1#[cfg(feature = "cluster")]
9pub mod sharded;
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use atomr_core::actor::{Actor, ActorRef, Context, Props};
16use parking_lot::Mutex;
17use tokio::sync::oneshot;
18
19use crate::device::{DeviceLoad, DeviceMsg};
20use crate::error::GpuError;
21use crate::stream::Priority;
22
23#[derive(Debug, Clone, Copy, Default)]
24pub struct PlacementHints {
25 pub min_free_bytes: usize,
26 pub min_compute_cap: Option<(i32, i32)>,
27 pub priority: Option<Priority>,
28}
29
30pub struct DeviceChoice {
31 pub device_id: u32,
32 pub device: ActorRef<DeviceMsg>,
33 pub load: DeviceLoad,
34}
35
36pub trait PlacementPolicy: Send + Sync + 'static {
37 fn choose(&self, hints: &PlacementHints, candidates: &[(u32, &DeviceLoad)]) -> Option<u32>;
38}
39
40pub struct RoundRobinPolicy {
42 cursor: Mutex<usize>,
43}
44
45impl Default for RoundRobinPolicy {
46 fn default() -> Self {
47 Self {
48 cursor: Mutex::new(0),
49 }
50 }
51}
52
53impl PlacementPolicy for RoundRobinPolicy {
54 fn choose(&self, _hints: &PlacementHints, candidates: &[(u32, &DeviceLoad)]) -> Option<u32> {
55 if candidates.is_empty() {
56 return None;
57 }
58 let mut c = self.cursor.lock();
59 let idx = *c % candidates.len();
60 *c = c.wrapping_add(1);
61 Some(candidates[idx].0)
62 }
63}
64
65pub struct LeastLoadedPolicy;
68
69impl PlacementPolicy for LeastLoadedPolicy {
70 fn choose(&self, hints: &PlacementHints, candidates: &[(u32, &DeviceLoad)]) -> Option<u32> {
71 let mut best: Option<(u32, u64)> = None;
72 for (id, load) in candidates {
73 if load.free_bytes < hints.min_free_bytes {
74 continue;
75 }
76 if let Some((mj, mn)) = hints.min_compute_cap {
77 if load.compute_cap.0 < mj || (load.compute_cap.0 == mj && load.compute_cap.1 < mn)
78 {
79 continue;
80 }
81 }
82 let score = load.queue_depth as u64 + load.active_streams as u64;
83 match best {
84 None => best = Some((*id, score)),
85 Some((_, s)) if score < s => best = Some((*id, score)),
86 _ => {}
87 }
88 }
89 best.map(|(id, _)| id)
90 }
91}
92
93pub enum PlacementMsg {
94 Pick {
95 hints: PlacementHints,
96 reply: oneshot::Sender<Result<DeviceChoice, GpuError>>,
97 },
98 PollStats,
100 StatsUpdate { slot: usize, load: DeviceLoad },
103}
104
105pub struct PlacementActor {
106 devices: Vec<(u32, ActorRef<DeviceMsg>)>,
107 loads: Vec<DeviceLoad>,
108 policy: Arc<dyn PlacementPolicy>,
109 poll_interval: Duration,
110}
111
112impl PlacementActor {
113 pub fn props(
114 devices: Vec<(u32, ActorRef<DeviceMsg>)>,
115 policy: Arc<dyn PlacementPolicy>,
116 ) -> Props<Self> {
117 let n = devices.len();
118 Props::create(move || PlacementActor {
119 devices: devices.clone(),
120 loads: (0..n)
121 .map(|_| DeviceLoad {
122 free_bytes: 0,
123 total_bytes: 0,
124 active_streams: 0,
125 queue_depth: 0,
126 compute_cap: (0, 0),
127 })
128 .collect(),
129 policy: policy.clone(),
130 poll_interval: Duration::from_millis(250),
131 })
132 }
133
134 fn schedule_poll(&self, ctx: &Context<Self>) {
135 let self_ref = ctx.self_ref().clone();
136 let interval = self.poll_interval;
137 tokio::spawn(async move {
138 tokio::time::sleep(interval).await;
139 self_ref.tell(PlacementMsg::PollStats);
140 });
141 }
142}
143
144#[async_trait]
145impl Actor for PlacementActor {
146 type Msg = PlacementMsg;
147
148 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
149 self.schedule_poll(ctx);
150 }
151
152 async fn handle(&mut self, ctx: &mut Context<Self>, msg: PlacementMsg) {
153 match msg {
154 PlacementMsg::Pick { hints, reply } => {
155 let candidates: Vec<(u32, &DeviceLoad)> = self
156 .devices
157 .iter()
158 .zip(self.loads.iter())
159 .map(|((id, _), load)| (*id, load))
160 .collect();
161 match self.policy.choose(&hints, &candidates) {
162 None => {
163 let _ = reply.send(Err(GpuError::Unrecoverable(
164 "placement: no eligible device".into(),
165 )));
166 }
167 Some(id) => {
168 let pos = self.devices.iter().position(|(d, _)| *d == id).unwrap();
169 let _ = reply.send(Ok(DeviceChoice {
170 device_id: id,
171 device: self.devices[pos].1.clone(),
172 load: self.loads[pos],
173 }));
174 }
175 }
176 }
177 PlacementMsg::PollStats => {
178 let self_ref = ctx.self_ref().clone();
184 for (i, (_, dev)) in self.devices.iter().enumerate() {
185 let (tx, rx) = oneshot::channel();
186 dev.tell(DeviceMsg::Stats { reply: tx });
187 let self_ref2 = self_ref.clone();
188 tokio::spawn(async move {
189 if let Ok(load) = rx.await {
190 self_ref2.tell(PlacementMsg::StatsUpdate { slot: i, load });
191 }
192 });
193 }
194 self.schedule_poll(ctx);
195 }
196 PlacementMsg::StatsUpdate { slot, load } => {
197 if let Some(s) = self.loads.get_mut(slot) {
198 *s = load;
199 }
200 }
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use atomr_config::Config;
209 use atomr_core::actor::ActorSystem;
210
211 use crate::device::DeviceActor;
212 use crate::device::DeviceConfig;
213
214 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
215 async fn round_robin_picks_alternates() {
216 let sys = ActorSystem::create("placement-rr", Config::empty())
217 .await
218 .unwrap();
219 let d0 = sys
220 .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "d0")
221 .unwrap();
222 let d1 = sys
223 .actor_of(DeviceActor::props(DeviceConfig::mock(1)), "d1")
224 .unwrap();
225 let actor = sys
226 .actor_of(
227 PlacementActor::props(
228 vec![(0, d0), (1, d1)],
229 Arc::new(RoundRobinPolicy::default()),
230 ),
231 "placement",
232 )
233 .unwrap();
234
235 let mut picks = Vec::new();
236 for _ in 0..4 {
237 let (tx, rx) = oneshot::channel();
238 actor.tell(PlacementMsg::Pick {
239 hints: PlacementHints::default(),
240 reply: tx,
241 });
242 let c = tokio::time::timeout(Duration::from_secs(2), rx)
243 .await
244 .unwrap()
245 .unwrap()
246 .unwrap();
247 picks.push(c.device_id);
248 }
249 assert_eq!(picks, vec![0, 1, 0, 1]);
251
252 sys.terminate().await;
253 }
254}