reth_network_p2p/test_utils/
headers.rs

1//! Testing support for headers related interfaces.
2
3use crate::{
4    download::DownloadClient,
5    error::{DownloadError, DownloadResult, PeerRequestResult, RequestError},
6    headers::{
7        client::{HeadersClient, HeadersRequest},
8        downloader::{HeaderDownloader, SyncTarget},
9        error::HeadersDownloaderResult,
10    },
11    priority::Priority,
12};
13use alloy_consensus::Header;
14use futures::{Future, FutureExt, Stream, StreamExt};
15use reth_eth_wire_types::HeadersDirection;
16use reth_network_peers::{PeerId, WithPeerId};
17use reth_primitives_traits::SealedHeader;
18use std::{
19    fmt,
20    pin::Pin,
21    sync::{
22        atomic::{AtomicU64, Ordering},
23        Arc,
24    },
25    task::{ready, Context, Poll},
26};
27use tokio::sync::Mutex;
28
29/// A test downloader which just returns the values that have been pushed to it.
30#[derive(Debug)]
31pub struct TestHeaderDownloader {
32    client: TestHeadersClient,
33    limit: u64,
34    download: Option<TestDownload>,
35    queued_headers: Vec<SealedHeader>,
36    batch_size: usize,
37}
38
39impl TestHeaderDownloader {
40    /// Instantiates the downloader with the mock responses
41    pub const fn new(client: TestHeadersClient, limit: u64, batch_size: usize) -> Self {
42        Self { client, limit, download: None, batch_size, queued_headers: Vec::new() }
43    }
44
45    fn create_download(&self) -> TestDownload {
46        TestDownload {
47            client: self.client.clone(),
48            limit: self.limit,
49            fut: None,
50            buffer: vec![],
51            done: false,
52        }
53    }
54}
55
56impl HeaderDownloader for TestHeaderDownloader {
57    type Header = Header;
58
59    fn update_local_head(&mut self, _head: SealedHeader) {}
60
61    fn update_sync_target(&mut self, _target: SyncTarget) {}
62
63    fn set_batch_size(&mut self, limit: usize) {
64        self.batch_size = limit;
65    }
66}
67
68impl Stream for TestHeaderDownloader {
69    type Item = HeadersDownloaderResult<Vec<SealedHeader>, Header>;
70
71    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72        let this = self.get_mut();
73        loop {
74            if this.queued_headers.len() == this.batch_size {
75                return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers))))
76            }
77            if this.download.is_none() {
78                this.download = Some(this.create_download());
79            }
80
81            match ready!(this.download.as_mut().unwrap().poll_next_unpin(cx)) {
82                None => return Poll::Ready(Some(Ok(std::mem::take(&mut this.queued_headers)))),
83                Some(header) => this.queued_headers.push(header.unwrap()),
84            }
85        }
86    }
87}
88
89type TestHeadersFut = Pin<Box<dyn Future<Output = PeerRequestResult<Vec<Header>>> + Sync + Send>>;
90
91struct TestDownload {
92    client: TestHeadersClient,
93    limit: u64,
94    fut: Option<TestHeadersFut>,
95    buffer: Vec<SealedHeader>,
96    done: bool,
97}
98
99impl TestDownload {
100    fn get_or_init_fut(&mut self) -> &mut TestHeadersFut {
101        if self.fut.is_none() {
102            let request = HeadersRequest {
103                limit: self.limit,
104                direction: HeadersDirection::Rising,
105                start: 0u64.into(), // ignored
106            };
107            let client = self.client.clone();
108            self.fut = Some(Box::pin(client.get_headers(request)));
109        }
110        self.fut.as_mut().unwrap()
111    }
112}
113
114impl fmt::Debug for TestDownload {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("TestDownload")
117            .field("client", &self.client)
118            .field("limit", &self.limit)
119            .field("buffer", &self.buffer)
120            .field("done", &self.done)
121            .finish_non_exhaustive()
122    }
123}
124
125impl Stream for TestDownload {
126    type Item = DownloadResult<SealedHeader>;
127
128    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        let this = self.get_mut();
130
131        loop {
132            if let Some(header) = this.buffer.pop() {
133                return Poll::Ready(Some(Ok(header)))
134            } else if this.done {
135                return Poll::Ready(None)
136            }
137
138            match ready!(this.get_or_init_fut().poll_unpin(cx)) {
139                Ok(resp) => {
140                    // Skip head and seal headers
141                    let mut headers =
142                        resp.1.into_iter().skip(1).map(SealedHeader::seal_slow).collect::<Vec<_>>();
143                    headers.sort_unstable_by_key(|h| h.number);
144                    headers.into_iter().for_each(|h| this.buffer.push(h));
145                    this.done = true;
146                }
147                Err(err) => {
148                    this.done = true;
149                    return Poll::Ready(Some(Err(match err {
150                        RequestError::Timeout => DownloadError::Timeout,
151                        _ => DownloadError::RequestError(err),
152                    })))
153                }
154            }
155        }
156    }
157}
158
159/// A test client for fetching headers
160#[derive(Debug, Default, Clone)]
161pub struct TestHeadersClient {
162    responses: Arc<Mutex<Vec<Header>>>,
163    error: Arc<Mutex<Option<RequestError>>>,
164    request_attempts: Arc<AtomicU64>,
165}
166
167impl TestHeadersClient {
168    /// Return the number of times client was polled
169    pub fn request_attempts(&self) -> u64 {
170        self.request_attempts.load(Ordering::SeqCst)
171    }
172
173    /// Adds headers to the set.
174    pub async fn extend(&self, headers: impl IntoIterator<Item = Header>) {
175        let mut lock = self.responses.lock().await;
176        lock.extend(headers);
177    }
178
179    /// Clears the set.
180    pub async fn clear(&self) {
181        let mut lock = self.responses.lock().await;
182        lock.clear();
183    }
184
185    /// Set response error
186    pub async fn set_error(&self, err: RequestError) {
187        let mut lock = self.error.lock().await;
188        lock.replace(err);
189    }
190}
191
192impl DownloadClient for TestHeadersClient {
193    fn report_bad_message(&self, _peer_id: PeerId) {
194        // noop
195    }
196
197    fn num_connected_peers(&self) -> usize {
198        0
199    }
200}
201
202impl HeadersClient for TestHeadersClient {
203    type Header = Header;
204    type Output = TestHeadersFut;
205
206    fn get_headers_with_priority(
207        &self,
208        request: HeadersRequest,
209        _priority: Priority,
210    ) -> Self::Output {
211        let responses = self.responses.clone();
212        let error = self.error.clone();
213
214        self.request_attempts.fetch_add(1, Ordering::SeqCst);
215
216        Box::pin(async move {
217            if let Some(err) = &mut *error.lock().await {
218                return Err(err.clone())
219            }
220
221            let mut lock = responses.lock().await;
222            let len = lock.len().min(request.limit as usize);
223            let resp = lock.drain(..len).collect();
224            let with_peer_id = WithPeerId::from((PeerId::default(), resp));
225            Ok(with_peer_id)
226        })
227    }
228}