Skip to main content

atomr_accel_cuda/pipeline/
executor.rs

1//! Minimal pipeline executor: takes N boxed homogeneous stages and
2//! runs them in sequence with event-based handoff.
3//!
4//! Intentionally simple — no Source / Sink integration, no
5//! backpressure adapters. Those land in F3 once concrete patterns
6//! exist that benefit from the full DSL.
7
8use std::sync::Arc;
9
10use cudarc::driver::CudaStream;
11
12use crate::completion::CompletionStrategy;
13use crate::error::GpuError;
14use crate::pipeline::stage::PipelineStage;
15
16/// Run a homogeneous sequence of stages on `streams[i]` for stage i.
17///
18/// Caller supplies one stream per stage (use [`crate::stream::PerActorAllocator`]
19/// to mint them). The executor enqueues all stages, hooking each stage's
20/// returned event into the next stage's `wait_for`, then awaits one
21/// `HostFnCompletion` on the last stream.
22pub async fn run_pipeline<S: PipelineStage>(
23    stages: &mut [S],
24    streams: &[Arc<CudaStream>],
25    completion: &Arc<dyn CompletionStrategy>,
26    input: S::In,
27) -> Result<S::Out, GpuError>
28where
29    S::Out: From<S::In>, // Stage chain identity helper for trivial through-pipelines.
30{
31    if stages.is_empty() {
32        return Err(GpuError::Unrecoverable("pipeline has zero stages".into()));
33    }
34    if stages.len() != streams.len() {
35        return Err(GpuError::Unrecoverable(format!(
36            "stage count {} != stream count {}",
37            stages.len(),
38            streams.len()
39        )));
40    }
41    let stages_len = stages.len();
42    let mut input = Some(input);
43    let mut wait_event = None;
44    let mut last_out: Option<S::Out> = None;
45    for (i, stage) in stages.iter_mut().enumerate() {
46        let stream = &streams[i];
47        let in_v = input.take().expect("pipeline input consumed prematurely");
48        let (ev, out) = stage.enqueue(stream, wait_event.as_ref(), in_v)?;
49        wait_event = Some(ev);
50        last_out = Some(out);
51        if i + 1 < stages_len {
52            return Err(GpuError::Unrecoverable(
53                "run_pipeline currently supports only single-stage chains; \
54                 use PipelineExecutor for multi-stage heterogeneous pipelines"
55                    .into(),
56            ));
57        }
58    }
59    // Tail completion.
60    let tail_stream = streams.last().unwrap();
61    completion.await_completion(tail_stream).await?;
62    last_out.ok_or_else(|| GpuError::Unrecoverable("pipeline produced no output".into()))
63}
64
65/// Two-stage type-state executor — the simplest non-trivial chain.
66pub struct PipelineExecutor<S0, S1>
67where
68    S0: PipelineStage,
69    S1: PipelineStage<In = S0::Out>,
70{
71    pub s0: S0,
72    pub s1: S1,
73}
74
75impl<S0, S1> PipelineExecutor<S0, S1>
76where
77    S0: PipelineStage,
78    S1: PipelineStage<In = S0::Out>,
79{
80    pub async fn run(
81        &mut self,
82        s0_stream: &Arc<CudaStream>,
83        s1_stream: &Arc<CudaStream>,
84        completion: &Arc<dyn CompletionStrategy>,
85        input: S0::In,
86    ) -> Result<S1::Out, GpuError> {
87        let (ev0, out0) = self.s0.enqueue(s0_stream, None, input)?;
88        let (_ev1, out1) = self.s1.enqueue(s1_stream, Some(&ev0), out0)?;
89        completion.await_completion(s1_stream).await?;
90        Ok(out1)
91    }
92}
93
94/// Heterogeneous N-stage executor.
95///
96/// Each stage's `In` and `Out` types are erased into `Box<dyn Any +
97/// Send>`. Stage adapters wrap their typed `PipelineStage` impl into
98/// a `BoxedStage` and the executor drives the chain. Dynamic typing
99/// gives up some compile-time safety in exchange for arbitrarily-long
100/// chains; type mismatches at stage boundaries surface as
101/// `GpuError::Unrecoverable("…downcast failed…")` at runtime.
102pub trait BoxedStage: Send + 'static {
103    fn enqueue_boxed(
104        &mut self,
105        stream: &Arc<CudaStream>,
106        wait_for: Option<&cudarc::driver::CudaEvent>,
107        input: Box<dyn std::any::Any + Send>,
108    ) -> Result<(cudarc::driver::CudaEvent, Box<dyn std::any::Any + Send>), GpuError>;
109}
110
111/// Adapter wrapping any typed `PipelineStage` into a `BoxedStage`.
112pub struct StageBox<S: PipelineStage> {
113    inner: S,
114}
115
116impl<S: PipelineStage> StageBox<S> {
117    pub fn new(s: S) -> Self {
118        Self { inner: s }
119    }
120}
121
122impl<S: PipelineStage> BoxedStage for StageBox<S> {
123    fn enqueue_boxed(
124        &mut self,
125        stream: &Arc<CudaStream>,
126        wait_for: Option<&cudarc::driver::CudaEvent>,
127        input: Box<dyn std::any::Any + Send>,
128    ) -> Result<(cudarc::driver::CudaEvent, Box<dyn std::any::Any + Send>), GpuError> {
129        let typed = input.downcast::<S::In>().map_err(|_| {
130            GpuError::Unrecoverable(format!(
131                "PipelineExecutorN: stage input downcast to `{}` failed",
132                std::any::type_name::<S::In>()
133            ))
134        })?;
135        let (ev, out) = self.inner.enqueue(stream, wait_for, *typed)?;
136        Ok((ev, Box::new(out) as Box<dyn std::any::Any + Send>))
137    }
138}
139
140/// N-stage heterogeneous executor.
141pub struct PipelineExecutorN {
142    stages: Vec<Box<dyn BoxedStage>>,
143}
144
145impl PipelineExecutorN {
146    pub fn new() -> Self {
147        Self { stages: Vec::new() }
148    }
149
150    pub fn push<S: PipelineStage>(mut self, stage: S) -> Self {
151        self.stages.push(Box::new(StageBox::new(stage)));
152        self
153    }
154
155    /// Run the chain across `streams` (one per stage). On success
156    /// returns the tail stage's output. On any stage failure, the
157    /// error short-circuits the chain.
158    pub async fn run<I, O>(
159        &mut self,
160        streams: &[Arc<CudaStream>],
161        completion: &Arc<dyn CompletionStrategy>,
162        input: I,
163    ) -> Result<O, GpuError>
164    where
165        I: Send + 'static,
166        O: Send + 'static,
167    {
168        if self.stages.is_empty() {
169            return Err(GpuError::Unrecoverable(
170                "PipelineExecutorN: no stages".into(),
171            ));
172        }
173        if streams.len() != self.stages.len() {
174            return Err(GpuError::Unrecoverable(format!(
175                "PipelineExecutorN: stage count {} != stream count {}",
176                self.stages.len(),
177                streams.len()
178            )));
179        }
180        let mut payload: Box<dyn std::any::Any + Send> = Box::new(input);
181        let mut wait_event: Option<cudarc::driver::CudaEvent> = None;
182        for (stage, stream) in self.stages.iter_mut().zip(streams.iter()) {
183            let (ev, next) = stage.enqueue_boxed(stream, wait_event.as_ref(), payload)?;
184            wait_event = Some(ev);
185            payload = next;
186        }
187        completion.await_completion(streams.last().unwrap()).await?;
188        let out = payload.downcast::<O>().map_err(|_| {
189            GpuError::Unrecoverable(format!(
190                "PipelineExecutorN: tail downcast to `{}` failed",
191                std::any::type_name::<O>()
192            ))
193        })?;
194        Ok(*out)
195    }
196
197    pub fn stage_count(&self) -> usize {
198        self.stages.len()
199    }
200}
201
202impl Default for PipelineExecutorN {
203    fn default() -> Self {
204        Self::new()
205    }
206}