reth_stages_api/pipeline/
set.rs1use crate::{Stage, StageId};
2use std::{
3    collections::HashMap,
4    fmt::{Debug, Formatter},
5};
6
7pub trait StageSet<Provider>: Sized {
14    fn builder(self) -> StageSetBuilder<Provider>;
16
17    fn set<S: Stage<Provider> + 'static>(self, stage: S) -> StageSetBuilder<Provider> {
23        self.builder().set(stage)
24    }
25}
26
27struct StageEntry<Provider> {
28    stage: Box<dyn Stage<Provider>>,
29    enabled: bool,
30}
31
32impl<Provider> Debug for StageEntry<Provider> {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("StageEntry")
35            .field("stage", &self.stage.id())
36            .field("enabled", &self.enabled)
37            .finish()
38    }
39}
40
41pub struct StageSetBuilder<Provider> {
48    stages: HashMap<StageId, StageEntry<Provider>>,
49    order: Vec<StageId>,
50}
51
52impl<Provider> Default for StageSetBuilder<Provider> {
53    fn default() -> Self {
54        Self { stages: HashMap::default(), order: Vec::new() }
55    }
56}
57
58impl<Provider> Debug for StageSetBuilder<Provider> {
59    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("StageSetBuilder")
61            .field("stages", &self.stages)
62            .field("order", &self.order)
63            .finish()
64    }
65}
66
67impl<Provider> StageSetBuilder<Provider> {
68    fn index_of(&self, stage_id: StageId) -> usize {
69        let index = self.order.iter().position(|&id| id == stage_id);
70
71        index.unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"))
72    }
73
74    fn upsert_stage_state(&mut self, stage: Box<dyn Stage<Provider>>, added_at_index: usize) {
75        let stage_id = stage.id();
76        if self.stages.insert(stage.id(), StageEntry { stage, enabled: true }).is_some() {
77            if let Some(to_remove) = self
78                .order
79                .iter()
80                .enumerate()
81                .find(|(i, id)| *i != added_at_index && **id == stage_id)
82                .map(|(i, _)| i)
83            {
84                self.order.remove(to_remove);
85            }
86        }
87    }
88
89    pub fn set<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
95        let entry = self
96            .stages
97            .get_mut(&stage.id())
98            .unwrap_or_else(|| panic!("Stage does not exist in set: {}", stage.id()));
99        entry.stage = Box::new(stage);
100        self
101    }
102
103    pub fn stages(&self) -> impl Iterator<Item = StageId> + '_ {
106        self.order.iter().copied()
107    }
108
109    pub fn replace<S: Stage<Provider> + 'static>(mut self, stage_id: StageId, stage: S) -> Self {
114        self.stages
115            .get(&stage_id)
116            .unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"));
117
118        if stage.id() == stage_id {
119            return self.set(stage);
120        }
121        let index = self.index_of(stage_id);
122        self.stages.remove(&stage_id);
123        self.order[index] = stage.id();
124        self.upsert_stage_state(Box::new(stage), index);
125        self
126    }
127
128    pub fn add_stage<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
132        let target_index = self.order.len();
133        self.order.push(stage.id());
134        self.upsert_stage_state(Box::new(stage), target_index);
135        self
136    }
137
138    pub fn add_stage_opt<S: Stage<Provider> + 'static>(self, stage: Option<S>) -> Self {
142        if let Some(stage) = stage {
143            self.add_stage(stage)
144        } else {
145            self
146        }
147    }
148
149    pub fn add_set<Set: StageSet<Provider>>(mut self, set: Set) -> Self {
154        for stage in set.builder().build() {
155            let target_index = self.order.len();
156            self.order.push(stage.id());
157            self.upsert_stage_state(stage, target_index);
158        }
159        self
160    }
161
162    pub fn add_before<S: Stage<Provider> + 'static>(mut self, stage: S, before: StageId) -> Self {
170        let target_index = self.index_of(before);
171        self.order.insert(target_index, stage.id());
172        self.upsert_stage_state(Box::new(stage), target_index);
173        self
174    }
175
176    pub fn add_after<S: Stage<Provider> + 'static>(mut self, stage: S, after: StageId) -> Self {
184        let target_index = self.index_of(after) + 1;
185        self.order.insert(target_index, stage.id());
186        self.upsert_stage_state(Box::new(stage), target_index);
187        self
188    }
189
190    pub fn enable(mut self, stage_id: StageId) -> Self {
198        let entry =
199            self.stages.get_mut(&stage_id).expect("Cannot enable a stage that is not in the set.");
200        entry.enabled = true;
201        self
202    }
203
204    #[track_caller]
215    pub fn disable(mut self, stage_id: StageId) -> Self {
216        let entry = self
217            .stages
218            .get_mut(&stage_id)
219            .unwrap_or_else(|| panic!("Cannot disable a stage that is not in the set: {stage_id}"));
220        entry.enabled = false;
221        self
222    }
223
224    pub fn disable_all(mut self, stages: &[StageId]) -> Self {
228        for stage_id in stages {
229            let Some(entry) = self.stages.get_mut(stage_id) else { continue };
230            entry.enabled = false;
231        }
232        self
233    }
234
235    #[track_caller]
239    pub fn disable_if<F>(self, stage_id: StageId, f: F) -> Self
240    where
241        F: FnOnce() -> bool,
242    {
243        if f() {
244            return self.disable(stage_id)
245        }
246        self
247    }
248
249    #[track_caller]
253    pub fn disable_all_if<F>(self, stages: &[StageId], f: F) -> Self
254    where
255        F: FnOnce() -> bool,
256    {
257        if f() {
258            return self.disable_all(stages)
259        }
260        self
261    }
262
263    pub fn build(mut self) -> Vec<Box<dyn Stage<Provider>>> {
265        let mut stages = Vec::new();
266        for id in &self.order {
267            if let Some(entry) = self.stages.remove(id) {
268                if entry.enabled {
269                    stages.push(entry.stage);
270                }
271            }
272        }
273        stages
274    }
275}
276
277impl<Provider> StageSet<Provider> for StageSetBuilder<Provider> {
278    fn builder(self) -> Self {
279        self
280    }
281}