atomr_accel_cuda/pipeline/
sink.rs1use std::sync::Arc;
17
18use cudarc::driver::CudaStream;
19use tokio::sync::mpsc;
20use tokio_stream::wrappers::ReceiverStream;
21
22use crate::completion::CompletionStrategy;
23use crate::error::GpuError;
24use crate::pipeline::executor::PipelineExecutorN;
25
26#[derive(Clone)]
29pub struct PipelineSink<I: Send + 'static> {
30 tx: mpsc::Sender<I>,
31}
32
33impl<I: Send + 'static> PipelineSink<I> {
34 pub async fn submit(&self, item: I) -> Result<(), GpuError> {
35 self.tx
36 .send(item)
37 .await
38 .map_err(|_| GpuError::Unrecoverable("PipelineSink: driver dropped".into()))
39 }
40
41 pub fn try_submit(&self, item: I) -> Result<(), GpuError> {
42 self.tx
43 .try_send(item)
44 .map_err(|e| GpuError::Unrecoverable(format!("PipelineSink try_submit: {e}")))
45 }
46}
47
48pub struct PipelineSource<O: Send + 'static> {
51 rx: mpsc::Receiver<Result<O, GpuError>>,
52}
53
54impl<O: Send + 'static> PipelineSource<O> {
55 pub fn into_stream(self) -> ReceiverStream<Result<O, GpuError>> {
56 ReceiverStream::new(self.rx)
57 }
58}
59
60pub fn spawn_pipeline<I, O>(
64 mut executor: PipelineExecutorN,
65 streams: Vec<Arc<CudaStream>>,
66 completion: Arc<dyn CompletionStrategy>,
67 head_capacity: usize,
68 tail_capacity: usize,
69) -> (PipelineSink<I>, PipelineSource<O>)
70where
71 I: Send + 'static,
72 O: Send + 'static,
73{
74 let (in_tx, mut in_rx) = mpsc::channel::<I>(head_capacity.max(1));
75 let (out_tx, out_rx) = mpsc::channel::<Result<O, GpuError>>(tail_capacity.max(1));
76 tokio::spawn(async move {
77 while let Some(item) = in_rx.recv().await {
78 let result = executor.run::<I, O>(&streams, &completion, item).await;
79 if out_tx.send(result).await.is_err() {
80 break;
81 }
82 }
83 });
84 (PipelineSink { tx: in_tx }, PipelineSource { rx: out_rx })
85}