reth_stages_api/pipeline/
set.rs1use crate::{Stage, StageId};
2use std::{
3 collections::HashMap,
4 fmt::{Debug, Formatter},
5};
6
7pub trait StageSet<Provider>: Sized {
14 fn builder(self) -> StageSetBuilder<Provider>;
16
17 fn set<S: Stage<Provider> + 'static>(self, stage: S) -> StageSetBuilder<Provider> {
23 self.builder().set(stage)
24 }
25}
26
27struct StageEntry<Provider> {
28 stage: Box<dyn Stage<Provider>>,
29 enabled: bool,
30}
31
32impl<Provider> Debug for StageEntry<Provider> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("StageEntry")
35 .field("stage", &self.stage.id())
36 .field("enabled", &self.enabled)
37 .finish()
38 }
39}
40
41pub struct StageSetBuilder<Provider> {
48 stages: HashMap<StageId, StageEntry<Provider>>,
49 order: Vec<StageId>,
50}
51
52impl<Provider> Default for StageSetBuilder<Provider> {
53 fn default() -> Self {
54 Self { stages: HashMap::default(), order: Vec::new() }
55 }
56}
57
58impl<Provider> Debug for StageSetBuilder<Provider> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("StageSetBuilder")
61 .field("stages", &self.stages)
62 .field("order", &self.order)
63 .finish()
64 }
65}
66
67impl<Provider> StageSetBuilder<Provider> {
68 fn index_of(&self, stage_id: StageId) -> usize {
69 let index = self.order.iter().position(|&id| id == stage_id);
70
71 index.unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"))
72 }
73
74 fn upsert_stage_state(&mut self, stage: Box<dyn Stage<Provider>>, added_at_index: usize) {
75 let stage_id = stage.id();
76 if self.stages.insert(stage.id(), StageEntry { stage, enabled: true }).is_some() {
77 if let Some(to_remove) = self
78 .order
79 .iter()
80 .enumerate()
81 .find(|(i, id)| *i != added_at_index && **id == stage_id)
82 .map(|(i, _)| i)
83 {
84 self.order.remove(to_remove);
85 }
86 }
87 }
88
89 pub fn set<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
95 let entry = self
96 .stages
97 .get_mut(&stage.id())
98 .unwrap_or_else(|| panic!("Stage does not exist in set: {}", stage.id()));
99 entry.stage = Box::new(stage);
100 self
101 }
102
103 pub fn add_stage<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
107 let target_index = self.order.len();
108 self.order.push(stage.id());
109 self.upsert_stage_state(Box::new(stage), target_index);
110 self
111 }
112
113 pub fn add_stage_opt<S: Stage<Provider> + 'static>(self, stage: Option<S>) -> Self {
117 if let Some(stage) = stage {
118 self.add_stage(stage)
119 } else {
120 self
121 }
122 }
123
124 pub fn add_set<Set: StageSet<Provider>>(mut self, set: Set) -> Self {
129 for stage in set.builder().build() {
130 let target_index = self.order.len();
131 self.order.push(stage.id());
132 self.upsert_stage_state(stage, target_index);
133 }
134 self
135 }
136
137 pub fn add_before<S: Stage<Provider> + 'static>(mut self, stage: S, before: StageId) -> Self {
145 let target_index = self.index_of(before);
146 self.order.insert(target_index, stage.id());
147 self.upsert_stage_state(Box::new(stage), target_index);
148 self
149 }
150
151 pub fn add_after<S: Stage<Provider> + 'static>(mut self, stage: S, after: StageId) -> Self {
159 let target_index = self.index_of(after) + 1;
160 self.order.insert(target_index, stage.id());
161 self.upsert_stage_state(Box::new(stage), target_index);
162 self
163 }
164
165 pub fn enable(mut self, stage_id: StageId) -> Self {
173 let entry =
174 self.stages.get_mut(&stage_id).expect("Cannot enable a stage that is not in the set.");
175 entry.enabled = true;
176 self
177 }
178
179 #[track_caller]
190 pub fn disable(mut self, stage_id: StageId) -> Self {
191 let entry = self
192 .stages
193 .get_mut(&stage_id)
194 .unwrap_or_else(|| panic!("Cannot disable a stage that is not in the set: {stage_id}"));
195 entry.enabled = false;
196 self
197 }
198
199 pub fn disable_all(mut self, stages: &[StageId]) -> Self {
203 for stage_id in stages {
204 let Some(entry) = self.stages.get_mut(stage_id) else { continue };
205 entry.enabled = false;
206 }
207 self
208 }
209
210 #[track_caller]
214 pub fn disable_if<F>(self, stage_id: StageId, f: F) -> Self
215 where
216 F: FnOnce() -> bool,
217 {
218 if f() {
219 return self.disable(stage_id)
220 }
221 self
222 }
223
224 #[track_caller]
228 pub fn disable_all_if<F>(self, stages: &[StageId], f: F) -> Self
229 where
230 F: FnOnce() -> bool,
231 {
232 if f() {
233 return self.disable_all(stages)
234 }
235 self
236 }
237
238 pub fn build(mut self) -> Vec<Box<dyn Stage<Provider>>> {
240 let mut stages = Vec::new();
241 for id in &self.order {
242 if let Some(entry) = self.stages.remove(id) {
243 if entry.enabled {
244 stages.push(entry.stage);
245 }
246 }
247 }
248 stages
249 }
250}
251
252impl<Provider> StageSet<Provider> for StageSetBuilder<Provider> {
253 fn builder(self) -> Self {
254 self
255 }
256}