reth_network_p2p/test_utils/
headers.rs

1//! Testing support for headers related interfaces.
2
3use crate::{
4    download::DownloadClient,
5    error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6    headers::{
7        client::{HeadersClient, HeadersRequest},
8        downloader::{HeaderDownloader, SyncTarget},
9        error::HeadersDownloaderResult,
10    },
11    priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_consensus::{test_utils::TestConsensus, Consensus};
16use reth_eth_wire_types::HeadersDirection;
17use reth_network_peers::{PeerId, WithPeerId};
18use reth_primitives::SealedHeader;
19use std::{
20    fmt,
21    pin::Pin,
22    sync::{
23        atomic::{AtomicU64, Ordering},
24        Arc,
25    },
26    task::{ready, Context, Poll},
27};
28use tokio::sync::Mutex;
29
30/// A test downloader which just returns the values that have been pushed to it.
31#[derive(Debug)]
32pub struct TestHeaderDownloader {
33    client: TestHeadersClient,
34    consensus: Arc<TestConsensus>,
35    limit: u64,
36    download: Option<TestDownload>,
37    queued_headers: Vec<SealedHeader>,
38    batch_size: usize,
39}
40
41impl TestHeaderDownloader {
42    /// Instantiates the downloader with the mock responses
43    pub const fn new(
44        client: TestHeadersClient,
45        consensus: Arc<TestConsensus>,
46        limit: u64,
47        batch_size: usize,
48    ) -> Self {
49        Self { client, consensus, limit, download: None, batch_size, queued_headers: Vec::new() }
50    }
51
52    fn create_download(&self) -> TestDownload {
53        TestDownload {
54            client: self.client.clone(),
55            consensus: Arc::clone(&self.consensus),
56            limit: self.limit,
57            fut: None,
58            buffer: vec![],
59            done: false,
60        }
61    }
62}
63
64impl HeaderDownloader for TestHeaderDownloader {
65    type Header = Header;
66
67    fn update_local_head(&mut self, _head: SealedHeader) {}
68
69    fn update_sync_target(&mut self, _target: SyncTarget) {}
70
71    fn set_batch_size(&mut self, limit: usize) {
72        self.batch_size = limit;
73    }
74}
75
76impl Stream for TestHeaderDownloader {
77    type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
78
79    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        let this = self.get_mut();
81        loop {
82            if this.queued_headers.len() == this.batch_size {
83                return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
84            }
85            if this.download.is_none() {
86                this.download = Some(this.create_download());
87            }
88
89            match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
90                None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
91                Some(header) => this.queued_headers.push(header.unwrap()),
92            }
93        }
94    }
95}
96
97type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
98
99struct TestDownload {
100    client: TestHeadersClient,
101    consensus: Arc<TestConsensus>,
102    limit: u64,
103    fut: Option<TestHeadersFut>,
104    buffer: Vec<SealedHeader>,
105    done: bool,
106}
107
108impl TestDownload {
109    fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
110        if self.fut.is_none() {
111            let request = HeadersRequest {
112                limit: self.limit,
113                direction: HeadersDirection::Rising,
114                start: 0u64.into(), // ignored
115            };
116            let client = self.client.clone();
117            self.fut = Some(Box::pin(client.get_headers(request)));
118        }
119        self.fut.as_mut().unwrap()
120    }
121}
122
123impl fmt::Debug for TestDownload {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        f.debug_struct("TestDownload")
126            .field("client", &self.client)
127            .field("consensus", &self.consensus)
128            .field("limit", &self.limit)
129            .field("buffer", &self.buffer)
130            .field("done", &self.done)
131            .finish_non_exhaustive()
132    }
133}
134
135impl Stream for TestDownload {
136    type Item = DownloadResult<SealedHeader>;
137
138    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        let this = self.get_mut();
140
141        loop {
142            if let Some(header) = this.buffer.pop() {
143                return Poll::Ready(Some(Ok(header)))
144            } else if this.done {
145                return Poll::Ready(None)
146            }
147
148            let empty: SealedHeader = SealedHeader::default();
149            if let Err(error) =
150                <dyn Consensus<_>>::validate_header_against_parent(&this.consensus, &empty, &empty)
151            {
152                this.done = true;
153                return Poll::Ready(Some(Err(DownloadError::HeaderValidation {
154                    hash: empty.hash(),
155                    number: empty.number,
156                    error: Box::new(error),
157                })))
158            }
159
160            match ready!(this.get_or_init_fut().poll_unpin(cx)) {
161                Ok(resp) => {
162                    // Skip head and seal headers
163                    let mut headers =
164                        resp.1.into_iter().skip(1).map(SealedHeader::seal).collect::<Vec<_>>();
165                    headers.sort_unstable_by_key(|h| h.number);
166                    headers.into_iter().for_each(|h| this.buffer.push(h));
167                    this.done = true;
168                    continue
169                }
170                Err(err) => {
171                    this.done = true;
172                    return Poll::Ready(Some(Err(match err {
173                        RequestError::Timeout => DownloadError::Timeout,
174                        _ => DownloadError::RequestError(err),
175                    })))
176                }
177            }
178        }
179    }
180}
181
182/// A test client for fetching headers
183#[derive(Debug, Default, Clone)]
184pub struct TestHeadersClient {
185    responses: Arc<Mutex<Vec<Header>>>,
186    error: Arc<Mutex<Option<RequestError>>>,
187    request_attempts: Arc<AtomicU64>,
188}
189
190impl TestHeadersClient {
191    /// Return the number of times client was polled
192    pub fn request_attempts(&self) -> u64 {
193        self.request_attempts.load(Ordering::SeqCst)
194    }
195
196    /// Adds headers to the set.
197    pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
198        let mut lock = self.responses.lock().await;
199        lock.extend(headers);
200    }
201
202    /// Clears the set.
203    pub async fn clear(&self) {
204        let mut lock = self.responses.lock().await;
205        lock.clear();
206    }
207
208    /// Set response error
209    pub async fn set_error(&self, err: RequestError) {
210        let mut lock = self.error.lock().await;
211        lock.replace(err);
212    }
213}
214
215impl DownloadClient for TestHeadersClient {
216    fn report_bad_message(&self, _peer_id: PeerId) {
217        // noop
218    }
219
220    fn num_connected_peers(&self) -> usize {
221        0
222    }
223}
224
225impl HeadersClient for TestHeadersClient {
226    type Header = Header;
227    type Output = TestHeadersFut;
228
229    fn get_headers_with_priority(
230        &self,
231        request: HeadersRequest,
232        _priority: Priority,
233    ) -> Self::Output {
234        let responses = self.responses.clone();
235        let error = self.error.clone();
236
237        self.request_attempts.fetch_add(1, Ordering::SeqCst);
238
239        Box::pin(async move {
240            if let Some(err) = &mut *error.lock().await {
241                return Err(err.clone())
242            }
243
244            let mut lock = responses.lock().await;
245            let len = lock.len().min(request.limit as usize);
246            let resp = lock.drain(..len).collect();
247            let with_peer_id = WithPeerId::from((PeerId::default(), resp));
248            Ok(with_peer_id)
249        })
250    }
251}