Skip to main content

atomr_accel_cuda/placement/
sharded.rs

1//! `atomr-cluster-sharding` adapter for [`super::PlacementActor`].
2//!
3//! Bridges the GPU-fleet placement layer to atomr's typed sharding
4//! primitives. Callers wrap their [`crate::device::DeviceMsg`] with a
5//! [`RoutedDeviceMsg { entity_id, msg }`] envelope, and the adapter
6//! exposes an [`EntityRef<DeviceExtractor>`] whose `tell` routes the
7//! message to the device owning that entity id.
8//!
9//! This enables:
10//! - Cluster-wide shard handoff if/when a remote forwarder is wired in
11//!   via [`ShardRegion::set_remote_forwarder`].
12//! - Consistent placement: identical `entity_id`s always land on the
13//!   same device, even across restarts of the calling code.
14//!
15//! The current implementation uses a simple FxHash-mod-N consistent
16//! routing policy (no live-load awareness). A follow-up can install a
17//! custom [`ShardCoordinator`]-driven allocation strategy that polls
18//! the underlying [`super::PlacementActor`]'s load snapshot.
19
20use std::collections::hash_map::DefaultHasher;
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use atomr_cluster_sharding::{EntityRef, MessageExtractor, ShardCoordinator, ShardRegion};
25use atomr_core::actor::ActorRef;
26
27use crate::device::DeviceMsg;
28
29/// Envelope used by the sharding adapter. Wraps an underlying
30/// `DeviceMsg` with the entity id the caller wants the message to be
31/// routed by. The adapter's `MessageExtractor` reads `entity_id` to
32/// pick the destination shard / device.
33pub struct RoutedDeviceMsg {
34    pub entity_id: String,
35    pub msg: DeviceMsg,
36}
37
38/// `MessageExtractor` impl for [`RoutedDeviceMsg`]. Hashes `entity_id`
39/// into one of `shard_count` shards.
40pub struct DeviceExtractor {
41    shard_count: usize,
42}
43
44impl DeviceExtractor {
45    pub fn new(shard_count: usize) -> Self {
46        Self {
47            shard_count: shard_count.max(1),
48        }
49    }
50}
51
52impl MessageExtractor for DeviceExtractor {
53    type Message = RoutedDeviceMsg;
54
55    fn entity_id(&self, message: &Self::Message) -> String {
56        message.entity_id.clone()
57    }
58
59    fn shard_id(&self, message: &Self::Message) -> String {
60        let mut h = DefaultHasher::new();
61        message.entity_id.hash(&mut h);
62        let n = h.finish() as usize % self.shard_count;
63        format!("shard-{n}")
64    }
65}
66
67/// Adapter that publishes a [`ShardRegion<DeviceExtractor>`] backed by
68/// a fixed pool of pre-spawned [`DeviceActor`](crate::device::DeviceActor)
69/// refs. Each shard maps to one device via `shard_index % devices.len()`,
70/// so identical `entity_id`s always reach the same device.
71pub struct PlacementShardingAdapter {
72    region: Arc<ShardRegion<DeviceExtractor>>,
73}
74
75impl PlacementShardingAdapter {
76    /// Build the adapter from a fleet of device refs. `region_id` is
77    /// the cluster-visible name of this region (mirrors akka.net's
78    /// type-name parameter to `ClusterSharding.Start(...)`).
79    ///
80    /// `shard_count` controls the routing granularity; larger values
81    /// distribute consecutive entity ids more evenly. Defaults to the
82    /// number of devices when zero.
83    pub fn start(
84        region_id: impl Into<String>,
85        devices: Vec<ActorRef<DeviceMsg>>,
86        shard_count: usize,
87    ) -> Self {
88        let n_devices = devices.len().max(1);
89        let n_shards = if shard_count == 0 {
90            n_devices
91        } else {
92            shard_count
93        };
94        let extractor = Arc::new(DeviceExtractor::new(n_shards));
95        let coord = Arc::new(ShardCoordinator::new());
96        // Devices captured in a shared Arc so the handler closures
97        // (one per shard) all see the same fleet.
98        let devices = Arc::new(devices);
99        let devices_for_factory = devices.clone();
100        let region = ShardRegion::new(
101            region_id,
102            extractor,
103            coord,
104            Arc::new(move || {
105                // Each shard creates its own EntityHandler; capture
106                // the shared device pool so all handlers route into
107                // the same fleet. Dispatch is consistent-hash by
108                // entity_id mod n_devices.
109                let devices = devices_for_factory.clone();
110                Box::new(move |entity_id: &str, msg: RoutedDeviceMsg| {
111                    if devices.is_empty() {
112                        return;
113                    }
114                    let mut h = DefaultHasher::new();
115                    entity_id.hash(&mut h);
116                    let idx = (h.finish() as usize) % devices.len();
117                    devices[idx].tell(msg.msg);
118                })
119            }),
120        );
121        Self { region }
122    }
123
124    /// Build a typed handle to a particular entity.
125    pub fn entity(&self, entity_id: impl Into<String>) -> EntityRef<DeviceExtractor> {
126        EntityRef::new(self.region.clone(), entity_id.into())
127    }
128
129    /// Direct access to the underlying [`ShardRegion`] — useful for
130    /// installing a remote forwarder or inspecting shard counts.
131    pub fn region(&self) -> Arc<ShardRegion<DeviceExtractor>> {
132        self.region.clone()
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::device::{DeviceActor, DeviceConfig};
140    use atomr_config::Config;
141    use atomr_core::actor::ActorSystem;
142    use std::time::Duration;
143    use tokio::sync::oneshot;
144
145    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
146    async fn entity_ref_routes_to_one_of_the_devices() {
147        let sys = ActorSystem::create("sharding-adapter", Config::empty())
148            .await
149            .unwrap();
150        let d0 = sys
151            .actor_of(DeviceActor::props(DeviceConfig::mock(0)), "d0")
152            .unwrap();
153        let d1 = sys
154            .actor_of(DeviceActor::props(DeviceConfig::mock(1)), "d1")
155            .unwrap();
156        let adapter = PlacementShardingAdapter::start("gpu", vec![d0, d1], 16);
157
158        // Same entity_id should hash to the same device deterministically.
159        let entity = adapter.entity("user-42");
160        let (tx, rx) = oneshot::channel();
161        entity.tell(RoutedDeviceMsg {
162            entity_id: "user-42".into(),
163            msg: DeviceMsg::Allocate { len: 16, reply: tx },
164        });
165        // We don't verify which device served it — only that the route
166        // delivered (the reply arrives, even if it's an
167        // Unrecoverable from mock mode).
168        let _ = tokio::time::timeout(Duration::from_secs(2), rx)
169            .await
170            .expect("Allocate reply should arrive within timeout");
171
172        sys.terminate().await;
173    }
174}