reth_downloaders/headers/
task.rs1use futures::{FutureExt, Stream};
2use futures_util::StreamExt;
3use pin_project::pin_project;
4use reth_network_p2p::headers::{
5 downloader::{HeaderDownloader, SyncTarget},
6 error::HeadersDownloaderResult,
7};
8use reth_primitives::SealedHeader;
9use reth_tasks::{TaskSpawner, TokioTaskExecutor};
10use std::{
11 fmt::Debug,
12 future::Future,
13 pin::Pin,
14 task::{ready, Context, Poll},
15};
16use tokio::sync::{mpsc, mpsc::UnboundedSender};
17use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
18use tokio_util::sync::PollSender;
19
20pub const HEADERS_TASK_BUFFER_SIZE: usize = 8;
22
23#[derive(Debug)]
25#[pin_project]
26pub struct TaskDownloader<H> {
27 #[pin]
28 from_downloader: ReceiverStream<HeadersDownloaderResult<Vec<SealedHeader<H>>, H>>,
29 to_downloader: UnboundedSender<DownloaderUpdates<H>>,
30}
31
32impl<H: Send + Sync + Unpin + 'static> TaskDownloader<H> {
35 pub fn spawn<T>(downloader: T) -> Self
59 where
60 T: HeaderDownloader<Header = H> + 'static,
61 {
62 Self::spawn_with(downloader, &TokioTaskExecutor::default())
63 }
64
65 pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
68 where
69 T: HeaderDownloader<Header = H> + 'static,
70 S: TaskSpawner,
71 {
72 let (headers_tx, headers_rx) = mpsc::channel(HEADERS_TASK_BUFFER_SIZE);
73 let (to_downloader, updates_rx) = mpsc::unbounded_channel();
74
75 let downloader = SpawnedDownloader {
76 headers_tx: PollSender::new(headers_tx),
77 updates: UnboundedReceiverStream::new(updates_rx),
78 downloader,
79 };
80 spawner.spawn(downloader.boxed());
81
82 Self { from_downloader: ReceiverStream::new(headers_rx), to_downloader }
83 }
84}
85
86impl<H: Debug + Send + Sync + Unpin + 'static> HeaderDownloader for TaskDownloader<H> {
87 type Header = H;
88
89 fn update_sync_gap(&mut self, head: SealedHeader<H>, target: SyncTarget) {
90 let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncGap(head, target));
91 }
92
93 fn update_local_head(&mut self, head: SealedHeader<H>) {
94 let _ = self.to_downloader.send(DownloaderUpdates::UpdateLocalHead(head));
95 }
96
97 fn update_sync_target(&mut self, target: SyncTarget) {
98 let _ = self.to_downloader.send(DownloaderUpdates::UpdateSyncTarget(target));
99 }
100
101 fn set_batch_size(&mut self, limit: usize) {
102 let _ = self.to_downloader.send(DownloaderUpdates::SetBatchSize(limit));
103 }
104}
105
106impl<H> Stream for TaskDownloader<H> {
107 type Item = HeadersDownloaderResult<Vec<SealedHeader<H>>, H>;
108
109 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110 self.project().from_downloader.poll_next(cx)
111 }
112}
113
114#[expect(clippy::complexity)]
116struct SpawnedDownloader<T: HeaderDownloader> {
117 updates: UnboundedReceiverStream<DownloaderUpdates<T::Header>>,
118 headers_tx: PollSender<HeadersDownloaderResult<Vec<SealedHeader<T::Header>>, T::Header>>,
119 downloader: T,
120}
121
122impl<T: HeaderDownloader> Future for SpawnedDownloader<T> {
123 type Output = ();
124
125 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
126 let this = self.get_mut();
127
128 loop {
129 loop {
130 match this.updates.poll_next_unpin(cx) {
131 Poll::Pending => break,
132 Poll::Ready(None) => {
133 return Poll::Ready(())
136 }
137 Poll::Ready(Some(update)) => match update {
138 DownloaderUpdates::UpdateSyncGap(head, target) => {
139 this.downloader.update_sync_gap(head, target);
140 }
141 DownloaderUpdates::UpdateLocalHead(head) => {
142 this.downloader.update_local_head(head);
143 }
144 DownloaderUpdates::UpdateSyncTarget(target) => {
145 this.downloader.update_sync_target(target);
146 }
147 DownloaderUpdates::SetBatchSize(limit) => {
148 this.downloader.set_batch_size(limit);
149 }
150 },
151 }
152 }
153
154 match ready!(this.headers_tx.poll_reserve(cx)) {
155 Ok(()) => {
156 match ready!(this.downloader.poll_next_unpin(cx)) {
157 Some(headers) => {
158 if this.headers_tx.send_item(headers).is_err() {
159 return Poll::Ready(())
162 }
163 }
164 None => return Poll::Pending,
165 }
166 }
167 Err(_) => {
168 return Poll::Ready(())
171 }
172 }
173 }
174 }
175}
176
177enum DownloaderUpdates<H> {
179 UpdateSyncGap(SealedHeader<H>, SyncTarget),
180 UpdateLocalHead(SealedHeader<H>),
181 UpdateSyncTarget(SyncTarget),
182 SetBatchSize(usize),
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::headers::{
189 reverse_headers::ReverseHeadersDownloaderBuilder, test_utils::child_header,
190 };
191 use reth_consensus::test_utils::TestConsensus;
192 use reth_network_p2p::test_utils::TestHeadersClient;
193 use std::sync::Arc;
194
195 #[tokio::test(flavor = "multi_thread")]
196 async fn download_one_by_one_on_task() {
197 reth_tracing::init_test_tracing();
198
199 let p3 = SealedHeader::default();
200 let p2 = child_header(&p3);
201 let p1 = child_header(&p2);
202 let p0 = child_header(&p1);
203
204 let client = Arc::new(TestHeadersClient::default());
205 let downloader = ReverseHeadersDownloaderBuilder::default()
206 .stream_batch_size(1)
207 .request_limit(1)
208 .build(Arc::clone(&client), Arc::new(TestConsensus::default()));
209
210 let mut downloader = TaskDownloader::spawn(downloader);
211 downloader.update_local_head(p3.clone());
212 downloader.update_sync_target(SyncTarget::Tip(p0.hash()));
213
214 client
215 .extend(vec![
216 p0.as_ref().clone(),
217 p1.as_ref().clone(),
218 p2.as_ref().clone(),
219 p3.as_ref().clone(),
220 ])
221 .await;
222
223 let headers = downloader.next().await.unwrap();
224 assert_eq!(headers, Ok(vec![p0]));
225
226 let headers = downloader.next().await.unwrap();
227 assert_eq!(headers, Ok(vec![p1]));
228 let headers = downloader.next().await.unwrap();
229 assert_eq!(headers, Ok(vec![p2]));
230 }
231}