1#![doc(
8 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10 issue_tracker_base_url = "https://github.com/SeismicSystems/seismic-reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
14
15use crate::{
16 metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17 shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21 future::{select, BoxFuture},
22 Future, FutureExt, TryFutureExt,
23};
24use std::{
25 any::Any,
26 fmt::{Display, Formatter},
27 pin::{pin, Pin},
28 sync::{
29 atomic::{AtomicUsize, Ordering},
30 Arc, OnceLock,
31 },
32 task::{ready, Context, Poll},
33};
34use tokio::{
35 runtime::Handle,
36 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37 task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48static GLOBAL_EXECUTOR: OnceLock<TaskExecutor> = OnceLock::new();
50
51#[auto_impl::auto_impl(&, Arc)]
91pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
92 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
101
102 fn spawn_critical_blocking(
104 &self,
105 name: &'static str,
106 fut: BoxFuture<'static, ()>,
107 ) -> JoinHandle<()>;
108}
109
110dyn_clone::clone_trait_object!(TaskSpawner);
111
112#[derive(Debug, Clone, Default)]
114#[non_exhaustive]
115pub struct TokioTaskExecutor;
116
117impl TokioTaskExecutor {
118 pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
120 Box::new(self)
121 }
122}
123
124impl TaskSpawner for TokioTaskExecutor {
125 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
126 tokio::task::spawn(fut)
127 }
128
129 fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
130 tokio::task::spawn(fut)
131 }
132
133 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
134 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
135 }
136
137 fn spawn_critical_blocking(
138 &self,
139 _name: &'static str,
140 fut: BoxFuture<'static, ()>,
141 ) -> JoinHandle<()> {
142 tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
143 }
144}
145
146#[derive(Debug)]
159#[must_use = "TaskManager must be polled to monitor critical tasks"]
160pub struct TaskManager {
161 handle: Handle,
165 task_events_tx: UnboundedSender<TaskEvent>,
167 task_events_rx: UnboundedReceiver<TaskEvent>,
169 signal: Option<Signal>,
173 on_shutdown: Shutdown,
175 graceful_tasks: Arc<AtomicUsize>,
177}
178
179impl TaskManager {
182 pub fn current() -> Self {
192 let handle = Handle::current();
193 Self::new(handle)
194 }
195
196 pub fn new(handle: Handle) -> Self {
200 let (task_events_tx, task_events_rx) = unbounded_channel();
201 let (signal, on_shutdown) = signal();
202 let manager = Self {
203 handle,
204 task_events_tx,
205 task_events_rx,
206 signal: Some(signal),
207 on_shutdown,
208 graceful_tasks: Arc::new(AtomicUsize::new(0)),
209 };
210
211 let _ = GLOBAL_EXECUTOR
212 .set(manager.executor())
213 .inspect_err(|_| error!("Global executor already set"));
214
215 manager
216 }
217
218 pub fn executor(&self) -> TaskExecutor {
221 TaskExecutor {
222 handle: self.handle.clone(),
223 on_shutdown: self.on_shutdown.clone(),
224 task_events_tx: self.task_events_tx.clone(),
225 metrics: Default::default(),
226 graceful_tasks: Arc::clone(&self.graceful_tasks),
227 }
228 }
229
230 pub fn graceful_shutdown(self) {
232 let _ = self.do_graceful_shutdown(None);
233 }
234
235 pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
239 self.do_graceful_shutdown(Some(timeout))
240 }
241
242 fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
243 drop(self.signal);
244 let when = timeout.map(|t| std::time::Instant::now() + t);
245 while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
246 if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
247 debug!("graceful shutdown timed out");
248 return false
249 }
250 std::hint::spin_loop();
251 }
252
253 debug!("gracefully shut down");
254 true
255 }
256}
257
258impl Future for TaskManager {
262 type Output = Result<(), PanickedTaskError>;
263
264 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265 match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
266 Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
267 Some(TaskEvent::GracefulShutdown) | None => {
268 if let Some(signal) = self.get_mut().signal.take() {
269 signal.fire();
270 }
271 Poll::Ready(Ok(()))
272 }
273 }
274 }
275}
276
277#[derive(Debug, thiserror::Error, PartialEq, Eq)]
279pub struct PanickedTaskError {
280 task_name: &'static str,
281 error: Option<String>,
282}
283
284impl Display for PanickedTaskError {
285 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286 let task_name = self.task_name;
287 if let Some(error) = &self.error {
288 write!(f, "Critical task `{task_name}` panicked: `{error}`")
289 } else {
290 write!(f, "Critical task `{task_name}` panicked")
291 }
292 }
293}
294
295impl PanickedTaskError {
296 fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
297 let error = match error.downcast::<String>() {
298 Ok(value) => Some(*value),
299 Err(error) => match error.downcast::<&str>() {
300 Ok(value) => Some(value.to_string()),
301 Err(_) => None,
302 },
303 };
304
305 Self { task_name, error }
306 }
307}
308
309#[derive(Debug)]
311enum TaskEvent {
312 Panic(PanickedTaskError),
314 GracefulShutdown,
316}
317
318#[derive(Debug, Clone)]
320pub struct TaskExecutor {
321 handle: Handle,
325 on_shutdown: Shutdown,
327 task_events_tx: UnboundedSender<TaskEvent>,
329 metrics: TaskExecutorMetrics,
331 graceful_tasks: Arc<AtomicUsize>,
333}
334
335impl TaskExecutor {
338 pub fn try_current() -> Result<Self, NoCurrentTaskExecutorError> {
342 GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default)
343 }
344
345 pub fn current() -> Self {
352 Self::try_current().unwrap()
353 }
354
355 pub const fn handle(&self) -> &Handle {
357 &self.handle
358 }
359
360 pub const fn on_shutdown_signal(&self) -> &Shutdown {
362 &self.on_shutdown
363 }
364
365 fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
367 where
368 F: Future<Output = ()> + Send + 'static,
369 {
370 match task_kind {
371 TaskKind::Default => self.handle.spawn(fut),
372 TaskKind::Blocking => {
373 let handle = self.handle.clone();
374 self.handle.spawn_blocking(move || handle.block_on(fut))
375 }
376 }
377 }
378
379 fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
381 where
382 F: Future<Output = ()> + Send + 'static,
383 {
384 let on_shutdown = self.on_shutdown.clone();
385
386 let finished_regular_tasks_total_metrics =
388 self.metrics.finished_regular_tasks_total.clone();
389 let task = {
391 async move {
392 let _inc_counter_on_drop =
394 IncCounterOnDrop::new(finished_regular_tasks_total_metrics);
395 let fut = pin!(fut);
396 let _ = select(on_shutdown, fut).await;
397 }
398 }
399 .in_current_span();
400
401 self.spawn_on_rt(task, task_kind)
402 }
403
404 pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
409 where
410 F: Future<Output = ()> + Send + 'static,
411 {
412 self.spawn_task_as(fut, TaskKind::Default)
413 }
414
415 pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
420 where
421 F: Future<Output = ()> + Send + 'static,
422 {
423 self.spawn_task_as(fut, TaskKind::Blocking)
424 }
425
426 pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
431 where
432 F: Future<Output = ()> + Send + 'static,
433 {
434 let on_shutdown = self.on_shutdown.clone();
435 let fut = f(on_shutdown);
436
437 let task = fut.in_current_span();
438
439 self.handle.spawn(task)
440 }
441
442 fn spawn_critical_as<F>(
444 &self,
445 name: &'static str,
446 fut: F,
447 task_kind: TaskKind,
448 ) -> JoinHandle<()>
449 where
450 F: Future<Output = ()> + Send + 'static,
451 {
452 let panicked_tasks_tx = self.task_events_tx.clone();
453 let on_shutdown = self.on_shutdown.clone();
454
455 let task = std::panic::AssertUnwindSafe(fut)
457 .catch_unwind()
458 .map_err(move |error| {
459 let task_error = PanickedTaskError::new(name, error);
460 error!("{task_error}");
461 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
462 })
463 .in_current_span();
464
465 let finished_critical_tasks_total_metrics =
467 self.metrics.finished_critical_tasks_total.clone();
468 let task = async move {
469 let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
471 let task = pin!(task);
472 let _ = select(on_shutdown, task).await;
473 };
474
475 self.spawn_on_rt(task, task_kind)
476 }
477
478 pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
483 where
484 F: Future<Output = ()> + Send + 'static,
485 {
486 self.spawn_critical_as(name, fut, TaskKind::Blocking)
487 }
488
489 pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
494 where
495 F: Future<Output = ()> + Send + 'static,
496 {
497 self.spawn_critical_as(name, fut, TaskKind::Default)
498 }
499
500 pub fn spawn_critical_with_shutdown_signal<F>(
504 &self,
505 name: &'static str,
506 f: impl FnOnce(Shutdown) -> F,
507 ) -> JoinHandle<()>
508 where
509 F: Future<Output = ()> + Send + 'static,
510 {
511 let panicked_tasks_tx = self.task_events_tx.clone();
512 let on_shutdown = self.on_shutdown.clone();
513 let fut = f(on_shutdown);
514
515 let task = std::panic::AssertUnwindSafe(fut)
517 .catch_unwind()
518 .map_err(move |error| {
519 let task_error = PanickedTaskError::new(name, error);
520 error!("{task_error}");
521 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
522 })
523 .map(drop)
524 .in_current_span();
525
526 self.handle.spawn(task)
527 }
528
529 pub fn spawn_critical_with_graceful_shutdown_signal<F>(
550 &self,
551 name: &'static str,
552 f: impl FnOnce(GracefulShutdown) -> F,
553 ) -> JoinHandle<()>
554 where
555 F: Future<Output = ()> + Send + 'static,
556 {
557 let panicked_tasks_tx = self.task_events_tx.clone();
558 let on_shutdown = GracefulShutdown::new(
559 self.on_shutdown.clone(),
560 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
561 );
562 let fut = f(on_shutdown);
563
564 let task = std::panic::AssertUnwindSafe(fut)
566 .catch_unwind()
567 .map_err(move |error| {
568 let task_error = PanickedTaskError::new(name, error);
569 error!("{task_error}");
570 let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
571 })
572 .map(drop)
573 .in_current_span();
574
575 self.handle.spawn(task)
576 }
577
578 pub fn spawn_with_graceful_shutdown_signal<F>(
598 &self,
599 f: impl FnOnce(GracefulShutdown) -> F,
600 ) -> JoinHandle<()>
601 where
602 F: Future<Output = ()> + Send + 'static,
603 {
604 let on_shutdown = GracefulShutdown::new(
605 self.on_shutdown.clone(),
606 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
607 );
608 let fut = f(on_shutdown);
609
610 self.handle.spawn(fut)
611 }
612
613 pub fn initiate_graceful_shutdown(
620 &self,
621 ) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
622 self.task_events_tx
623 .send(TaskEvent::GracefulShutdown)
624 .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
625
626 Ok(GracefulShutdown::new(
627 self.on_shutdown.clone(),
628 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
629 ))
630 }
631}
632
633impl TaskSpawner for TaskExecutor {
634 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
635 self.metrics.inc_regular_tasks();
636 self.spawn(fut)
637 }
638
639 fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
640 self.metrics.inc_critical_tasks();
641 Self::spawn_critical(self, name, fut)
642 }
643
644 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
645 self.metrics.inc_regular_tasks();
646 self.spawn_blocking(fut)
647 }
648
649 fn spawn_critical_blocking(
650 &self,
651 name: &'static str,
652 fut: BoxFuture<'static, ()>,
653 ) -> JoinHandle<()> {
654 self.metrics.inc_critical_tasks();
655 Self::spawn_critical_blocking(self, name, fut)
656 }
657}
658
659#[auto_impl::auto_impl(&, Arc)]
661pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
662 fn spawn_critical_with_graceful_shutdown_signal<F>(
667 &self,
668 name: &'static str,
669 f: impl FnOnce(GracefulShutdown) -> F,
670 ) -> JoinHandle<()>
671 where
672 F: Future<Output = ()> + Send + 'static;
673
674 fn spawn_with_graceful_shutdown_signal<F>(
678 &self,
679 f: impl FnOnce(GracefulShutdown) -> F,
680 ) -> JoinHandle<()>
681 where
682 F: Future<Output = ()> + Send + 'static;
683}
684
685impl TaskSpawnerExt for TaskExecutor {
686 fn spawn_critical_with_graceful_shutdown_signal<F>(
687 &self,
688 name: &'static str,
689 f: impl FnOnce(GracefulShutdown) -> F,
690 ) -> JoinHandle<()>
691 where
692 F: Future<Output = ()> + Send + 'static,
693 {
694 Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
695 }
696
697 fn spawn_with_graceful_shutdown_signal<F>(
698 &self,
699 f: impl FnOnce(GracefulShutdown) -> F,
700 ) -> JoinHandle<()>
701 where
702 F: Future<Output = ()> + Send + 'static,
703 {
704 Self::spawn_with_graceful_shutdown_signal(self, f)
705 }
706}
707
708enum TaskKind {
710 Default,
712 Blocking,
714}
715
716#[derive(Debug, Default, thiserror::Error)]
718#[error("No current task executor available.")]
719#[non_exhaustive]
720pub struct NoCurrentTaskExecutorError;
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use std::{sync::atomic::AtomicBool, time::Duration};
726
727 #[test]
728 fn test_cloneable() {
729 #[derive(Clone)]
730 struct ExecutorWrapper {
731 _e: Box<dyn TaskSpawner>,
732 }
733
734 let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
735 let _e = dyn_clone::clone_box(&*executor);
736
737 let e = ExecutorWrapper { _e };
738 let _e2 = e;
739 }
740
741 #[test]
742 fn test_critical() {
743 let runtime = tokio::runtime::Runtime::new().unwrap();
744 let handle = runtime.handle().clone();
745 let manager = TaskManager::new(handle);
746 let executor = manager.executor();
747
748 executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
749
750 runtime.block_on(async move {
751 let err_result = manager.await;
752 assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
753 let panicked_err = err_result.unwrap_err();
754
755 assert_eq!(panicked_err.task_name, "this is a critical task");
756 assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
757 })
758 }
759
760 #[test]
762 fn test_manager_shutdown_critical() {
763 let runtime = tokio::runtime::Runtime::new().unwrap();
764 let handle = runtime.handle().clone();
765 let manager = TaskManager::new(handle.clone());
766 let executor = manager.executor();
767
768 let (signal, shutdown) = signal();
769
770 executor.spawn_critical("this is a critical task", async move {
771 tokio::time::sleep(Duration::from_millis(200)).await;
772 drop(signal);
773 });
774
775 drop(manager);
776
777 handle.block_on(shutdown);
778 }
779
780 #[test]
782 fn test_manager_shutdown() {
783 let runtime = tokio::runtime::Runtime::new().unwrap();
784 let handle = runtime.handle().clone();
785 let manager = TaskManager::new(handle.clone());
786 let executor = manager.executor();
787
788 let (signal, shutdown) = signal();
789
790 executor.spawn(Box::pin(async move {
791 tokio::time::sleep(Duration::from_millis(200)).await;
792 drop(signal);
793 }));
794
795 drop(manager);
796
797 handle.block_on(shutdown);
798 }
799
800 #[test]
801 fn test_manager_graceful_shutdown() {
802 let runtime = tokio::runtime::Runtime::new().unwrap();
803 let handle = runtime.handle().clone();
804 let manager = TaskManager::new(handle);
805 let executor = manager.executor();
806
807 let val = Arc::new(AtomicBool::new(false));
808 let c = val.clone();
809 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
810 let _guard = shutdown.await;
811 tokio::time::sleep(Duration::from_millis(200)).await;
812 c.store(true, Ordering::Relaxed);
813 });
814
815 manager.graceful_shutdown();
816 assert!(val.load(Ordering::Relaxed));
817 }
818
819 #[test]
820 fn test_manager_graceful_shutdown_many() {
821 let runtime = tokio::runtime::Runtime::new().unwrap();
822 let handle = runtime.handle().clone();
823 let manager = TaskManager::new(handle);
824 let executor = manager.executor();
825
826 let counter = Arc::new(AtomicUsize::new(0));
827 let num = 10;
828 for _ in 0..num {
829 let c = counter.clone();
830 executor.spawn_critical_with_graceful_shutdown_signal(
831 "grace",
832 move |shutdown| async move {
833 let _guard = shutdown.await;
834 tokio::time::sleep(Duration::from_millis(200)).await;
835 c.fetch_add(1, Ordering::SeqCst);
836 },
837 );
838 }
839
840 manager.graceful_shutdown();
841 assert_eq!(counter.load(Ordering::Relaxed), num);
842 }
843
844 #[test]
845 fn test_manager_graceful_shutdown_timeout() {
846 let runtime = tokio::runtime::Runtime::new().unwrap();
847 let handle = runtime.handle().clone();
848 let manager = TaskManager::new(handle);
849 let executor = manager.executor();
850
851 let timeout = Duration::from_millis(500);
852 let val = Arc::new(AtomicBool::new(false));
853 let val2 = val.clone();
854 executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
855 let _guard = shutdown.await;
856 tokio::time::sleep(timeout * 3).await;
857 val2.store(true, Ordering::Relaxed);
858 unreachable!("should not be reached");
859 });
860
861 manager.graceful_shutdown_with_timeout(timeout);
862 assert!(!val.load(Ordering::Relaxed));
863 }
864
865 #[test]
866 fn can_access_global() {
867 let runtime = tokio::runtime::Runtime::new().unwrap();
868 let handle = runtime.handle().clone();
869 let _manager = TaskManager::new(handle);
870 let _executor = TaskExecutor::try_current().unwrap();
871 }
872
873 #[test]
874 fn test_graceful_shutdown_triggered_by_executor() {
875 let runtime = tokio::runtime::Runtime::new().unwrap();
876 let task_manager = TaskManager::new(runtime.handle().clone());
877 let executor = task_manager.executor();
878
879 let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
880 let flag_clone = task_did_shutdown_flag.clone();
881
882 let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move {
883 shutdown_signal.await;
884 flag_clone.store(true, Ordering::SeqCst);
885 });
886
887 let manager_future_handle = runtime.spawn(task_manager);
888
889 let send_result = executor.initiate_graceful_shutdown();
890 assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future");
891
892 let manager_final_result = runtime.block_on(manager_future_handle);
893
894 assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
895 assert_eq!(
896 manager_final_result.unwrap(),
897 Ok(()),
898 "TaskManager should resolve cleanly with Ok(()) after graceful shutdown request"
899 );
900
901 let task_join_result = runtime.block_on(spawned_task_handle);
902 assert!(task_join_result.is_ok(), "Spawned task should complete without panic");
903
904 assert!(
905 task_did_shutdown_flag.load(Ordering::Relaxed),
906 "Task should have received the shutdown signal and set the flag"
907 );
908 }
909}