reth_tasks/
lib.rs

1//! Reth task management.
2//!
3//! # Feature Flags
4//!
5//! - `rayon`: Enable rayon thread pool for blocking tasks.
6
7#![doc(
8    html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
9    html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
10    issue_tracker_base_url = "https://github.com/SeismicSystems/seismic-reth/issues/"
11)]
12#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
14
15use crate::{
16    metrics::{IncCounterOnDrop, TaskExecutorMetrics},
17    shutdown::{signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal},
18};
19use dyn_clone::DynClone;
20use futures_util::{
21    future::{select, BoxFuture},
22    Future, FutureExt, TryFutureExt,
23};
24use std::{
25    any::Any,
26    fmt::{Display, Formatter},
27    pin::{pin, Pin},
28    sync::{
29        atomic::{AtomicUsize, Ordering},
30        Arc, OnceLock,
31    },
32    task::{ready, Context, Poll},
33};
34use tokio::{
35    runtime::Handle,
36    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
37    task::JoinHandle,
38};
39use tracing::{debug, error};
40use tracing_futures::Instrument;
41
42pub mod metrics;
43pub mod shutdown;
44
45#[cfg(feature = "rayon")]
46pub mod pool;
47
48/// Global [`TaskExecutor`] instance that can be accessed from anywhere.
49static GLOBAL_EXECUTOR: OnceLock<TaskExecutor> = OnceLock::new();
50
51/// A type that can spawn tasks.
52///
53/// The main purpose of this type is to abstract over [`TaskExecutor`] so it's more convenient to
54/// provide default impls for testing.
55///
56///
57/// # Examples
58///
59/// Use the [`TokioTaskExecutor`] that spawns with [`tokio::task::spawn`]
60///
61/// ```
62/// # async fn t() {
63/// use reth_tasks::{TaskSpawner, TokioTaskExecutor};
64/// let executor = TokioTaskExecutor::default();
65///
66/// let task = executor.spawn(Box::pin(async {
67///     // -- snip --
68/// }));
69/// task.await.unwrap();
70/// # }
71/// ```
72///
73/// Use the [`TaskExecutor`] that spawns task directly onto the tokio runtime via the [Handle].
74///
75/// ```
76/// # use reth_tasks::TaskManager;
77/// fn t() {
78///  use reth_tasks::TaskSpawner;
79/// let rt = tokio::runtime::Runtime::new().unwrap();
80/// let manager = TaskManager::new(rt.handle().clone());
81/// let executor = manager.executor();
82/// let task = TaskSpawner::spawn(&executor, Box::pin(async {
83///     // -- snip --
84/// }));
85/// rt.block_on(task).unwrap();
86/// # }
87/// ```
88///
89/// The [`TaskSpawner`] trait is [`DynClone`] so `Box<dyn TaskSpawner>` are also `Clone`.
90#[auto_impl::auto_impl(&, Arc)]
91pub trait TaskSpawner: Send + Sync + Unpin + std::fmt::Debug + DynClone {
92    /// Spawns the task onto the runtime.
93    /// See also [`Handle::spawn`].
94    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
95
96    /// This spawns a critical task onto the runtime.
97    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
98
99    /// Spawns a blocking task onto the runtime.
100    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
101
102    /// This spawns a critical blocking task onto the runtime.
103    fn spawn_critical_blocking(
104        &self,
105        name: &'static str,
106        fut: BoxFuture<'static, ()>,
107    ) -> JoinHandle<()>;
108}
109
110dyn_clone::clone_trait_object!(TaskSpawner);
111
112/// An [`TaskSpawner`] that uses [`tokio::task::spawn`] to execute tasks
113#[derive(Debug, Clone, Default)]
114#[non_exhaustive]
115pub struct TokioTaskExecutor;
116
117impl TokioTaskExecutor {
118    /// Converts the instance to a boxed [`TaskSpawner`].
119    pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> {
120        Box::new(self)
121    }
122}
123
124impl TaskSpawner for TokioTaskExecutor {
125    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
126        tokio::task::spawn(fut)
127    }
128
129    fn spawn_critical(&self, _name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
130        tokio::task::spawn(fut)
131    }
132
133    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
134        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
135    }
136
137    fn spawn_critical_blocking(
138        &self,
139        _name: &'static str,
140        fut: BoxFuture<'static, ()>,
141    ) -> JoinHandle<()> {
142        tokio::task::spawn_blocking(move || tokio::runtime::Handle::current().block_on(fut))
143    }
144}
145
146/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
147/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
148/// that handle the traffic to and from a peer.
149///
150/// To unify how tasks are created, the [`TaskManager`] provides access to the configured Tokio
151/// runtime. A [`TaskManager`] stores the [`tokio::runtime::Handle`] it is associated with. In this
152/// way it is possible to configure on which runtime a task is executed.
153///
154/// The main purpose of this type is to be able to monitor if a critical task panicked, for
155/// diagnostic purposes, since tokio task essentially fail silently. Therefore, this type is a
156/// Stream that yields the name of panicked task, See [`TaskExecutor::spawn_critical`]. In order to
157/// execute Tasks use the [`TaskExecutor`] type [`TaskManager::executor`].
158#[derive(Debug)]
159#[must_use = "TaskManager must be polled to monitor critical tasks"]
160pub struct TaskManager {
161    /// Handle to the tokio runtime this task manager is associated with.
162    ///
163    /// See [`Handle`] docs.
164    handle: Handle,
165    /// Sender half for sending task events to this type
166    task_events_tx: UnboundedSender<TaskEvent>,
167    /// Receiver for task events
168    task_events_rx: UnboundedReceiver<TaskEvent>,
169    /// The [Signal] to fire when all tasks should be shutdown.
170    ///
171    /// This is fired when dropped.
172    signal: Option<Signal>,
173    /// Receiver of the shutdown signal.
174    on_shutdown: Shutdown,
175    /// How many [`GracefulShutdown`] tasks are currently active
176    graceful_tasks: Arc<AtomicUsize>,
177}
178
179// === impl TaskManager ===
180
181impl TaskManager {
182    /// Returns a __new__ [`TaskManager`] over the currently running Runtime.
183    ///
184    /// This must be polled for the duration of the program.
185    ///
186    /// To obtain the current [`TaskExecutor`] see [`TaskExecutor::current`].
187    ///
188    /// # Panics
189    ///
190    /// This will panic if called outside the context of a Tokio runtime.
191    pub fn current() -> Self {
192        let handle = Handle::current();
193        Self::new(handle)
194    }
195
196    /// Create a new instance connected to the given handle's tokio runtime.
197    ///
198    /// This also sets the global [`TaskExecutor`].
199    pub fn new(handle: Handle) -> Self {
200        let (task_events_tx, task_events_rx) = unbounded_channel();
201        let (signal, on_shutdown) = signal();
202        let manager = Self {
203            handle,
204            task_events_tx,
205            task_events_rx,
206            signal: Some(signal),
207            on_shutdown,
208            graceful_tasks: Arc::new(AtomicUsize::new(0)),
209        };
210
211        let _ = GLOBAL_EXECUTOR
212            .set(manager.executor())
213            .inspect_err(|_| error!("Global executor already set"));
214
215        manager
216    }
217
218    /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
219    /// connected to.
220    pub fn executor(&self) -> TaskExecutor {
221        TaskExecutor {
222            handle: self.handle.clone(),
223            on_shutdown: self.on_shutdown.clone(),
224            task_events_tx: self.task_events_tx.clone(),
225            metrics: Default::default(),
226            graceful_tasks: Arc::clone(&self.graceful_tasks),
227        }
228    }
229
230    /// Fires the shutdown signal and awaits until all tasks are shutdown.
231    pub fn graceful_shutdown(self) {
232        let _ = self.do_graceful_shutdown(None);
233    }
234
235    /// Fires the shutdown signal and awaits until all tasks are shutdown.
236    ///
237    /// Returns true if all tasks were shutdown before the timeout elapsed.
238    pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
239        self.do_graceful_shutdown(Some(timeout))
240    }
241
242    fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
243        drop(self.signal);
244        let when = timeout.map(|t| std::time::Instant::now() + t);
245        while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
246            if when.map(|when| std::time::Instant::now() > when).unwrap_or(false) {
247                debug!("graceful shutdown timed out");
248                return false
249            }
250            std::hint::spin_loop();
251        }
252
253        debug!("gracefully shut down");
254        true
255    }
256}
257
258/// An endless future that resolves if a critical task panicked.
259///
260/// See [`TaskExecutor::spawn_critical`]
261impl Future for TaskManager {
262    type Output = Result<(), PanickedTaskError>;
263
264    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        match ready!(self.as_mut().get_mut().task_events_rx.poll_recv(cx)) {
266            Some(TaskEvent::Panic(err)) => Poll::Ready(Err(err)),
267            Some(TaskEvent::GracefulShutdown) | None => {
268                if let Some(signal) = self.get_mut().signal.take() {
269                    signal.fire();
270                }
271                Poll::Ready(Ok(()))
272            }
273        }
274    }
275}
276
277/// Error with the name of the task that panicked and an error downcasted to string, if possible.
278#[derive(Debug, thiserror::Error, PartialEq, Eq)]
279pub struct PanickedTaskError {
280    task_name: &'static str,
281    error: Option<String>,
282}
283
284impl Display for PanickedTaskError {
285    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286        let task_name = self.task_name;
287        if let Some(error) = &self.error {
288            write!(f, "Critical task `{task_name}` panicked: `{error}`")
289        } else {
290            write!(f, "Critical task `{task_name}` panicked")
291        }
292    }
293}
294
295impl PanickedTaskError {
296    fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
297        let error = match error.downcast::<String>() {
298            Ok(value) => Some(*value),
299            Err(error) => match error.downcast::<&str>() {
300                Ok(value) => Some(value.to_string()),
301                Err(_) => None,
302            },
303        };
304
305        Self { task_name, error }
306    }
307}
308
309/// Represents the events that the `TaskManager`'s main future can receive.
310#[derive(Debug)]
311enum TaskEvent {
312    /// Indicates that a critical task has panicked.
313    Panic(PanickedTaskError),
314    /// A signal requesting a graceful shutdown of the `TaskManager`.
315    GracefulShutdown,
316}
317
318/// A type that can spawn new tokio tasks
319#[derive(Debug, Clone)]
320pub struct TaskExecutor {
321    /// Handle to the tokio runtime this task manager is associated with.
322    ///
323    /// See [`Handle`] docs.
324    handle: Handle,
325    /// Receiver of the shutdown signal.
326    on_shutdown: Shutdown,
327    /// Sender half for sending task events to this type
328    task_events_tx: UnboundedSender<TaskEvent>,
329    /// Task Executor Metrics
330    metrics: TaskExecutorMetrics,
331    /// How many [`GracefulShutdown`] tasks are currently active
332    graceful_tasks: Arc<AtomicUsize>,
333}
334
335// === impl TaskExecutor ===
336
337impl TaskExecutor {
338    /// Attempts to get the current `TaskExecutor` if one has been initialized.
339    ///
340    /// Returns an error if no [`TaskExecutor`] has been initialized via [`TaskManager`].
341    pub fn try_current() -> Result<Self, NoCurrentTaskExecutorError> {
342        GLOBAL_EXECUTOR.get().cloned().ok_or_else(NoCurrentTaskExecutorError::default)
343    }
344
345    /// Returns the current `TaskExecutor`.
346    ///
347    /// # Panics
348    ///
349    /// Panics if no global executor has been initialized. Use [`try_current`](Self::try_current)
350    /// for a non-panicking version.
351    pub fn current() -> Self {
352        Self::try_current().unwrap()
353    }
354
355    /// Returns the [Handle] to the tokio runtime.
356    pub const fn handle(&self) -> &Handle {
357        &self.handle
358    }
359
360    /// Returns the receiver of the shutdown signal.
361    pub const fn on_shutdown_signal(&self) -> &Shutdown {
362        &self.on_shutdown
363    }
364
365    /// Spawns a future on the tokio runtime depending on the [`TaskKind`]
366    fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
367    where
368        F: Future<Output = ()> + Send + 'static,
369    {
370        match task_kind {
371            TaskKind::Default => self.handle.spawn(fut),
372            TaskKind::Blocking => {
373                let handle = self.handle.clone();
374                self.handle.spawn_blocking(move || handle.block_on(fut))
375            }
376        }
377    }
378
379    /// Spawns a regular task depending on the given [`TaskKind`]
380    fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
381    where
382        F: Future<Output = ()> + Send + 'static,
383    {
384        let on_shutdown = self.on_shutdown.clone();
385
386        // Clone only the specific counter that we need.
387        let finished_regular_tasks_total_metrics =
388            self.metrics.finished_regular_tasks_total.clone();
389        // Wrap the original future to increment the finished tasks counter upon completion
390        let task = {
391            async move {
392                // Create an instance of IncCounterOnDrop with the counter to increment
393                let _inc_counter_on_drop =
394                    IncCounterOnDrop::new(finished_regular_tasks_total_metrics);
395                let fut = pin!(fut);
396                let _ = select(on_shutdown, fut).await;
397            }
398        }
399        .in_current_span();
400
401        self.spawn_on_rt(task, task_kind)
402    }
403
404    /// Spawns the task onto the runtime.
405    /// The given future resolves as soon as the [Shutdown] signal is received.
406    ///
407    /// See also [`Handle::spawn`].
408    pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
409    where
410        F: Future<Output = ()> + Send + 'static,
411    {
412        self.spawn_task_as(fut, TaskKind::Default)
413    }
414
415    /// Spawns a blocking task onto the runtime.
416    /// The given future resolves as soon as the [Shutdown] signal is received.
417    ///
418    /// See also [`Handle::spawn_blocking`].
419    pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
420    where
421        F: Future<Output = ()> + Send + 'static,
422    {
423        self.spawn_task_as(fut, TaskKind::Blocking)
424    }
425
426    /// Spawns the task onto the runtime.
427    /// The given future resolves as soon as the [Shutdown] signal is received.
428    ///
429    /// See also [`Handle::spawn`].
430    pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
431    where
432        F: Future<Output = ()> + Send + 'static,
433    {
434        let on_shutdown = self.on_shutdown.clone();
435        let fut = f(on_shutdown);
436
437        let task = fut.in_current_span();
438
439        self.handle.spawn(task)
440    }
441
442    /// Spawns a critical task depending on the given [`TaskKind`]
443    fn spawn_critical_as<F>(
444        &self,
445        name: &'static str,
446        fut: F,
447        task_kind: TaskKind,
448    ) -> JoinHandle<()>
449    where
450        F: Future<Output = ()> + Send + 'static,
451    {
452        let panicked_tasks_tx = self.task_events_tx.clone();
453        let on_shutdown = self.on_shutdown.clone();
454
455        // wrap the task in catch unwind
456        let task = std::panic::AssertUnwindSafe(fut)
457            .catch_unwind()
458            .map_err(move |error| {
459                let task_error = PanickedTaskError::new(name, error);
460                error!("{task_error}");
461                let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
462            })
463            .in_current_span();
464
465        // Clone only the specific counter that we need.
466        let finished_critical_tasks_total_metrics =
467            self.metrics.finished_critical_tasks_total.clone();
468        let task = async move {
469            // Create an instance of IncCounterOnDrop with the counter to increment
470            let _inc_counter_on_drop = IncCounterOnDrop::new(finished_critical_tasks_total_metrics);
471            let task = pin!(task);
472            let _ = select(on_shutdown, task).await;
473        };
474
475        self.spawn_on_rt(task, task_kind)
476    }
477
478    /// This spawns a critical blocking task onto the runtime.
479    /// The given future resolves as soon as the [Shutdown] signal is received.
480    ///
481    /// If this task panics, the [`TaskManager`] is notified.
482    pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
483    where
484        F: Future<Output = ()> + Send + 'static,
485    {
486        self.spawn_critical_as(name, fut, TaskKind::Blocking)
487    }
488
489    /// This spawns a critical task onto the runtime.
490    /// The given future resolves as soon as the [Shutdown] signal is received.
491    ///
492    /// If this task panics, the [`TaskManager`] is notified.
493    pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
494    where
495        F: Future<Output = ()> + Send + 'static,
496    {
497        self.spawn_critical_as(name, fut, TaskKind::Default)
498    }
499
500    /// This spawns a critical task onto the runtime.
501    ///
502    /// If this task panics, the [`TaskManager`] is notified.
503    pub fn spawn_critical_with_shutdown_signal<F>(
504        &self,
505        name: &'static str,
506        f: impl FnOnce(Shutdown) -> F,
507    ) -> JoinHandle<()>
508    where
509        F: Future<Output = ()> + Send + 'static,
510    {
511        let panicked_tasks_tx = self.task_events_tx.clone();
512        let on_shutdown = self.on_shutdown.clone();
513        let fut = f(on_shutdown);
514
515        // wrap the task in catch unwind
516        let task = std::panic::AssertUnwindSafe(fut)
517            .catch_unwind()
518            .map_err(move |error| {
519                let task_error = PanickedTaskError::new(name, error);
520                error!("{task_error}");
521                let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
522            })
523            .map(drop)
524            .in_current_span();
525
526        self.handle.spawn(task)
527    }
528
529    /// This spawns a critical task onto the runtime.
530    ///
531    /// If this task panics, the [`TaskManager`] is notified.
532    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
533    ///
534    /// # Example
535    ///
536    /// ```no_run
537    /// # async fn t(executor: reth_tasks::TaskExecutor) {
538    ///
539    /// executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
540    ///     // await the shutdown signal
541    ///     let guard = shutdown.await;
542    ///     // do work before exiting the program
543    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
544    ///     // allow graceful shutdown
545    ///     drop(guard);
546    /// });
547    /// # }
548    /// ```
549    pub fn spawn_critical_with_graceful_shutdown_signal<F>(
550        &self,
551        name: &'static str,
552        f: impl FnOnce(GracefulShutdown) -> F,
553    ) -> JoinHandle<()>
554    where
555        F: Future<Output = ()> + Send + 'static,
556    {
557        let panicked_tasks_tx = self.task_events_tx.clone();
558        let on_shutdown = GracefulShutdown::new(
559            self.on_shutdown.clone(),
560            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
561        );
562        let fut = f(on_shutdown);
563
564        // wrap the task in catch unwind
565        let task = std::panic::AssertUnwindSafe(fut)
566            .catch_unwind()
567            .map_err(move |error| {
568                let task_error = PanickedTaskError::new(name, error);
569                error!("{task_error}");
570                let _ = panicked_tasks_tx.send(TaskEvent::Panic(task_error));
571            })
572            .map(drop)
573            .in_current_span();
574
575        self.handle.spawn(task)
576    }
577
578    /// This spawns a regular task onto the runtime.
579    ///
580    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
581    ///
582    /// # Example
583    ///
584    /// ```no_run
585    /// # async fn t(executor: reth_tasks::TaskExecutor) {
586    ///
587    /// executor.spawn_with_graceful_shutdown_signal(|shutdown| async move {
588    ///     // await the shutdown signal
589    ///     let guard = shutdown.await;
590    ///     // do work before exiting the program
591    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
592    ///     // allow graceful shutdown
593    ///     drop(guard);
594    /// });
595    /// # }
596    /// ```
597    pub fn spawn_with_graceful_shutdown_signal<F>(
598        &self,
599        f: impl FnOnce(GracefulShutdown) -> F,
600    ) -> JoinHandle<()>
601    where
602        F: Future<Output = ()> + Send + 'static,
603    {
604        let on_shutdown = GracefulShutdown::new(
605            self.on_shutdown.clone(),
606            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
607        );
608        let fut = f(on_shutdown);
609
610        self.handle.spawn(fut)
611    }
612
613    /// Sends a request to the `TaskManager` to initiate a graceful shutdown.
614    ///
615    /// Caution: This will terminate the entire program.
616    ///
617    /// The [`TaskManager`] upon receiving this event, will terminate and initiate the shutdown that
618    /// can be handled via the returned [`GracefulShutdown`].
619    pub fn initiate_graceful_shutdown(
620        &self,
621    ) -> Result<GracefulShutdown, tokio::sync::mpsc::error::SendError<()>> {
622        self.task_events_tx
623            .send(TaskEvent::GracefulShutdown)
624            .map_err(|_send_error_with_task_event| tokio::sync::mpsc::error::SendError(()))?;
625
626        Ok(GracefulShutdown::new(
627            self.on_shutdown.clone(),
628            GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
629        ))
630    }
631}
632
633impl TaskSpawner for TaskExecutor {
634    fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
635        self.metrics.inc_regular_tasks();
636        self.spawn(fut)
637    }
638
639    fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
640        self.metrics.inc_critical_tasks();
641        Self::spawn_critical(self, name, fut)
642    }
643
644    fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
645        self.metrics.inc_regular_tasks();
646        self.spawn_blocking(fut)
647    }
648
649    fn spawn_critical_blocking(
650        &self,
651        name: &'static str,
652        fut: BoxFuture<'static, ()>,
653    ) -> JoinHandle<()> {
654        self.metrics.inc_critical_tasks();
655        Self::spawn_critical_blocking(self, name, fut)
656    }
657}
658
659/// `TaskSpawner` with extended behaviour
660#[auto_impl::auto_impl(&, Arc)]
661pub trait TaskSpawnerExt: Send + Sync + Unpin + std::fmt::Debug + DynClone {
662    /// This spawns a critical task onto the runtime.
663    ///
664    /// If this task panics, the [`TaskManager`] is notified.
665    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
666    fn spawn_critical_with_graceful_shutdown_signal<F>(
667        &self,
668        name: &'static str,
669        f: impl FnOnce(GracefulShutdown) -> F,
670    ) -> JoinHandle<()>
671    where
672        F: Future<Output = ()> + Send + 'static;
673
674    /// This spawns a regular task onto the runtime.
675    ///
676    /// The [`TaskManager`] will wait until the given future has completed before shutting down.
677    fn spawn_with_graceful_shutdown_signal<F>(
678        &self,
679        f: impl FnOnce(GracefulShutdown) -> F,
680    ) -> JoinHandle<()>
681    where
682        F: Future<Output = ()> + Send + 'static;
683}
684
685impl TaskSpawnerExt for TaskExecutor {
686    fn spawn_critical_with_graceful_shutdown_signal<F>(
687        &self,
688        name: &'static str,
689        f: impl FnOnce(GracefulShutdown) -> F,
690    ) -> JoinHandle<()>
691    where
692        F: Future<Output = ()> + Send + 'static,
693    {
694        Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
695    }
696
697    fn spawn_with_graceful_shutdown_signal<F>(
698        &self,
699        f: impl FnOnce(GracefulShutdown) -> F,
700    ) -> JoinHandle<()>
701    where
702        F: Future<Output = ()> + Send + 'static,
703    {
704        Self::spawn_with_graceful_shutdown_signal(self, f)
705    }
706}
707
708/// Determines how a task is spawned
709enum TaskKind {
710    /// Spawn the task to the default executor [`Handle::spawn`]
711    Default,
712    /// Spawn the task to the blocking executor [`Handle::spawn_blocking`]
713    Blocking,
714}
715
716/// Error returned by `try_current` when no task executor has been configured.
717#[derive(Debug, Default, thiserror::Error)]
718#[error("No current task executor available.")]
719#[non_exhaustive]
720pub struct NoCurrentTaskExecutorError;
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725    use std::{sync::atomic::AtomicBool, time::Duration};
726
727    #[test]
728    fn test_cloneable() {
729        #[derive(Clone)]
730        struct ExecutorWrapper {
731            _e: Box<dyn TaskSpawner>,
732        }
733
734        let executor: Box<dyn TaskSpawner> = Box::<TokioTaskExecutor>::default();
735        let _e = dyn_clone::clone_box(&*executor);
736
737        let e = ExecutorWrapper { _e };
738        let _e2 = e;
739    }
740
741    #[test]
742    fn test_critical() {
743        let runtime = tokio::runtime::Runtime::new().unwrap();
744        let handle = runtime.handle().clone();
745        let manager = TaskManager::new(handle);
746        let executor = manager.executor();
747
748        executor.spawn_critical("this is a critical task", async { panic!("intentionally panic") });
749
750        runtime.block_on(async move {
751            let err_result = manager.await;
752            assert!(err_result.is_err(), "Expected TaskManager to return an error due to panic");
753            let panicked_err = err_result.unwrap_err();
754
755            assert_eq!(panicked_err.task_name, "this is a critical task");
756            assert_eq!(panicked_err.error, Some("intentionally panic".to_string()));
757        })
758    }
759
760    // Tests that spawned tasks are terminated if the `TaskManager` drops
761    #[test]
762    fn test_manager_shutdown_critical() {
763        let runtime = tokio::runtime::Runtime::new().unwrap();
764        let handle = runtime.handle().clone();
765        let manager = TaskManager::new(handle.clone());
766        let executor = manager.executor();
767
768        let (signal, shutdown) = signal();
769
770        executor.spawn_critical("this is a critical task", async move {
771            tokio::time::sleep(Duration::from_millis(200)).await;
772            drop(signal);
773        });
774
775        drop(manager);
776
777        handle.block_on(shutdown);
778    }
779
780    // Tests that spawned tasks are terminated if the `TaskManager` drops
781    #[test]
782    fn test_manager_shutdown() {
783        let runtime = tokio::runtime::Runtime::new().unwrap();
784        let handle = runtime.handle().clone();
785        let manager = TaskManager::new(handle.clone());
786        let executor = manager.executor();
787
788        let (signal, shutdown) = signal();
789
790        executor.spawn(Box::pin(async move {
791            tokio::time::sleep(Duration::from_millis(200)).await;
792            drop(signal);
793        }));
794
795        drop(manager);
796
797        handle.block_on(shutdown);
798    }
799
800    #[test]
801    fn test_manager_graceful_shutdown() {
802        let runtime = tokio::runtime::Runtime::new().unwrap();
803        let handle = runtime.handle().clone();
804        let manager = TaskManager::new(handle);
805        let executor = manager.executor();
806
807        let val = Arc::new(AtomicBool::new(false));
808        let c = val.clone();
809        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
810            let _guard = shutdown.await;
811            tokio::time::sleep(Duration::from_millis(200)).await;
812            c.store(true, Ordering::Relaxed);
813        });
814
815        manager.graceful_shutdown();
816        assert!(val.load(Ordering::Relaxed));
817    }
818
819    #[test]
820    fn test_manager_graceful_shutdown_many() {
821        let runtime = tokio::runtime::Runtime::new().unwrap();
822        let handle = runtime.handle().clone();
823        let manager = TaskManager::new(handle);
824        let executor = manager.executor();
825
826        let counter = Arc::new(AtomicUsize::new(0));
827        let num = 10;
828        for _ in 0..num {
829            let c = counter.clone();
830            executor.spawn_critical_with_graceful_shutdown_signal(
831                "grace",
832                move |shutdown| async move {
833                    let _guard = shutdown.await;
834                    tokio::time::sleep(Duration::from_millis(200)).await;
835                    c.fetch_add(1, Ordering::SeqCst);
836                },
837            );
838        }
839
840        manager.graceful_shutdown();
841        assert_eq!(counter.load(Ordering::Relaxed), num);
842    }
843
844    #[test]
845    fn test_manager_graceful_shutdown_timeout() {
846        let runtime = tokio::runtime::Runtime::new().unwrap();
847        let handle = runtime.handle().clone();
848        let manager = TaskManager::new(handle);
849        let executor = manager.executor();
850
851        let timeout = Duration::from_millis(500);
852        let val = Arc::new(AtomicBool::new(false));
853        let val2 = val.clone();
854        executor.spawn_critical_with_graceful_shutdown_signal("grace", |shutdown| async move {
855            let _guard = shutdown.await;
856            tokio::time::sleep(timeout * 3).await;
857            val2.store(true, Ordering::Relaxed);
858            unreachable!("should not be reached");
859        });
860
861        manager.graceful_shutdown_with_timeout(timeout);
862        assert!(!val.load(Ordering::Relaxed));
863    }
864
865    #[test]
866    fn can_access_global() {
867        let runtime = tokio::runtime::Runtime::new().unwrap();
868        let handle = runtime.handle().clone();
869        let _manager = TaskManager::new(handle);
870        let _executor = TaskExecutor::try_current().unwrap();
871    }
872
873    #[test]
874    fn test_graceful_shutdown_triggered_by_executor() {
875        let runtime = tokio::runtime::Runtime::new().unwrap();
876        let task_manager = TaskManager::new(runtime.handle().clone());
877        let executor = task_manager.executor();
878
879        let task_did_shutdown_flag = Arc::new(AtomicBool::new(false));
880        let flag_clone = task_did_shutdown_flag.clone();
881
882        let spawned_task_handle = executor.spawn_with_signal(|shutdown_signal| async move {
883            shutdown_signal.await;
884            flag_clone.store(true, Ordering::SeqCst);
885        });
886
887        let manager_future_handle = runtime.spawn(task_manager);
888
889        let send_result = executor.initiate_graceful_shutdown();
890        assert!(send_result.is_ok(), "Sending the graceful shutdown signal should succeed and return a GracefulShutdown future");
891
892        let manager_final_result = runtime.block_on(manager_future_handle);
893
894        assert!(manager_final_result.is_ok(), "TaskManager task should not panic");
895        assert_eq!(
896            manager_final_result.unwrap(),
897            Ok(()),
898            "TaskManager should resolve cleanly with Ok(()) after graceful shutdown request"
899        );
900
901        let task_join_result = runtime.block_on(spawned_task_handle);
902        assert!(task_join_result.is_ok(), "Spawned task should complete without panic");
903
904        assert!(
905            task_did_shutdown_flag.load(Ordering::Relaxed),
906            "Task should have received the shutdown signal and set the flag"
907        );
908    }
909}