reth_downloaders/bodies/
task.rs1use alloy_primitives::BlockNumber;
2use futures::Stream;
3use futures_util::{FutureExt, StreamExt};
4use pin_project::pin_project;
5use reth_network_p2p::{
6 bodies::downloader::{BodyDownloader, BodyDownloaderResult},
7 error::DownloadResult,
8};
9use reth_tasks::{TaskSpawner, TokioTaskExecutor};
10use std::{
11 fmt::Debug,
12 future::Future,
13 ops::RangeInclusive,
14 pin::Pin,
15 task::{ready, Context, Poll},
16};
17use tokio::sync::{mpsc, mpsc::UnboundedSender};
18use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
19use tokio_util::sync::PollSender;
20
21pub const BODIES_TASK_BUFFER_SIZE: usize = 4;
23
24#[derive(Debug)]
26#[pin_project]
27pub struct TaskDownloader<H, B> {
28 #[pin]
29 from_downloader: ReceiverStream<BodyDownloaderResult<H, B>>,
30 to_downloader: UnboundedSender<RangeInclusive<BlockNumber>>,
31}
32
33impl<H: Send + Sync + Unpin + 'static, B: Send + Sync + Unpin + 'static> TaskDownloader<H, B> {
36 pub fn spawn<T>(downloader: T) -> Self
66 where
67 T: BodyDownloader<Header = H, Body = B> + 'static,
68 {
69 Self::spawn_with(downloader, &TokioTaskExecutor::default())
70 }
71
72 pub fn spawn_with<T, S>(downloader: T, spawner: &S) -> Self
75 where
76 T: BodyDownloader<Header = H, Body = B> + 'static,
77 S: TaskSpawner,
78 {
79 let (bodies_tx, bodies_rx) = mpsc::channel(BODIES_TASK_BUFFER_SIZE);
80 let (to_downloader, updates_rx) = mpsc::unbounded_channel();
81
82 let downloader = SpawnedDownloader {
83 bodies_tx: PollSender::new(bodies_tx),
84 updates: UnboundedReceiverStream::new(updates_rx),
85 downloader,
86 };
87
88 spawner.spawn(downloader.boxed());
89
90 Self { from_downloader: ReceiverStream::new(bodies_rx), to_downloader }
91 }
92}
93
94impl<H: Debug + Send + Sync + Unpin + 'static, B: Debug + Send + Sync + Unpin + 'static>
95 BodyDownloader for TaskDownloader<H, B>
96{
97 type Header = H;
98 type Body = B;
99
100 fn set_download_range(&mut self, range: RangeInclusive<BlockNumber>) -> DownloadResult<()> {
101 let _ = self.to_downloader.send(range);
102 Ok(())
103 }
104}
105
106impl<H, B> Stream for TaskDownloader<H, B> {
107 type Item = BodyDownloaderResult<H, B>;
108
109 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110 self.project().from_downloader.poll_next(cx)
111 }
112}
113
114struct SpawnedDownloader<T: BodyDownloader> {
116 updates: UnboundedReceiverStream<RangeInclusive<BlockNumber>>,
117 bodies_tx: PollSender<BodyDownloaderResult<T::Header, T::Body>>,
118 downloader: T,
119}
120
121impl<T: BodyDownloader> Future for SpawnedDownloader<T> {
122 type Output = ();
123
124 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
125 let this = self.get_mut();
126
127 loop {
128 while let Poll::Ready(update) = this.updates.poll_next_unpin(cx) {
129 if let Some(range) = update {
130 if let Err(err) = this.downloader.set_download_range(range) {
131 tracing::error!(target: "downloaders::bodies", %err, "Failed to set bodies download range");
132
133 let mut bodies_tx = this.bodies_tx.clone();
135
136 let forward_error_result = ready!(bodies_tx.poll_reserve(cx))
137 .and_then(|_| bodies_tx.send_item(Err(err)));
138 if forward_error_result.is_err() {
139 return Poll::Ready(())
142 }
143 }
144 } else {
145 return Poll::Ready(())
148 }
149 }
150
151 match ready!(this.bodies_tx.poll_reserve(cx)) {
152 Ok(()) => match ready!(this.downloader.poll_next_unpin(cx)) {
153 Some(bodies) => {
154 if this.bodies_tx.send_item(bodies).is_err() {
155 return Poll::Ready(())
158 }
159 }
160 None => return Poll::Pending,
161 },
162 Err(_) => {
163 return Poll::Ready(())
166 }
167 }
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::{
176 bodies::{
177 bodies::BodiesDownloaderBuilder,
178 test_utils::{insert_headers, zip_blocks},
179 },
180 test_utils::{generate_bodies, TestBodiesClient},
181 };
182 use assert_matches::assert_matches;
183 use reth_consensus::test_utils::TestConsensus;
184 use reth_network_p2p::error::DownloadError;
185 use reth_provider::test_utils::create_test_provider_factory;
186 use std::sync::Arc;
187
188 #[tokio::test(flavor = "multi_thread")]
189 async fn download_one_by_one_on_task() {
190 reth_tracing::init_test_tracing();
191
192 let factory = create_test_provider_factory();
193 let (headers, mut bodies) = generate_bodies(0..=19);
194
195 insert_headers(factory.db_ref().db(), &headers);
196
197 let client = Arc::new(
198 TestBodiesClient::default().with_bodies(bodies.clone()).with_should_delay(true),
199 );
200 let downloader = BodiesDownloaderBuilder::default().build(
201 client.clone(),
202 Arc::new(TestConsensus::default()),
203 factory,
204 );
205 let mut downloader = TaskDownloader::spawn(downloader);
206
207 downloader.set_download_range(0..=19).expect("failed to set download range");
208
209 assert_matches!(
210 downloader.next().await,
211 Some(Ok(res)) => assert_eq!(res, zip_blocks(headers.iter(), &mut bodies))
212 );
213 assert_eq!(client.times_requested(), 1);
214 }
215
216 #[tokio::test(flavor = "multi_thread")]
217 #[allow(clippy::reversed_empty_ranges)]
218 async fn set_download_range_error_returned() {
219 reth_tracing::init_test_tracing();
220 let factory = create_test_provider_factory();
221
222 let downloader = BodiesDownloaderBuilder::default().build(
223 Arc::new(TestBodiesClient::default()),
224 Arc::new(TestConsensus::default()),
225 factory,
226 );
227 let mut downloader = TaskDownloader::spawn(downloader);
228
229 downloader.set_download_range(1..=0).expect("failed to set download range");
230 assert_matches!(downloader.next().await, Some(Err(DownloadError::InvalidBodyRange { .. })));
231 }
232}