1use std::collections::VecDeque;
19use std::ffi::c_void;
20use std::marker::PhantomData;
21
22use async_trait::async_trait;
23use atomr_core::actor::{Actor, ActorRef, Context, Props};
24use tokio::sync::{mpsc, oneshot};
25use tracing::{debug, warn};
26
27use crate::error::GpuError;
28
29#[derive(Debug, Clone, Copy)]
30pub struct PinnedBufferPoolConfig {
31 pub initial_buffers: usize,
32 pub max_buffers: usize,
33 pub buffer_capacity_bytes: usize,
34 pub allow_oversize: bool,
38}
39
40impl Default for PinnedBufferPoolConfig {
41 fn default() -> Self {
42 Self {
43 initial_buffers: 4,
44 max_buffers: 32,
45 buffer_capacity_bytes: 4 * 1024 * 1024,
46 allow_oversize: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct PinnedPoolStats {
53 pub in_use: usize,
54 pub free: usize,
55 pub total: usize,
56 pub bytes_allocated: usize,
57}
58
59pub struct PinnedSlot {
61 ptr: *mut c_void,
62 capacity_bytes: usize,
63 oversize: bool,
66}
67
68unsafe impl Send for PinnedSlot {}
72unsafe impl Sync for PinnedSlot {}
73
74impl PinnedSlot {
75 fn new(capacity_bytes: usize, oversize: bool) -> Result<Self, GpuError> {
76 let ptr = unsafe { cudarc::driver::result::malloc_host(capacity_bytes, 0) }
80 .map_err(|e| GpuError::OutOfMemory(format!("pinned alloc {capacity_bytes}B: {e}")))?;
81 Ok(Self {
82 ptr,
83 capacity_bytes,
84 oversize,
85 })
86 }
87
88 fn free(self) {
89 drop(self);
91 }
92}
93
94impl Drop for PinnedSlot {
95 fn drop(&mut self) {
96 if !self.ptr.is_null() {
97 unsafe {
98 let _ = cudarc::driver::result::free_host(self.ptr);
99 }
100 self.ptr = std::ptr::null_mut();
101 }
102 }
103}
104
105type PinnedGeneration = u64;
109
110pub enum PinnedPoolMsg {
112 Acquire {
113 len_bytes: usize,
114 reply: oneshot::Sender<Result<PinnedBufHandle, GpuError>>,
115 },
116 InternalReturn {
119 slot: PinnedSlot,
120 generation: PinnedGeneration,
121 },
122 Stats {
123 reply: oneshot::Sender<PinnedPoolStats>,
124 },
125}
126
127pub struct PinnedBufHandle {
130 slot: Option<PinnedSlot>,
131 generation: PinnedGeneration,
132 return_tx: mpsc::UnboundedSender<PinnedPoolMsg>,
133}
134
135impl PinnedBufHandle {
136 pub fn capacity_bytes(&self) -> usize {
137 self.slot.as_ref().map(|s| s.capacity_bytes).unwrap_or(0)
138 }
139
140 pub fn into_typed<T>(mut self, len: usize) -> Result<PinnedBuf<T>, GpuError> {
144 let needed = len.checked_mul(std::mem::size_of::<T>()).ok_or_else(|| {
145 GpuError::Unrecoverable("pinned buf: len * size_of overflowed".into())
146 })?;
147 if needed > self.capacity_bytes() {
148 return Err(GpuError::Unrecoverable(format!(
149 "pinned buf: requested {len} elements ({needed} B) exceeds capacity {} B",
150 self.capacity_bytes()
151 )));
152 }
153 let slot = self.slot.take().expect("PinnedBufHandle slot was None");
154 let ptr = slot.ptr as *mut T;
155 Ok(PinnedBuf {
156 inner: Some(PinnedBufInner {
157 slot,
158 len,
159 return_tx: self.return_tx.clone(),
160 generation: self.generation,
161 }),
162 ptr,
163 len,
164 _marker: PhantomData,
165 })
166 }
167}
168
169impl Drop for PinnedBufHandle {
170 fn drop(&mut self) {
171 if let Some(slot) = self.slot.take() {
173 let _ = self.return_tx.send(PinnedPoolMsg::InternalReturn {
174 slot,
175 generation: self.generation,
176 });
177 }
178 }
179}
180
181pub struct PinnedBuf<T> {
187 inner: Option<PinnedBufInner>,
188 ptr: *mut T,
189 len: usize,
190 _marker: PhantomData<T>,
191}
192
193struct PinnedBufInner {
194 slot: PinnedSlot,
195 #[allow(dead_code)]
196 len: usize,
197 return_tx: mpsc::UnboundedSender<PinnedPoolMsg>,
198 generation: PinnedGeneration,
199}
200
201unsafe impl<T: Send> Send for PinnedBuf<T> {}
202unsafe impl<T: Sync> Sync for PinnedBuf<T> {}
203
204impl<T> PinnedBuf<T> {
205 pub fn len(&self) -> usize {
206 self.len
207 }
208
209 pub fn is_empty(&self) -> bool {
210 self.len == 0
211 }
212
213 pub fn as_ptr(&self) -> *const T {
214 self.ptr
215 }
216
217 pub fn as_mut_ptr(&mut self) -> *mut T {
218 self.ptr
219 }
220
221 pub fn as_slice(&self) -> &[T] {
225 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
229 }
230
231 pub fn as_mut_slice(&mut self) -> &mut [T] {
232 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
233 }
234}
235
236impl<T: std::fmt::Debug> std::fmt::Debug for PinnedBuf<T> {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("PinnedBuf").field("len", &self.len).finish()
239 }
240}
241
242impl<T> Drop for PinnedBuf<T> {
243 fn drop(&mut self) {
244 if let Some(inner) = self.inner.take() {
245 let _ = inner.return_tx.send(PinnedPoolMsg::InternalReturn {
246 slot: inner.slot,
247 generation: inner.generation,
248 });
249 }
250 }
251}
252
253pub struct PinnedBufferPool {
255 config: PinnedBufferPoolConfig,
256 free: VecDeque<PinnedSlot>,
257 in_use: usize,
258 total_minted: usize,
259 bytes_allocated: usize,
260 generation: PinnedGeneration,
263 return_tx: mpsc::UnboundedSender<PinnedPoolMsg>,
268 return_rx_observer: Option<ActorRef<PinnedPoolMsg>>,
269}
270
271impl PinnedBufferPool {
272 pub fn props(config: PinnedBufferPoolConfig) -> Props<Self> {
273 Props::create(move || {
274 let (tx, _rx) = mpsc::unbounded_channel();
275 PinnedBufferPool {
279 config,
280 free: VecDeque::new(),
281 in_use: 0,
282 total_minted: 0,
283 bytes_allocated: 0,
284 generation: 0,
285 return_tx: tx,
286 return_rx_observer: None,
287 }
288 })
289 }
290
291 pub fn stats(&self) -> PinnedPoolStats {
294 PinnedPoolStats {
295 in_use: self.in_use,
296 free: self.free.len(),
297 total: self.total_minted,
298 bytes_allocated: self.bytes_allocated,
299 }
300 }
301
302 fn try_acquire(&mut self, len_bytes: usize) -> Result<PinnedBufHandle, GpuError> {
303 let cap = self.config.buffer_capacity_bytes;
304 let oversize = len_bytes > cap;
305
306 let slot = if oversize {
307 if !self.config.allow_oversize {
308 return Err(GpuError::OutOfMemory(format!(
309 "pinned pool: oversize request {len_bytes}B exceeds slot capacity {cap}B"
310 )));
311 }
312 self.bytes_allocated += len_bytes;
314 self.total_minted += 1;
315 PinnedSlot::new(len_bytes, true)?
316 } else if let Some(slot) = self.free.pop_front() {
317 slot
318 } else {
319 if self.total_minted >= self.config.max_buffers {
320 return Err(GpuError::OutOfMemory(format!(
321 "pinned pool: max_buffers={} reached",
322 self.config.max_buffers
323 )));
324 }
325 self.bytes_allocated += cap;
326 self.total_minted += 1;
327 PinnedSlot::new(cap, false)?
328 };
329
330 self.in_use += 1;
331 Ok(PinnedBufHandle {
332 slot: Some(slot),
333 generation: self.generation,
334 return_tx: self.return_tx.clone(),
335 })
336 }
337
338 fn return_slot(&mut self, slot: PinnedSlot, generation: PinnedGeneration) {
339 if generation != self.generation {
340 self.bytes_allocated = self.bytes_allocated.saturating_sub(slot.capacity_bytes);
343 self.total_minted = self.total_minted.saturating_sub(1);
344 slot.free();
345 return;
346 }
347 self.in_use = self.in_use.saturating_sub(1);
348 if slot.oversize {
349 self.bytes_allocated = self.bytes_allocated.saturating_sub(slot.capacity_bytes);
351 self.total_minted = self.total_minted.saturating_sub(1);
352 slot.free();
353 } else {
354 self.free.push_back(slot);
355 }
356 }
357}
358
359#[async_trait]
360impl Actor for PinnedBufferPool {
361 type Msg = PinnedPoolMsg;
362
363 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
364 let (tx, mut rx) = mpsc::unbounded_channel::<PinnedPoolMsg>();
369 self.return_tx = tx;
370 let self_ref = ctx.self_ref().clone();
371 self.return_rx_observer = Some(self_ref.clone());
372 tokio::spawn(async move {
373 while let Some(msg) = rx.recv().await {
374 self_ref.tell(msg);
375 }
376 });
377 debug!(
378 initial = self.config.initial_buffers,
379 max = self.config.max_buffers,
380 cap = self.config.buffer_capacity_bytes,
381 "PinnedBufferPool started"
382 );
383 }
384
385 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: PinnedPoolMsg) {
386 match msg {
387 PinnedPoolMsg::Acquire { len_bytes, reply } => {
388 let r = self.try_acquire(len_bytes);
389 let _ = reply.send(r);
390 }
391 PinnedPoolMsg::InternalReturn { slot, generation } => {
392 self.return_slot(slot, generation);
393 }
394 PinnedPoolMsg::Stats { reply } => {
395 let _ = reply.send(self.stats());
396 }
397 }
398 }
399
400 async fn pre_restart(&mut self, _ctx: &mut Context<Self>, err: &str) {
401 warn!(%err, "PinnedBufferPool pre_restart — dropping all in-flight buffers");
402 self.free.clear();
406 self.generation += 1;
407 self.in_use = 0;
408 self.total_minted = 0;
409 self.bytes_allocated = 0;
410 }
411
412 async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
413 debug!("PinnedBufferPool post_stop");
414 }
417}