reth_stages_api/pipeline/
set.rs

1use crate::{Stage, StageId};
2use std::{
3    collections::HashMap,
4    fmt::{Debug, Formatter},
5};
6
7/// Combines multiple [`Stage`]s into a single unit.
8///
9/// A [`StageSet`] is a logical chunk of stages that depend on each other. It is up to the
10/// individual stage sets to determine what kind of configuration they expose.
11///
12/// Individual stages in the set can be added, removed and overridden using [`StageSetBuilder`].
13pub trait StageSet<Provider>: Sized {
14    /// Configures the stages in the set.
15    fn builder(self) -> StageSetBuilder<Provider>;
16
17    /// Overrides the given [`Stage`], if it is in this set.
18    ///
19    /// # Panics
20    ///
21    /// Panics if the [`Stage`] is not in this set.
22    fn set<S: Stage<Provider> + 'static>(self, stage: S) -> StageSetBuilder<Provider> {
23        self.builder().set(stage)
24    }
25}
26
27struct StageEntry<Provider> {
28    stage: Box<dyn Stage<Provider>>,
29    enabled: bool,
30}
31
32impl<Provider> Debug for StageEntry<Provider> {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("StageEntry")
35            .field("stage", &self.stage.id())
36            .field("enabled", &self.enabled)
37            .finish()
38    }
39}
40
41/// Helper to create and configure a [`StageSet`].
42///
43/// The builder provides ordering helpers to ensure that stages that depend on each other are added
44/// to the final sync pipeline before/after their dependencies.
45///
46/// Stages inside the set can be disabled, enabled, overridden and reordered.
47pub struct StageSetBuilder<Provider> {
48    stages: HashMap<StageId, StageEntry<Provider>>,
49    order: Vec<StageId>,
50}
51
52impl<Provider> Default for StageSetBuilder<Provider> {
53    fn default() -> Self {
54        Self { stages: HashMap::default(), order: Vec::new() }
55    }
56}
57
58impl<Provider> Debug for StageSetBuilder<Provider> {
59    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("StageSetBuilder")
61            .field("stages", &self.stages)
62            .field("order", &self.order)
63            .finish()
64    }
65}
66
67impl<Provider> StageSetBuilder<Provider> {
68    fn index_of(&self, stage_id: StageId) -> usize {
69        let index = self.order.iter().position(|&id| id == stage_id);
70
71        index.unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"))
72    }
73
74    fn upsert_stage_state(&mut self, stage: Box<dyn Stage<Provider>>, added_at_index: usize) {
75        let stage_id = stage.id();
76        if self.stages.insert(stage.id(), StageEntry { stage, enabled: true }).is_some() {
77            if let Some(to_remove) = self
78                .order
79                .iter()
80                .enumerate()
81                .find(|(i, id)| *i != added_at_index && **id == stage_id)
82                .map(|(i, _)| i)
83            {
84                self.order.remove(to_remove);
85            }
86        }
87    }
88
89    /// Overrides the given [`Stage`], if it is in this set.
90    ///
91    /// # Panics
92    ///
93    /// Panics if the [`Stage`] is not in this set.
94    pub fn set<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
95        let entry = self
96            .stages
97            .get_mut(&stage.id())
98            .unwrap_or_else(|| panic!("Stage does not exist in set: {}", stage.id()));
99        entry.stage = Box::new(stage);
100        self
101    }
102
103    /// Returns iterator over the stages in this set,
104    /// In the same order they would be executed in the pipeline.
105    pub fn stages(&self) -> impl Iterator<Item = StageId> + '_ {
106        self.order.iter().copied()
107    }
108
109    /// Replaces a stage with the given ID with a new stage.
110    ///
111    /// If the new stage has a different ID,
112    /// it will maintain the original stage's position in the execution order.
113    pub fn replace<S: Stage<Provider> + 'static>(mut self, stage_id: StageId, stage: S) -> Self {
114        self.stages
115            .get(&stage_id)
116            .unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"));
117
118        if stage.id() == stage_id {
119            return self.set(stage);
120        }
121        let index = self.index_of(stage_id);
122        self.stages.remove(&stage_id);
123        self.order[index] = stage.id();
124        self.upsert_stage_state(Box::new(stage), index);
125        self
126    }
127
128    /// Adds the given [`Stage`] at the end of this set.
129    ///
130    /// If the stage was already in the group, it is removed from its previous place.
131    pub fn add_stage<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
132        let target_index = self.order.len();
133        self.order.push(stage.id());
134        self.upsert_stage_state(Box::new(stage), target_index);
135        self
136    }
137
138    /// Adds the given [`Stage`] at the end of this set if it's [`Some`].
139    ///
140    /// If the stage was already in the group, it is removed from its previous place.
141    pub fn add_stage_opt<S: Stage<Provider> + 'static>(self, stage: Option<S>) -> Self {
142        if let Some(stage) = stage {
143            self.add_stage(stage)
144        } else {
145            self
146        }
147    }
148
149    /// Adds the given [`StageSet`] to the end of this set.
150    ///
151    /// If a stage is in both sets, it is removed from its previous place in this set. Because of
152    /// this, it is advisable to merge sets first and re-order stages after if needed.
153    pub fn add_set<Set: StageSet<Provider>>(mut self, set: Set) -> Self {
154        for stage in set.builder().build() {
155            let target_index = self.order.len();
156            self.order.push(stage.id());
157            self.upsert_stage_state(stage, target_index);
158        }
159        self
160    }
161
162    /// Adds the given [`Stage`] before the stage with the given [`StageId`].
163    ///
164    /// If the stage was already in the group, it is removed from its previous place.
165    ///
166    /// # Panics
167    ///
168    /// Panics if the dependency stage is not in this set.
169    pub fn add_before<S: Stage<Provider> + 'static>(mut self, stage: S, before: StageId) -> Self {
170        let target_index = self.index_of(before);
171        self.order.insert(target_index, stage.id());
172        self.upsert_stage_state(Box::new(stage), target_index);
173        self
174    }
175
176    /// Adds the given [`Stage`] after the stage with the given [`StageId`].
177    ///
178    /// If the stage was already in the group, it is removed from its previous place.
179    ///
180    /// # Panics
181    ///
182    /// Panics if the dependency stage is not in this set.
183    pub fn add_after<S: Stage<Provider> + 'static>(mut self, stage: S, after: StageId) -> Self {
184        let target_index = self.index_of(after) + 1;
185        self.order.insert(target_index, stage.id());
186        self.upsert_stage_state(Box::new(stage), target_index);
187        self
188    }
189
190    /// Enables the given stage.
191    ///
192    /// All stages within a [`StageSet`] are enabled by default.
193    ///
194    /// # Panics
195    ///
196    /// Panics if the stage is not in this set.
197    pub fn enable(mut self, stage_id: StageId) -> Self {
198        let entry =
199            self.stages.get_mut(&stage_id).expect("Cannot enable a stage that is not in the set.");
200        entry.enabled = true;
201        self
202    }
203
204    /// Disables the given stage.
205    ///
206    /// The disabled [`Stage`] keeps its place in the set, so it can be used for ordering with
207    /// [`StageSetBuilder::add_before`] or [`StageSetBuilder::add_after`], or it can be re-enabled.
208    ///
209    /// All stages within a [`StageSet`] are enabled by default.
210    ///
211    /// # Panics
212    ///
213    /// Panics if the stage is not in this set.
214    #[track_caller]
215    pub fn disable(mut self, stage_id: StageId) -> Self {
216        let entry = self
217            .stages
218            .get_mut(&stage_id)
219            .unwrap_or_else(|| panic!("Cannot disable a stage that is not in the set: {stage_id}"));
220        entry.enabled = false;
221        self
222    }
223
224    /// Disables all given stages. See [`disable`](Self::disable).
225    ///
226    /// If any of the stages is not in this set, it is ignored.
227    pub fn disable_all(mut self, stages: &[StageId]) -> Self {
228        for stage_id in stages {
229            let Some(entry) = self.stages.get_mut(stage_id) else { continue };
230            entry.enabled = false;
231        }
232        self
233    }
234
235    /// Disables the given stage if the given closure returns true.
236    ///
237    /// See [`Self::disable`]
238    #[track_caller]
239    pub fn disable_if<F>(self, stage_id: StageId, f: F) -> Self
240    where
241        F: FnOnce() -> bool,
242    {
243        if f() {
244            return self.disable(stage_id)
245        }
246        self
247    }
248
249    /// Disables all given stages if the given closure returns true.
250    ///
251    /// See [`Self::disable`]
252    #[track_caller]
253    pub fn disable_all_if<F>(self, stages: &[StageId], f: F) -> Self
254    where
255        F: FnOnce() -> bool,
256    {
257        if f() {
258            return self.disable_all(stages)
259        }
260        self
261    }
262
263    /// Consumes the builder and returns the contained [`Stage`]s in the order specified.
264    pub fn build(mut self) -> Vec<Box<dyn Stage<Provider>>> {
265        let mut stages = Vec::new();
266        for id in &self.order {
267            if let Some(entry) = self.stages.remove(id) {
268                if entry.enabled {
269                    stages.push(entry.stage);
270                }
271            }
272        }
273        stages
274    }
275}
276
277impl<Provider> StageSet<Provider> for StageSetBuilder<Provider> {
278    fn builder(self) -> Self {
279        self
280    }
281}