atomr_accel_cuda/placement/sharded.rs
1//! `atomr-cluster-sharding` adapter for [`super::PlacementActor`].
2//!
3//! Bridges the GPU-fleet placement layer to atomr's typed sharding
4//! primitives. Callers wrap their [`crate::device::DeviceMsg`] with a
5//! [`RoutedDeviceMsg { entity_id, msg }`] envelope, and the adapter
6//! exposes an [`EntityRef<DeviceExtractor>`] whose `tell` routes the
7//! message to the device owning that entity id.
8//!
9//! This enables:
10//! - Cluster-wide shard handoff if/when a remote forwarder is wired in
11//! via [`ShardRegion::set_remote_forwarder`].
12//! - Consistent placement: identical `entity_id`s always land on the
13//! same device, even across restarts of the calling code.
14//!
15//! The current implementation uses a simple FxHash-mod-N consistent
16//! routing policy (no live-load awareness). A follow-up can install a
17//! custom [`ShardCoordinator`]-driven allocation strategy that polls
18//! the underlying [`super::PlacementActor`]'s load snapshot.
19
20use std::collections::hash_map::DefaultHasher;
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use atomr_cluster_sharding::{EntityRef, MessageExtractor, ShardCoordinator, ShardRegion};
25use atomr_core::actor::ActorRef;
26
27use crate::device::DeviceMsg;
28
29/// Envelope used by the sharding adapter. Wraps an underlying
30/// `DeviceMsg` with the entity id the caller wants the message to be
31/// routed by. The adapter's `MessageExtractor` reads `entity_id` to
32/// pick the destination shard / device.
33pub struct RoutedDeviceMsg {
34 pub entity_id: String,
35 pub msg: DeviceMsg,
36}
37
38/// `MessageExtractor` impl for [`RoutedDeviceMsg`]. Hashes `entity_id`
39/// into one of `shard_count` shards.
40pub struct DeviceExtractor {
41 shard_count: usize,
42}
43
44impl DeviceExtractor {
45 pub fn new(shard_count: usize) -> Self {
46 Self {
47 shard_count: shard_count.max(1),
48 }
49 }
50}
51
52impl MessageExtractor for DeviceExtractor {
53 type Message = RoutedDeviceMsg;
54
55 fn entity_id(&self, message: &Self::Message) -> String {
56 message.entity_id.clone()
57 }
58
59 fn shard_id(&self, message: &Self::Message) -> String {
60 let mut h = DefaultHasher::new();
61 message.entity_id.hash(&mut h);
62 let n = h.finish() as usize % self.shard_count;
63 format!("shard-{n}")
64 }
65}
66
67/// Adapter that publishes a [`ShardRegion<DeviceExtractor>`] backed by
68/// a fixed pool of pre-spawned [`DeviceActor`](crate::device::DeviceActor)
69/// refs. Each shard maps to one device via `shard_index % devices.len()`,
70/// so identical `entity_id`s always reach the same device.
71pub struct PlacementShardingAdapter {
72 region: Arc<ShardRegion<DeviceExtractor>>,
73}
74
75impl PlacementShardingAdapter {
76 /// Build the adapter from a fleet of device refs. `region_id` is
77 /// the cluster-visible name of this region (mirrors akka.net's
78 /// type-name parameter to `ClusterSharding.Start(...)`).
79 ///
80 /// `shard_count` controls the routing granularity; larger values
81 /// distribute consecutive entity ids more evenly. Defaults to the
82 /// number of devices when zero.
83 pub fn start(
84 region_id: impl Into<String>,
85 devices: Vec<ActorRef<DeviceMsg>>,
86 shard_count: usize,
87 ) -> Self {
88 let n_devices = devices.len().max(1);
89 let n_shards = if shard_count == 0 {
90 n_devices
91 } else {
92 shard_count
93 };
94 let extractor = Arc::new(DeviceExtractor::new(n_shards));
95 let coord = Arc::new(ShardCoordinator::new());
96 // Devices captured in a shared Arc so the handler closures
97 // (one per shard) all see the same fleet.
98 let devices = Arc::new(devices);
99 let devices_for_factory = devices.clone();
100 let region = ShardRegion::new(
101 region_id,
102 extractor,
103 coord,
104 Arc::new(move || {
105 // Each shard creates its own EntityHandler; capture
106 // the shared device pool so all handlers route into
107 // the same fleet. Dispatch is consistent-hash by
108 // entity_id mod n_devices.
109 let devices = devices_for_factory.clone();
110 Box::new(move |entity_id: &str, msg: RoutedDeviceMsg| {
111 if devices.is_empty() {
112 return;
113 }
114 let mut h = DefaultHasher::new();
115 entity_id.hash(&mut h);
116 let idx = (h.finish() as usize) % devices.len();
117 devices[idx].tell(msg.msg);
118 })
119 }),
120 );
121 Self { region }
122 }
123
124 /// Build a typed handle to a particular entity.
125 pub fn entity(&self, entity_id: impl Into<String>) -> EntityRef<DeviceExtractor> {
126 EntityRef::new(self.region.clone(), entity_id.into())
127 }
128
129 /// Direct access to the underlying [`ShardRegion`] — useful for
130 /// installing a remote forwarder or inspecting shard counts.
131 pub fn region(&self) -> Arc<ShardRegion<DeviceExtractor>> {
132 self.region.clone()
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::device::{DeviceActor, DeviceConfig};
140 use atomr_config::Config;
141 use atomr_core::actor::ActorSystem;
142 use std::time::Duration;
143 use tokio::sync::oneshot;
144
145 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
146 async fn entity_ref_routes_to_one_of_the_devices() {
147 let sys = ActorSystem::create("sharding-adapter", Config::empty())
148 .await
149 .unwrap();
150 let d0 = sys
151 .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "d0")
152 .unwrap();
153 let d1 = sys
154 .actor_of(DeviceActor::props(DeviceConfig::mock(1)), "d1")
155 .unwrap();
156 let adapter = PlacementShardingAdapter::start("gpu", vec![d0, d1], 16);
157
158 // Same entity_id should hash to the same device deterministically.
159 let entity = adapter.entity("user-42");
160 let (tx, rx) = oneshot::channel();
161 entity.tell(RoutedDeviceMsg {
162 entity_id: "user-42".into(),
163 msg: DeviceMsg::Allocate { len: 16, reply: tx },
164 });
165 // We don't verify which device served it — only that the route
166 // delivered (the reply arrives, even if it's an
167 // Unrecoverable from mock mode).
168 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
169 .await
170 .expect("Allocate reply should arrive within timeout");
171
172 sys.terminate().await;
173 }
174}