Skip to main content

atomr_accel_cuda/pipeline/
sink.rs

1//! Bounded-mpsc Source/Sink adapters around `PipelineExecutorN`.
2//!
3//! Produces an async ergonomics layer on top of the synchronous
4//! executor: a producer (`PipelineSink<I>`) holds the head of a
5//! bounded `tokio::mpsc::Sender<I>`; a driver task pulls items off
6//! the channel one at a time, runs them through the executor, and
7//! pushes results onto a tail `tokio::mpsc::Sender<Result<O>>`.
8//! The consumer (`PipelineSource<O>`) wraps the tail receiver as
9//! a `Stream<Result<O, GpuError>>`.
10//!
11//! Backpressure: the head channel's bound caps how many items can
12//! be queued while the executor is busy. Real atomr-streams
13//! integration (with `OverflowStrategy::Backpressure`) is a
14//! drop-in once that crate is added as a workspace dependency.
15
16use std::sync::Arc;
17
18use cudarc::driver::CudaStream;
19use tokio::sync::mpsc;
20use tokio_stream::wrappers::ReceiverStream;
21
22use crate::completion::CompletionStrategy;
23use crate::error::GpuError;
24use crate::pipeline::executor::PipelineExecutorN;
25
26/// Producer end. `submit` blocks (awaits) when the channel is full
27/// — that's the backpressure signal.
28#[derive(Clone)]
29pub struct PipelineSink<I: Send + 'static> {
30    tx: mpsc::Sender<I>,
31}
32
33impl<I: Send + 'static> PipelineSink<I> {
34    pub async fn submit(&self, item: I) -> Result<(), GpuError> {
35        self.tx
36            .send(item)
37            .await
38            .map_err(|_| GpuError::Unrecoverable("PipelineSink: driver dropped".into()))
39    }
40
41    pub fn try_submit(&self, item: I) -> Result<(), GpuError> {
42        self.tx
43            .try_send(item)
44            .map_err(|e| GpuError::Unrecoverable(format!("PipelineSink try_submit: {e}")))
45    }
46}
47
48/// Consumer end. Returns a `ReceiverStream<Result<O, GpuError>>`
49/// that yields one item per processed input.
50pub struct PipelineSource<O: Send + 'static> {
51    rx: mpsc::Receiver<Result<O, GpuError>>,
52}
53
54impl<O: Send + 'static> PipelineSource<O> {
55    pub fn into_stream(self) -> ReceiverStream<Result<O, GpuError>> {
56        ReceiverStream::new(self.rx)
57    }
58}
59
60/// Spawn a backpressured async pipeline driver around an executor.
61/// Returns `(PipelineSink<I>, PipelineSource<O>)`. The driver runs
62/// on the ambient tokio runtime.
63pub fn spawn_pipeline<I, O>(
64    mut executor: PipelineExecutorN,
65    streams: Vec<Arc<CudaStream>>,
66    completion: Arc<dyn CompletionStrategy>,
67    head_capacity: usize,
68    tail_capacity: usize,
69) -> (PipelineSink<I>, PipelineSource<O>)
70where
71    I: Send + 'static,
72    O: Send + 'static,
73{
74    let (in_tx, mut in_rx) = mpsc::channel::<I>(head_capacity.max(1));
75    let (out_tx, out_rx) = mpsc::channel::<Result<O, GpuError>>(tail_capacity.max(1));
76    tokio::spawn(async move {
77        while let Some(item) = in_rx.recv().await {
78            let result = executor.run::<I, O>(&streams, &completion, item).await;
79            if out_tx.send(result).await.is_err() {
80                break;
81            }
82        }
83    });
84    (PipelineSink { tx: in_tx }, PipelineSource { rx: out_rx })
85}