atomr_accel_cuda/pipeline/
executor.rs1use std::sync::Arc;
9
10use cudarc::driver::CudaStream;
11
12use crate::completion::CompletionStrategy;
13use crate::error::GpuError;
14use crate::pipeline::stage::PipelineStage;
15
16pub async fn run_pipeline<S: PipelineStage>(
23 stages: &mut [S],
24 streams: &[Arc<CudaStream>],
25 completion: &Arc<dyn CompletionStrategy>,
26 input: S::In,
27) -> Result<S::Out, GpuError>
28where
29 S::Out: From<S::In>, {
31 if stages.is_empty() {
32 return Err(GpuError::Unrecoverable("pipeline has zero stages".into()));
33 }
34 if stages.len() != streams.len() {
35 return Err(GpuError::Unrecoverable(format!(
36 "stage count {} != stream count {}",
37 stages.len(),
38 streams.len()
39 )));
40 }
41 let stages_len = stages.len();
42 let mut input = Some(input);
43 let mut wait_event = None;
44 let mut last_out: Option<S::Out> = None;
45 for (i, stage) in stages.iter_mut().enumerate() {
46 let stream = &streams[i];
47 let in_v = input.take().expect("pipeline input consumed prematurely");
48 let (ev, out) = stage.enqueue(stream, wait_event.as_ref(), in_v)?;
49 wait_event = Some(ev);
50 last_out = Some(out);
51 if i + 1 < stages_len {
52 return Err(GpuError::Unrecoverable(
53 "run_pipeline currently supports only single-stage chains; \
54 use PipelineExecutor for multi-stage heterogeneous pipelines"
55 .into(),
56 ));
57 }
58 }
59 let tail_stream = streams.last().unwrap();
61 completion.await_completion(tail_stream).await?;
62 last_out.ok_or_else(|| GpuError::Unrecoverable("pipeline produced no output".into()))
63}
64
65pub struct PipelineExecutor<S0, S1>
67where
68 S0: PipelineStage,
69 S1: PipelineStage<In = S0::Out>,
70{
71 pub s0: S0,
72 pub s1: S1,
73}
74
75impl<S0, S1> PipelineExecutor<S0, S1>
76where
77 S0: PipelineStage,
78 S1: PipelineStage<In = S0::Out>,
79{
80 pub async fn run(
81 &mut self,
82 s0_stream: &Arc<CudaStream>,
83 s1_stream: &Arc<CudaStream>,
84 completion: &Arc<dyn CompletionStrategy>,
85 input: S0::In,
86 ) -> Result<S1::Out, GpuError> {
87 let (ev0, out0) = self.s0.enqueue(s0_stream, None, input)?;
88 let (_ev1, out1) = self.s1.enqueue(s1_stream, Some(&ev0), out0)?;
89 completion.await_completion(s1_stream).await?;
90 Ok(out1)
91 }
92}
93
94pub trait BoxedStage: Send + 'static {
103 fn enqueue_boxed(
104 &mut self,
105 stream: &Arc<CudaStream>,
106 wait_for: Option<&cudarc::driver::CudaEvent>,
107 input: Box<dyn std::any::Any + Send>,
108 ) -> Result<(cudarc::driver::CudaEvent, Box<dyn std::any::Any + Send>), GpuError>;
109}
110
111pub struct StageBox<S: PipelineStage> {
113 inner: S,
114}
115
116impl<S: PipelineStage> StageBox<S> {
117 pub fn new(s: S) -> Self {
118 Self { inner: s }
119 }
120}
121
122impl<S: PipelineStage> BoxedStage for StageBox<S> {
123 fn enqueue_boxed(
124 &mut self,
125 stream: &Arc<CudaStream>,
126 wait_for: Option<&cudarc::driver::CudaEvent>,
127 input: Box<dyn std::any::Any + Send>,
128 ) -> Result<(cudarc::driver::CudaEvent, Box<dyn std::any::Any + Send>), GpuError> {
129 let typed = input.downcast::<S::In>().map_err(|_| {
130 GpuError::Unrecoverable(format!(
131 "PipelineExecutorN: stage input downcast to `{}` failed",
132 std::any::type_name::<S::In>()
133 ))
134 })?;
135 let (ev, out) = self.inner.enqueue(stream, wait_for, *typed)?;
136 Ok((ev, Box::new(out) as Box<dyn std::any::Any + Send>))
137 }
138}
139
140pub struct PipelineExecutorN {
142 stages: Vec<Box<dyn BoxedStage>>,
143}
144
145impl PipelineExecutorN {
146 pub fn new() -> Self {
147 Self { stages: Vec::new() }
148 }
149
150 pub fn push<S: PipelineStage>(mut self, stage: S) -> Self {
151 self.stages.push(Box::new(StageBox::new(stage)));
152 self
153 }
154
155 pub async fn run<I, O>(
159 &mut self,
160 streams: &[Arc<CudaStream>],
161 completion: &Arc<dyn CompletionStrategy>,
162 input: I,
163 ) -> Result<O, GpuError>
164 where
165 I: Send + 'static,
166 O: Send + 'static,
167 {
168 if self.stages.is_empty() {
169 return Err(GpuError::Unrecoverable(
170 "PipelineExecutorN: no stages".into(),
171 ));
172 }
173 if streams.len() != self.stages.len() {
174 return Err(GpuError::Unrecoverable(format!(
175 "PipelineExecutorN: stage count {} != stream count {}",
176 self.stages.len(),
177 streams.len()
178 )));
179 }
180 let mut payload: Box<dyn std::any::Any + Send> = Box::new(input);
181 let mut wait_event: Option<cudarc::driver::CudaEvent> = None;
182 for (stage, stream) in self.stages.iter_mut().zip(streams.iter()) {
183 let (ev, next) = stage.enqueue_boxed(stream, wait_event.as_ref(), payload)?;
184 wait_event = Some(ev);
185 payload = next;
186 }
187 completion.await_completion(streams.last().unwrap()).await?;
188 let out = payload.downcast::<O>().map_err(|_| {
189 GpuError::Unrecoverable(format!(
190 "PipelineExecutorN: tail downcast to `{}` failed",
191 std::any::type_name::<O>()
192 ))
193 })?;
194 Ok(*out)
195 }
196
197 pub fn stage_count(&self) -> usize {
198 self.stages.len()
199 }
200}
201
202impl Default for PipelineExecutorN {
203 fn default() -> Self {
204 Self::new()
205 }
206}