reth_downloaders/bodies/
task.rs

1use alloy_primitives::BlockNumber;
2use futures::Stream;
3use futures_util::{FutureExt, StreamExt};
4use pin_project::pin_project;
5use reth_network_p2p::{
6    bodies::downloader::{BodyDownloader, BodyDownloaderResult},
7    error::DownloadResult,
8};
9use reth_tasks::{TaskSpawner, TokioTaskExecutor};
10use std::{
11    fmt::Debug,
12    future::Future,
13    ops::RangeInclusive,
14    pin::Pin,
15    task::{ready, Context, Poll},
16};
17use tokio::sync::{mpsc, mpsc::UnboundedSender};
18use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
19use tokio_util::sync::PollSender;
20
21/// The maximum number of [`BodyDownloaderResult`]s to hold in the buffer.
22pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
23
24/// A [BodyDownloader] that drives a spawned [BodyDownloader] on a spawned task.
25#[derive(Debug)]
26#[pin_project]
27pub struct TaskDownloader<H, B> {
28    #[pin]
29    from_downloader: ReceiverStream<BodyDownloaderResult<H, B>>,
30    to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
31}
32
33// === impl TaskDownloader ===
34
35impl<H: Send + Sync + Unpin + 'static, B: Send + Sync + Unpin + 'static> TaskDownloader<H, B> {
36    /// Spawns the given `downloader` via [`tokio::task::spawn`] returns a [`TaskDownloader`] that's
37    /// connected to that task.
38    ///
39    /// # Panics
40    ///
41    /// This method panics if called outside of a Tokio runtime
42    ///
43    /// # Example
44    ///
45    /// ```
46    /// use reth_consensus::Consensus;
47    /// use reth_downloaders::bodies::{bodies::BodiesDownloaderBuilder, task::TaskDownloader};
48    /// use reth_network_p2p::bodies::client::BodiesClient;
49    /// use reth_primitives_traits::InMemorySize;
50    /// use reth_storage_api::HeaderProvider;
51    /// use std::{fmt::Debug, sync::Arc};
52    ///
53    /// fn t<
54    ///     B: BodiesClient<Body: Debug + InMemorySize> + 'static,
55    ///     Provider: HeaderProvider<Header = alloy_consensus::Header> + Unpin + 'static,
56    /// >(
57    ///     client: Arc<B>,
58    ///     consensus: Arc<dyn Consensus<Provider::Header, B::Body>>,
59    ///     provider: Provider,
60    /// ) {
61    ///     let downloader = BodiesDownloaderBuilder::default().build(client, consensus, provider);
62    ///     let downloader = TaskDownloader::spawn(downloader);
63    /// }
64    /// ```
65    pub fn spawn<T>(downloader: T) -> Self
66    where
67        T: BodyDownloader<Header = H, Body = B> + 'static,
68    {
69        Self::spawn_with(downloader, &TokioTaskExecutor::default())
70    }
71
72    /// Spawns the given `downloader` via the given [`TaskSpawner`] returns a [`TaskDownloader`]
73    /// that's connected to that task.
74    pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
75    where
76        T: BodyDownloader<Header = H, Body = B> + 'static,
77        S: TaskSpawner,
78    {
79        let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
80        let (to_downloader, updates_rx) = mpsc::unbounded_channel();
81
82        let downloader = SpawnedDownloader {
83            bodies_tx: PollSender::new(bodies_tx),
84            updates: UnboundedReceiverStream::new(updates_rx),
85            downloader,
86        };
87
88        spawner.spawn(downloader.boxed());
89
90        Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
91    }
92}
93
94impl<H: Debug + Send + Sync + Unpin + 'static, B: Debug + Send + Sync + Unpin + 'static>
95    BodyDownloader for TaskDownloader<H, B>
96{
97    type Header = H;
98    type Body = B;
99
100    fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
101        let _ = self.to_downloader.send(range);
102        Ok(())
103    }
104}
105
106impl<H, B> Stream for TaskDownloader<H, B> {
107    type Item = BodyDownloaderResult<H, B>;
108
109    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110        self.project().from_downloader.poll_next(cx)
111    }
112}
113
114/// A [`BodyDownloader`] that runs on its own task
115struct SpawnedDownloader<T: BodyDownloader> {
116    updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
117    bodies_tx: PollSender<BodyDownloaderResult<T::Header, T::Body>>,
118    downloader: T,
119}
120
121impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
122    type Output = ();
123
124    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
125        let this = self.get_mut();
126
127        loop {
128            while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
129                if let Some(range) = update {
130                    if let Err(err) = this.downloader.set_download_range(range) {
131                        tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
132
133                        // Clone the sender ensure its availability. See [PollSender::clone].
134                        let mut bodies_tx = this.bodies_tx.clone();
135
136                        let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
137                            .and_then(|_| bodies_tx.send_item(Err(err)));
138                        if forward_error_result.is_err() {
139                            // channel closed, this means [TaskDownloader] was dropped,
140                            // so we can also exit
141                            return Poll::Ready(())
142                        }
143                    }
144                } else {
145                    // channel closed, this means [TaskDownloader] was dropped, so we can also
146                    // exit
147                    return Poll::Ready(())
148                }
149            }
150
151            match ready!(this.bodies_tx.poll_reserve(cx)) {
152                Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
153                    Some(bodies) => {
154                        if this.bodies_tx.send_item(bodies).is_err() {
155                            // channel closed, this means [TaskDownloader] was dropped, so we can
156                            // also exit
157                            return Poll::Ready(())
158                        }
159                    }
160                    None => return Poll::Pending,
161                },
162                Err(_) => {
163                    // channel closed, this means [TaskDownloader] was dropped, so we can also
164                    // exit
165                    return Poll::Ready(())
166                }
167            }
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::{
176        bodies::{
177            bodies::BodiesDownloaderBuilder,
178            test_utils::{insert_headers, zip_blocks},
179        },
180        test_utils::{generate_bodies, TestBodiesClient},
181    };
182    use assert_matches::assert_matches;
183    use reth_consensus::test_utils::TestConsensus;
184    use reth_network_p2p::error::DownloadError;
185    use reth_provider::test_utils::create_test_provider_factory;
186    use std::sync::Arc;
187
188    #[tokio::test(flavor = "multi_thread")]
189    async fn download_one_by_one_on_task() {
190        reth_tracing::init_test_tracing();
191
192        let factory = create_test_provider_factory();
193        let (headers, mut bodies) = generate_bodies(0..=19);
194
195        insert_headers(factory.db_ref().db(), &headers);
196
197        let client = Arc::new(
198            TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
199        );
200        let downloader = BodiesDownloaderBuilder::default().build(
201            client.clone(),
202            Arc::new(TestConsensus::default()),
203            factory,
204        );
205        let mut downloader = TaskDownloader::spawn(downloader);
206
207        downloader.set_download_range(0..=19).expect("failed to set download range");
208
209        assert_matches!(
210            downloader.next().await,
211            Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
212        );
213        assert_eq!(client.times_requested(), 1);
214    }
215
216    #[tokio::test(flavor = "multi_thread")]
217    #[allow(clippy::reversed_empty_ranges)]
218    async fn set_download_range_error_returned() {
219        reth_tracing::init_test_tracing();
220        let factory = create_test_provider_factory();
221
222        let downloader = BodiesDownloaderBuilder::default().build(
223            Arc::new(TestBodiesClient::default()),
224            Arc::new(TestConsensus::default()),
225            factory,
226        );
227        let mut downloader = TaskDownloader::spawn(downloader);
228
229        downloader.set_download_range(1..=0).expect("failed to set download range");
230        assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
231    }
232}