reth_stages_api/pipeline/
set.rs1use crate::{Stage, StageId};
2use std::{
3 collections::HashMap,
4 fmt::{Debug, Formatter},
5};
6
7pub trait StageSet<Provider>: Sized {
14 fn builder(self) -> StageSetBuilder<Provider>;
16
17 fn set<S: Stage<Provider> + 'static>(self, stage: S) -> StageSetBuilder<Provider> {
23 self.builder().set(stage)
24 }
25}
26
27struct StageEntry<Provider> {
28 stage: Box<dyn Stage<Provider>>,
29 enabled: bool,
30}
31
32impl<Provider> Debug for StageEntry<Provider> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("StageEntry")
35 .field("stage", &self.stage.id())
36 .field("enabled", &self.enabled)
37 .finish()
38 }
39}
40
41pub struct StageSetBuilder<Provider> {
48 stages: HashMap<StageId, StageEntry<Provider>>,
49 order: Vec<StageId>,
50}
51
52impl<Provider> Default for StageSetBuilder<Provider> {
53 fn default() -> Self {
54 Self { stages: HashMap::default(), order: Vec::new() }
55 }
56}
57
58impl<Provider> Debug for StageSetBuilder<Provider> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("StageSetBuilder")
61 .field("stages", &self.stages)
62 .field("order", &self.order)
63 .finish()
64 }
65}
66
67impl<Provider> StageSetBuilder<Provider> {
68 fn index_of(&self, stage_id: StageId) -> usize {
69 let index = self.order.iter().position(|&id| id == stage_id);
70
71 index.unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"))
72 }
73
74 fn upsert_stage_state(&mut self, stage: Box<dyn Stage<Provider>>, added_at_index: usize) {
75 let stage_id = stage.id();
76 if self.stages.insert(stage.id(), StageEntry { stage, enabled: true }).is_some() {
77 if let Some(to_remove) = self
78 .order
79 .iter()
80 .enumerate()
81 .find(|(i, id)| *i != added_at_index && **id == stage_id)
82 .map(|(i, _)| i)
83 {
84 self.order.remove(to_remove);
85 }
86 }
87 }
88
89 pub fn set<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
95 let entry = self
96 .stages
97 .get_mut(&stage.id())
98 .unwrap_or_else(|| panic!("Stage does not exist in set: {}", stage.id()));
99 entry.stage = Box::new(stage);
100 self
101 }
102
103 pub fn stages(&self) -> impl Iterator<Item = StageId> + '_ {
106 self.order.iter().copied()
107 }
108
109 pub fn replace<S: Stage<Provider> + 'static>(mut self, stage_id: StageId, stage: S) -> Self {
114 self.stages
115 .get(&stage_id)
116 .unwrap_or_else(|| panic!("Stage does not exist in set: {stage_id}"));
117
118 if stage.id() == stage_id {
119 return self.set(stage);
120 }
121 let index = self.index_of(stage_id);
122 self.stages.remove(&stage_id);
123 self.order[index] = stage.id();
124 self.upsert_stage_state(Box::new(stage), index);
125 self
126 }
127
128 pub fn add_stage<S: Stage<Provider> + 'static>(mut self, stage: S) -> Self {
132 let target_index = self.order.len();
133 self.order.push(stage.id());
134 self.upsert_stage_state(Box::new(stage), target_index);
135 self
136 }
137
138 pub fn add_stage_opt<S: Stage<Provider> + 'static>(self, stage: Option<S>) -> Self {
142 if let Some(stage) = stage {
143 self.add_stage(stage)
144 } else {
145 self
146 }
147 }
148
149 pub fn add_set<Set: StageSet<Provider>>(mut self, set: Set) -> Self {
154 for stage in set.builder().build() {
155 let target_index = self.order.len();
156 self.order.push(stage.id());
157 self.upsert_stage_state(stage, target_index);
158 }
159 self
160 }
161
162 pub fn add_before<S: Stage<Provider> + 'static>(mut self, stage: S, before: StageId) -> Self {
170 let target_index = self.index_of(before);
171 self.order.insert(target_index, stage.id());
172 self.upsert_stage_state(Box::new(stage), target_index);
173 self
174 }
175
176 pub fn add_after<S: Stage<Provider> + 'static>(mut self, stage: S, after: StageId) -> Self {
184 let target_index = self.index_of(after) + 1;
185 self.order.insert(target_index, stage.id());
186 self.upsert_stage_state(Box::new(stage), target_index);
187 self
188 }
189
190 pub fn enable(mut self, stage_id: StageId) -> Self {
198 let entry =
199 self.stages.get_mut(&stage_id).expect("Cannot enable a stage that is not in the set.");
200 entry.enabled = true;
201 self
202 }
203
204 #[track_caller]
215 pub fn disable(mut self, stage_id: StageId) -> Self {
216 let entry = self
217 .stages
218 .get_mut(&stage_id)
219 .unwrap_or_else(|| panic!("Cannot disable a stage that is not in the set: {stage_id}"));
220 entry.enabled = false;
221 self
222 }
223
224 pub fn disable_all(mut self, stages: &[StageId]) -> Self {
228 for stage_id in stages {
229 let Some(entry) = self.stages.get_mut(stage_id) else { continue };
230 entry.enabled = false;
231 }
232 self
233 }
234
235 #[track_caller]
239 pub fn disable_if<F>(self, stage_id: StageId, f: F) -> Self
240 where
241 F: FnOnce() -> bool,
242 {
243 if f() {
244 return self.disable(stage_id)
245 }
246 self
247 }
248
249 #[track_caller]
253 pub fn disable_all_if<F>(self, stages: &[StageId], f: F) -> Self
254 where
255 F: FnOnce() -> bool,
256 {
257 if f() {
258 return self.disable_all(stages)
259 }
260 self
261 }
262
263 pub fn build(mut self) -> Vec<Box<dyn Stage<Provider>>> {
265 let mut stages = Vec::new();
266 for id in &self.order {
267 if let Some(entry) = self.stages.remove(id) {
268 if entry.enabled {
269 stages.push(entry.stage);
270 }
271 }
272 }
273 stages
274 }
275}
276
277impl<Provider> StageSet<Provider> for StageSetBuilder<Provider> {
278 fn builder(self) -> Self {
279 self
280 }
281}