reth_downloaders/headers/
task.rs

1use futures::{FutureExt, Stream};
2use futures_util::StreamExt;
3use pin_project::pin_project;
4use reth_network_p2p::headers::{
5    downloader::{HeaderDownloader, SyncTarget},
6    error::HeadersDownloaderResult,
7};
8use reth_primitives::SealedHeader;
9use reth_tasks::{TaskSpawner, TokioTaskExecutor};
10use std::{
11    fmt::Debug,
12    future::Future,
13    pin::Pin,
14    task::{ready, Context, Poll},
15};
16use tokio::sync::{mpsc, mpsc::UnboundedSender};
17use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
18use tokio_util::sync::PollSender;
19
20/// The maximum number of header results to hold in the buffer.
21pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
22
23/// A [HeaderDownloader] that drives a spawned [HeaderDownloader] on a spawned task.
24#[derive(Debug)]
25#[pin_project]
26pub struct TaskDownloader<H> {
27    #[pin]
28    from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
29    to_downloader: UnboundedSender<DownloaderUpdates<H>>,
30}
31
32// === impl TaskDownloader ===
33
34impl<H: Send + Sync + Unpin + 'static> TaskDownloader<H> {
35    /// Spawns the given `downloader` via [`tokio::task::spawn`] and returns a [`TaskDownloader`]
36    /// that's connected to that task.
37    ///
38    /// # Panics
39    ///
40    /// This method panics if called outside of a Tokio runtime
41    ///
42    /// # Example
43    ///
44    /// ```
45    /// # use std::sync::Arc;
46    /// # use reth_downloaders::headers::reverse_headers::ReverseHeadersDownloader;
47    /// # use reth_downloaders::headers::task::TaskDownloader;
48    /// # use reth_consensus::HeaderValidator;
49    /// # use reth_network_p2p::headers::client::HeadersClient;
50    /// # use reth_primitives_traits::BlockHeader;
51    /// # fn t<H: HeadersClient<Header: BlockHeader> + 'static>(consensus:Arc<dyn HeaderValidator<H::Header>>, client: Arc<H>) {
52    ///    let downloader = ReverseHeadersDownloader::<H>::builder().build(
53    ///        client,
54    ///        consensus
55    ///     );
56    ///   let downloader = TaskDownloader::spawn(downloader);
57    /// # }
58    pub fn spawn<T>(downloader: T) -> Self
59    where
60        T: HeaderDownloader<Header = H> + 'static,
61    {
62        Self::spawn_with(downloader, &TokioTaskExecutor::default())
63    }
64
65    /// Spawns the given `downloader` via the given [`TaskSpawner`] returns a [`TaskDownloader`]
66    /// that's connected to that task.
67    pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
68    where
69        T: HeaderDownloader<Header = H> + 'static,
70        S: TaskSpawner,
71    {
72        let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
73        let (to_downloader, updates_rx) = mpsc::unbounded_channel();
74
75        let downloader = SpawnedDownloader {
76            headers_tx: PollSender::new(headers_tx),
77            updates: UnboundedReceiverStream::new(updates_rx),
78            downloader,
79        };
80        spawner.spawn(downloader.boxed());
81
82        Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
83    }
84}
85
86impl<H: Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
87    type Header = H;
88
89    fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
90        let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
91    }
92
93    fn update_local_head(&mut self, head: SealedHeader<H>) {
94        let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
95    }
96
97    fn update_sync_target(&mut self, target: SyncTarget) {
98        let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
99    }
100
101    fn set_batch_size(&mut self, limit: usize) {
102        let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
103    }
104}
105
106impl<H> Stream for TaskDownloader<H> {
107    type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
108
109    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110        self.project().from_downloader.poll_next(cx)
111    }
112}
113
114/// A [`HeaderDownloader`] that runs on its own task
115#[expect(clippy::complexity)]
116struct SpawnedDownloader<T: HeaderDownloader> {
117    updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
118    headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
119    downloader: T,
120}
121
122impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
123    type Output = ();
124
125    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
126        let this = self.get_mut();
127
128        loop {
129            loop {
130                match this.updates.poll_next_unpin(cx) {
131                    Poll::Pending => break,
132                    Poll::Ready(None) => {
133                        // channel closed, this means [TaskDownloader] was dropped, so we can also
134                        // exit
135                        return Poll::Ready(())
136                    }
137                    Poll::Ready(Some(update)) => match update {
138                        DownloaderUpdates::UpdateSyncGap(head, target) => {
139                            this.downloader.update_sync_gap(head, target);
140                        }
141                        DownloaderUpdates::UpdateLocalHead(head) => {
142                            this.downloader.update_local_head(head);
143                        }
144                        DownloaderUpdates::UpdateSyncTarget(target) => {
145                            this.downloader.update_sync_target(target);
146                        }
147                        DownloaderUpdates::SetBatchSize(limit) => {
148                            this.downloader.set_batch_size(limit);
149                        }
150                    },
151                }
152            }
153
154            match ready!(this.headers_tx.poll_reserve(cx)) {
155                Ok(()) => {
156                    match ready!(this.downloader.poll_next_unpin(cx)) {
157                        Some(headers) => {
158                            if this.headers_tx.send_item(headers).is_err() {
159                                // channel closed, this means [TaskDownloader] was dropped, so we
160                                // can also exit
161                                return Poll::Ready(())
162                            }
163                        }
164                        None => return Poll::Pending,
165                    }
166                }
167                Err(_) => {
168                    // channel closed, this means [TaskDownloader] was dropped, so
169                    // we can also exit
170                    return Poll::Ready(())
171                }
172            }
173        }
174    }
175}
176
177/// Commands delegated tot the spawned [`HeaderDownloader`]
178enum DownloaderUpdates<H> {
179    UpdateSyncGap(SealedHeader<H>, SyncTarget),
180    UpdateLocalHead(SealedHeader<H>),
181    UpdateSyncTarget(SyncTarget),
182    SetBatchSize(usize),
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::headers::{
189        reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
190    };
191    use reth_consensus::test_utils::TestConsensus;
192    use reth_network_p2p::test_utils::TestHeadersClient;
193    use std::sync::Arc;
194
195    #[tokio::test(flavor = "multi_thread")]
196    async fn download_one_by_one_on_task() {
197        reth_tracing::init_test_tracing();
198
199        let p3 = SealedHeader::default();
200        let p2 = child_header(&p3);
201        let p1 = child_header(&p2);
202        let p0 = child_header(&p1);
203
204        let client = Arc::new(TestHeadersClient::default());
205        let downloader = ReverseHeadersDownloaderBuilder::default()
206            .stream_batch_size(1)
207            .request_limit(1)
208            .build(Arc::clone(&client), Arc::new(TestConsensus::default()));
209
210        let mut downloader = TaskDownloader::spawn(downloader);
211        downloader.update_local_head(p3.clone());
212        downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
213
214        client
215            .extend(vec![
216                p0.as_ref().clone(),
217                p1.as_ref().clone(),
218                p2.as_ref().clone(),
219                p3.as_ref().clone(),
220            ])
221            .await;
222
223        let headers = downloader.next().await.unwrap();
224        assert_eq!(headers, Ok(vec![p0]));
225
226        let headers = downloader.next().await.unwrap();
227        assert_eq!(headers, Ok(vec![p1]));
228        let headers = downloader.next().await.unwrap();
229        assert_eq!(headers, Ok(vec![p2]));
230    }
231}