reth_beacon_consensus/engine/hooks/
controller.rs1use crate::hooks::{
2 EngineHook, EngineHookContext, EngineHookDBAccessLevel, EngineHookError, EngineHookEvent,
3 EngineHooks,
4};
5use std::{
6 collections::VecDeque,
7 task::{Context, Poll},
8};
9use tracing::debug;
10
11#[derive(Debug)]
12pub(crate) struct PolledHook {
13 pub(crate) name: &'static str,
14 pub(crate) event: EngineHookEvent,
15 pub(crate) db_access_level: EngineHookDBAccessLevel,
16}
17
18pub(crate) struct EngineHooksController {
24 hooks: VecDeque<Box<dyn EngineHook>>,
29 active_db_write_hook: Option<Box<dyn EngineHook>>,
31}
32
33impl EngineHooksController {
34 pub(crate) fn new(hooks: EngineHooks) -> Self {
36 Self { hooks: hooks.inner.into(), active_db_write_hook: None }
37 }
38
39 pub(crate) fn poll_active_db_write_hook(
50 &mut self,
51 cx: &mut Context<'_>,
52 args: EngineHookContext,
53 ) -> Poll<Result<PolledHook, EngineHookError>> {
54 let Some(mut hook) = self.active_db_write_hook.take() else { return Poll::Pending };
55
56 match hook.poll(cx, args)? {
57 Poll::Ready(event) => {
58 let result = PolledHook {
59 name: hook.name(),
60 event,
61 db_access_level: hook.db_access_level(),
62 };
63
64 debug!(
65 target: "consensus::engine::hooks",
66 hook = hook.name(),
67 ?result,
68 "Polled running hook with db write access"
69 );
70
71 if result.event.is_finished() {
72 self.hooks.push_back(hook);
73 } else {
74 self.active_db_write_hook = Some(hook);
75 }
76
77 return Poll::Ready(Ok(result))
78 }
79 Poll::Pending => {
80 self.active_db_write_hook = Some(hook);
81 }
82 }
83
84 Poll::Pending
85 }
86
87 pub(crate) fn poll_next_hook(
99 &mut self,
100 cx: &mut Context<'_>,
101 args: EngineHookContext,
102 db_write_active: bool,
103 ) -> Poll<Result<PolledHook, EngineHookError>> {
104 let Some(mut hook) = self.hooks.pop_front() else { return Poll::Pending };
105
106 let result = self.poll_next_hook_inner(cx, &mut hook, args, db_write_active);
107
108 if matches!(
109 result,
110 Poll::Ready(Ok(PolledHook {
111 event: EngineHookEvent::Started,
112 db_access_level: EngineHookDBAccessLevel::ReadWrite,
113 ..
114 }))
115 ) {
116 self.active_db_write_hook = Some(hook);
118 } else {
119 self.hooks.push_back(hook);
121 }
122
123 result
124 }
125
126 fn poll_next_hook_inner(
127 &self,
128 cx: &mut Context<'_>,
129 hook: &mut Box<dyn EngineHook>,
130 args: EngineHookContext,
131 db_write_active: bool,
132 ) -> Poll<Result<PolledHook, EngineHookError>> {
133 if hook.db_access_level().is_read_write() &&
140 (self.active_db_write_hook.is_some() ||
141 db_write_active ||
142 args.finalized_block_number.is_none())
143 {
144 return Poll::Pending
145 }
146
147 if let Poll::Ready(event) = hook.poll(cx, args)? {
148 let result =
149 PolledHook { name: hook.name(), event, db_access_level: hook.db_access_level() };
150
151 debug!(
152 target: "consensus::engine::hooks",
153 hook = hook.name(),
154 ?result,
155 "Polled next hook"
156 );
157
158 return Poll::Ready(Ok(result))
159 }
160 debug!(target: "consensus::engine::hooks", hook = hook.name(), "Next hook is not ready");
161
162 Poll::Pending
163 }
164
165 pub(crate) fn active_db_write_hook(&self) -> Option<&dyn EngineHook> {
167 self.active_db_write_hook.as_ref().map(|hook| hook.as_ref())
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use crate::hooks::{
174 EngineHook, EngineHookContext, EngineHookDBAccessLevel, EngineHookEvent, EngineHooks,
175 EngineHooksController,
176 };
177 use futures::poll;
178 use reth_errors::{RethError, RethResult};
179 use std::{
180 collections::VecDeque,
181 future::poll_fn,
182 task::{Context, Poll},
183 };
184
185 struct TestHook {
186 results: VecDeque<RethResult<EngineHookEvent>>,
187 name: &'static str,
188 access_level: EngineHookDBAccessLevel,
189 }
190
191 impl TestHook {
192 fn new_ro(name: &'static str) -> Self {
193 Self {
194 results: Default::default(),
195 name,
196 access_level: EngineHookDBAccessLevel::ReadOnly,
197 }
198 }
199 fn new_rw(name: &'static str) -> Self {
200 Self {
201 results: Default::default(),
202 name,
203 access_level: EngineHookDBAccessLevel::ReadWrite,
204 }
205 }
206
207 fn add_result(&mut self, result: RethResult<EngineHookEvent>) {
208 self.results.push_back(result);
209 }
210 }
211
212 impl EngineHook for TestHook {
213 fn name(&self) -> &'static str {
214 self.name
215 }
216
217 fn poll(
218 &mut self,
219 _cx: &mut Context<'_>,
220 _ctx: EngineHookContext,
221 ) -> Poll<RethResult<EngineHookEvent>> {
222 self.results.pop_front().map_or(Poll::Pending, Poll::Ready)
223 }
224
225 fn db_access_level(&self) -> EngineHookDBAccessLevel {
226 self.access_level
227 }
228 }
229
230 #[tokio::test]
231 async fn poll_active_db_write_hook() {
232 let mut controller = EngineHooksController::new(EngineHooks::new());
233
234 let context = EngineHookContext { tip_block_number: 2, finalized_block_number: Some(1) };
235
236 let result = poll!(poll_fn(|cx| controller.poll_active_db_write_hook(cx, context)));
238 assert!(result.is_pending());
239
240 controller.active_db_write_hook = Some(Box::new(TestHook::new_rw("read-write")));
242
243 let result = poll!(poll_fn(|cx| controller.poll_active_db_write_hook(cx, context)));
244 assert!(result.is_pending());
245
246 let mut hook = TestHook::new_rw("read-write");
250 hook.add_result(Ok(EngineHookEvent::Started));
251 controller.active_db_write_hook = Some(Box::new(hook));
252
253 let result = poll!(poll_fn(|cx| controller.poll_active_db_write_hook(cx, context)));
254 assert_eq!(
255 result.map(|result| {
256 let polled_hook = result.unwrap();
257 polled_hook.event.is_started() && polled_hook.db_access_level.is_read_write()
258 }),
259 Poll::Ready(true)
260 );
261 assert!(controller.active_db_write_hook.is_some());
262 assert!(controller.hooks.is_empty());
263
264 let mut hook = TestHook::new_rw("read-write");
268 hook.add_result(Ok(EngineHookEvent::Finished(Ok(()))));
269 controller.active_db_write_hook = Some(Box::new(hook));
270
271 let result = poll!(poll_fn(|cx| controller.poll_active_db_write_hook(cx, context)));
272 assert_eq!(
273 result.map(|result| {
274 let polled_hook = result.unwrap();
275 polled_hook.event.is_finished() && polled_hook.db_access_level.is_read_write()
276 }),
277 Poll::Ready(true)
278 );
279 assert!(controller.active_db_write_hook.is_none());
280 assert!(controller.hooks.pop_front().is_some());
281 }
282
283 #[tokio::test]
284 async fn poll_next_hook_db_write_active() {
285 let context = EngineHookContext { tip_block_number: 2, finalized_block_number: Some(1) };
286
287 let mut hook_rw = TestHook::new_rw("read-write");
288 hook_rw.add_result(Ok(EngineHookEvent::Started));
289
290 let hook_ro_name = "read-only";
291 let mut hook_ro = TestHook::new_ro(hook_ro_name);
292 hook_ro.add_result(Ok(EngineHookEvent::Started));
293
294 let mut hooks = EngineHooks::new();
295 hooks.add(hook_rw);
296 hooks.add(hook_ro);
297 let mut controller = EngineHooksController::new(hooks);
298
299 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, true)));
301 assert!(result.is_pending());
302 assert!(controller.active_db_write_hook.is_none());
303
304 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, true)));
306 assert_eq!(
307 result.map(|result| {
308 let polled_hook = result.unwrap();
309 polled_hook.name == hook_ro_name &&
310 polled_hook.event.is_started() &&
311 polled_hook.db_access_level.is_read_only()
312 }),
313 Poll::Ready(true)
314 );
315 }
316
317 #[tokio::test]
318 async fn poll_next_hook_db_write_inactive() {
319 let context = EngineHookContext { tip_block_number: 2, finalized_block_number: Some(1) };
320
321 let hook_rw_1_name = "read-write-1";
322 let mut hook_rw_1 = TestHook::new_rw(hook_rw_1_name);
323 hook_rw_1.add_result(Ok(EngineHookEvent::Started));
324
325 let hook_rw_2_name = "read-write-2";
326 let mut hook_rw_2 = TestHook::new_rw(hook_rw_2_name);
327 hook_rw_2.add_result(Ok(EngineHookEvent::Started));
328
329 let hook_ro_name = "read-only";
330 let mut hook_ro = TestHook::new_ro(hook_ro_name);
331 hook_ro.add_result(Ok(EngineHookEvent::Started));
332 hook_ro.add_result(Err(RethError::msg("something went wrong")));
333
334 let mut hooks = EngineHooks::new();
335 hooks.add(hook_rw_1);
336 hooks.add(hook_rw_2);
337 hooks.add(hook_ro);
338
339 let mut controller = EngineHooksController::new(hooks);
340 let hooks_len = controller.hooks.len();
341
342 assert_eq!(controller.hooks.front().map(|hook| hook.name()), Some(hook_rw_1_name));
344 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, false)));
345 assert_eq!(
346 result.map(|result| {
347 let polled_hook = result.unwrap();
348 polled_hook.name == hook_rw_1_name &&
349 polled_hook.event.is_started() &&
350 polled_hook.db_access_level.is_read_write()
351 }),
352 Poll::Ready(true)
353 );
354 assert_eq!(
355 controller.active_db_write_hook.as_ref().map(|hook| hook.name()),
356 Some(hook_rw_1_name)
357 );
358
359 assert_eq!(controller.hooks.front().map(|hook| hook.name()), Some(hook_rw_2_name));
361 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, false)));
362 assert!(result.is_pending());
363
364 assert_eq!(controller.hooks.front().map(|hook| hook.name()), Some(hook_ro_name));
366 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, false)));
367 assert_eq!(
368 result.map(|result| {
369 let polled_hook = result.unwrap();
370 polled_hook.name == hook_ro_name &&
371 polled_hook.event.is_started() &&
372 polled_hook.db_access_level.is_read_only()
373 }),
374 Poll::Ready(true)
375 );
376
377 assert_eq!(controller.hooks.front().map(|hook| hook.name()), Some(hook_rw_2_name));
379 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, false)));
380 assert!(result.is_pending());
381
382 assert_eq!(controller.hooks.front().map(|hook| hook.name()), Some(hook_ro_name));
384 let result = poll!(poll_fn(|cx| controller.poll_next_hook(cx, context, false)));
385 assert_eq!(result.map(|result| { result.is_err() }), Poll::Ready(true));
386
387 assert!(controller.active_db_write_hook.is_some());
388 assert_eq!(controller.hooks.len(), hooks_len - 1)
389 }
390}