reth_network_p2p/test_utils/
headers.rs1use crate::{
4 download::DownloadClient,
5 error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6 headers::{
7 client::{HeadersClient, HeadersRequest},
8 downloader::{HeaderDownloader, SyncTarget},
9 error::HeadersDownloaderResult,
10 },
11 priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_eth_wire_types::HeadersDirection;
16use reth_network_peers::{PeerId, WithPeerId};
17use reth_primitives_traits::SealedHeader;
18use std::{
19 fmt,
20 pin::Pin,
21 sync::{
22 atomic::{AtomicU64, Ordering},
23 Arc,
24 },
25 task::{ready, Context, Poll},
26};
27use tokio::sync::Mutex;
28
29#[derive(Debug)]
31pub struct TestHeaderDownloader {
32 client: TestHeadersClient,
33 limit: u64,
34 download: Option<TestDownload>,
35 queued_headers: Vec<SealedHeader>,
36 batch_size: usize,
37}
38
39impl TestHeaderDownloader {
40 pub const fn new(client: TestHeadersClient, limit: u64, batch_size: usize) -> Self {
42 Self { client, limit, download: None, batch_size, queued_headers: Vec::new() }
43 }
44
45 fn create_download(&self) -> TestDownload {
46 TestDownload {
47 client: self.client.clone(),
48 limit: self.limit,
49 fut: None,
50 buffer: vec![],
51 done: false,
52 }
53 }
54}
55
56impl HeaderDownloader for TestHeaderDownloader {
57 type Header = Header;
58
59 fn update_local_head(&mut self, _head: SealedHeader) {}
60
61 fn update_sync_target(&mut self, _target: SyncTarget) {}
62
63 fn set_batch_size(&mut self, limit: usize) {
64 self.batch_size = limit;
65 }
66}
67
68impl Stream for TestHeaderDownloader {
69 type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
70
71 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72 let this = self.get_mut();
73 loop {
74 if this.queued_headers.len() == this.batch_size {
75 return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
76 }
77 if this.download.is_none() {
78 this.download = Some(this.create_download());
79 }
80
81 match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
82 None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
83 Some(header) => this.queued_headers.push(header.unwrap()),
84 }
85 }
86 }
87}
88
89type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
90
91struct TestDownload {
92 client: TestHeadersClient,
93 limit: u64,
94 fut: Option<TestHeadersFut>,
95 buffer: Vec<SealedHeader>,
96 done: bool,
97}
98
99impl TestDownload {
100 fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
101 if self.fut.is_none() {
102 let request = HeadersRequest {
103 limit: self.limit,
104 direction: HeadersDirection::Rising,
105 start: 0u64.into(), };
107 let client = self.client.clone();
108 self.fut = Some(Box::pin(client.get_headers(request)));
109 }
110 self.fut.as_mut().unwrap()
111 }
112}
113
114impl fmt::Debug for TestDownload {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("TestDownload")
117 .field("client", &self.client)
118 .field("limit", &self.limit)
119 .field("buffer", &self.buffer)
120 .field("done", &self.done)
121 .finish_non_exhaustive()
122 }
123}
124
125impl Stream for TestDownload {
126 type Item = DownloadResult<SealedHeader>;
127
128 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 let this = self.get_mut();
130
131 loop {
132 if let Some(header) = this.buffer.pop() {
133 return Poll::Ready(Some(Ok(header)))
134 } else if this.done {
135 return Poll::Ready(None)
136 }
137
138 match ready!(this.get_or_init_fut().poll_unpin(cx)) {
139 Ok(resp) => {
140 let mut headers =
142 resp.1.into_iter().skip(1).map(SealedHeader::seal_slow).collect::<Vec<_>>();
143 headers.sort_unstable_by_key(|h| h.number);
144 headers.into_iter().for_each(|h| this.buffer.push(h));
145 this.done = true;
146 }
147 Err(err) => {
148 this.done = true;
149 return Poll::Ready(Some(Err(match err {
150 RequestError::Timeout => DownloadError::Timeout,
151 _ => DownloadError::RequestError(err),
152 })))
153 }
154 }
155 }
156 }
157}
158
159#[derive(Debug, Default, Clone)]
161pub struct TestHeadersClient {
162 responses: Arc<Mutex<Vec<Header>>>,
163 error: Arc<Mutex<Option<RequestError>>>,
164 request_attempts: Arc<AtomicU64>,
165}
166
167impl TestHeadersClient {
168 pub fn request_attempts(&self) -> u64 {
170 self.request_attempts.load(Ordering::SeqCst)
171 }
172
173 pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
175 let mut lock = self.responses.lock().await;
176 lock.extend(headers);
177 }
178
179 pub async fn clear(&self) {
181 let mut lock = self.responses.lock().await;
182 lock.clear();
183 }
184
185 pub async fn set_error(&self, err: RequestError) {
187 let mut lock = self.error.lock().await;
188 lock.replace(err);
189 }
190}
191
192impl DownloadClient for TestHeadersClient {
193 fn report_bad_message(&self, _peer_id: PeerId) {
194 }
196
197 fn num_connected_peers(&self) -> usize {
198 0
199 }
200}
201
202impl HeadersClient for TestHeadersClient {
203 type Header = Header;
204 type Output = TestHeadersFut;
205
206 fn get_headers_with_priority(
207 &self,
208 request: HeadersRequest,
209 _priority: Priority,
210 ) -> Self::Output {
211 let responses = self.responses.clone();
212 let error = self.error.clone();
213
214 self.request_attempts.fetch_add(1, Ordering::SeqCst);
215
216 Box::pin(async move {
217 if let Some(err) = &mut *error.lock().await {
218 return Err(err.clone())
219 }
220
221 let mut lock = responses.lock().await;
222 let len = lock.len().min(request.limit as usize);
223 let resp = lock.drain(..len).collect();
224 let with_peer_id = WithPeerId::from((PeerId::default(), resp));
225 Ok(with_peer_id)
226 })
227 }
228}