reth_eth_wire/
ethstream.rs

1use crate::{
2    capability::RawCapabilityMessage,
3    errors::{EthHandshakeError, EthStreamError},
4    message::{EthBroadcastMessage, ProtocolBroadcastMessage},
5    p2pstream::HANDSHAKE_TIMEOUT,
6    CanDisconnect, DisconnectReason, EthMessage, EthNetworkPrimitives, EthVersion, ProtocolMessage,
7    Status,
8};
9use alloy_primitives::bytes::{Bytes, BytesMut};
10use alloy_rlp::Encodable;
11use futures::{ready, Sink, SinkExt, StreamExt};
12use pin_project::pin_project;
13use reth_eth_wire_types::NetworkPrimitives;
14use reth_ethereum_forks::ForkFilter;
15use reth_primitives_traits::GotExpected;
16use std::{
17    pin::Pin,
18    task::{Context, Poll},
19    time::Duration,
20};
21use tokio::time::timeout;
22use tokio_stream::Stream;
23use tracing::{debug, trace};
24
25/// [`MAX_MESSAGE_SIZE`] is the maximum cap on the size of a protocol message.
26// https://github.com/ethereum/go-ethereum/blob/30602163d5d8321fbc68afdcbbaf2362b2641bde/eth/protocols/eth/protocol.go#L50
27pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
28
29/// [`MAX_STATUS_SIZE`] is the maximum cap on the size of the initial status message
30pub(crate) const MAX_STATUS_SIZE: usize = 500 * 1024;
31
32/// An un-authenticated [`EthStream`]. This is consumed and returns a [`EthStream`] after the
33/// `Status` handshake is completed.
34#[pin_project]
35#[derive(Debug)]
36pub struct UnauthedEthStream<S> {
37    #[pin]
38    inner: S,
39}
40
41impl<S> UnauthedEthStream<S> {
42    /// Create a new `UnauthedEthStream` from a type `S` which implements `Stream` and `Sink`.
43    pub const fn new(inner: S) -> Self {
44        Self { inner }
45    }
46
47    /// Consumes the type and returns the wrapped stream
48    pub fn into_inner(self) -> S {
49        self.inner
50    }
51}
52
53impl<S, E> UnauthedEthStream<S>
54where
55    S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
56    EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
57{
58    /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
59    /// handshake is completed successfully. This also returns the `Status` message sent by the
60    /// remote peer.
61    pub async fn handshake<N: NetworkPrimitives>(
62        self,
63        status: Status,
64        fork_filter: ForkFilter,
65    ) -> Result<(EthStream<S, N>, Status), EthStreamError> {
66        self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
67    }
68
69    /// Wrapper around handshake which enforces a timeout.
70    pub async fn handshake_with_timeout<N: NetworkPrimitives>(
71        self,
72        status: Status,
73        fork_filter: ForkFilter,
74        timeout_limit: Duration,
75    ) -> Result<(EthStream<S, N>, Status), EthStreamError> {
76        timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
77            .await
78            .map_err(|_| EthStreamError::StreamTimeout)?
79    }
80
81    /// Handshake with no timeout
82    pub async fn handshake_without_timeout<N: NetworkPrimitives>(
83        mut self,
84        status: Status,
85        fork_filter: ForkFilter,
86    ) -> Result<(EthStream<S, N>, Status), EthStreamError> {
87        trace!(
88            %status,
89            "sending eth status to peer"
90        );
91
92        // we need to encode and decode here on our own because we don't have an `EthStream` yet
93        // The max length for a status with TTD is: <msg id = 1 byte> + <rlp(status) = 88 byte>
94        self.inner
95            .send(
96                alloy_rlp::encode(ProtocolMessage::<N>::from(EthMessage::<N>::Status(status)))
97                    .into(),
98            )
99            .await?;
100
101        let their_msg_res = self.inner.next().await;
102
103        let their_msg = match their_msg_res {
104            Some(msg) => msg,
105            None => {
106                self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
107                return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))
108            }
109        }?;
110
111        if their_msg.len() > MAX_STATUS_SIZE {
112            self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
113            return Err(EthStreamError::MessageTooBig(their_msg.len()))
114        }
115
116        let version = status.version;
117        let msg = match ProtocolMessage::<N>::decode_message(version, &mut their_msg.as_ref()) {
118            Ok(m) => m,
119            Err(err) => {
120                debug!("decode error in eth handshake: msg={their_msg:x}");
121                self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
122                return Err(EthStreamError::InvalidMessage(err))
123            }
124        };
125
126        // The following checks should match the checks in go-ethereum:
127        // https://github.com/ethereum/go-ethereum/blob/9244d5cd61f3ea5a7645fdf2a1a96d53421e412f/eth/protocols/eth/handshake.go#L87-L89
128        match msg.message {
129            EthMessage::Status(resp) => {
130                trace!(
131                    status=%resp,
132                    "validating incoming eth status from peer"
133                );
134                if status.genesis != resp.genesis {
135                    self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
136                    return Err(EthHandshakeError::MismatchedGenesis(
137                        GotExpected { expected: status.genesis, got: resp.genesis }.into(),
138                    )
139                    .into())
140                }
141
142                if status.version != resp.version {
143                    self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
144                    return Err(EthHandshakeError::MismatchedProtocolVersion(GotExpected {
145                        got: resp.version,
146                        expected: status.version,
147                    })
148                    .into())
149                }
150
151                if status.chain != resp.chain {
152                    self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
153                    return Err(EthHandshakeError::MismatchedChain(GotExpected {
154                        got: resp.chain,
155                        expected: status.chain,
156                    })
157                    .into())
158                }
159
160                // TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
161                // larger, it will still fit within 100 bits
162                if status.total_difficulty.bit_len() > 100 {
163                    self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
164                    return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
165                        got: status.total_difficulty.bit_len(),
166                        maximum: 100,
167                    }
168                    .into())
169                }
170
171                if let Err(err) =
172                    fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)
173                {
174                    self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
175                    return Err(err.into())
176                }
177
178                // now we can create the `EthStream` because the peer has successfully completed
179                // the handshake
180                let stream = EthStream::new(version, self.inner);
181
182                Ok((stream, resp))
183            }
184            _ => {
185                self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
186                Err(EthStreamError::EthHandshakeError(
187                    EthHandshakeError::NonStatusMessageInHandshake,
188                ))
189            }
190        }
191    }
192}
193
194/// An `EthStream` wraps over any `Stream` that yields bytes and makes it
195/// compatible with eth-networking protocol messages, which get RLP encoded/decoded.
196#[pin_project]
197#[derive(Debug)]
198pub struct EthStream<S, N = EthNetworkPrimitives> {
199    /// Negotiated eth version.
200    version: EthVersion,
201    #[pin]
202    inner: S,
203
204    _pd: std::marker::PhantomData<N>,
205}
206
207impl<S, N> EthStream<S, N> {
208    /// Creates a new unauthed [`EthStream`] from a provided stream. You will need
209    /// to manually handshake a peer.
210    #[inline]
211    pub const fn new(version: EthVersion, inner: S) -> Self {
212        Self { version, inner, _pd: std::marker::PhantomData }
213    }
214
215    /// Returns the eth version.
216    #[inline]
217    pub const fn version(&self) -> EthVersion {
218        self.version
219    }
220
221    /// Returns the underlying stream.
222    #[inline]
223    pub const fn inner(&self) -> &S {
224        &self.inner
225    }
226
227    /// Returns mutable access to the underlying stream.
228    #[inline]
229    pub fn inner_mut(&mut self) -> &mut S {
230        &mut self.inner
231    }
232
233    /// Consumes this type and returns the wrapped stream.
234    #[inline]
235    pub fn into_inner(self) -> S {
236        self.inner
237    }
238}
239
240impl<S, E, N> EthStream<S, N>
241where
242    S: Sink<Bytes, Error = E> + Unpin,
243    EthStreamError: From<E>,
244    N: NetworkPrimitives,
245{
246    /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
247    pub fn start_send_broadcast(
248        &mut self,
249        item: EthBroadcastMessage<N>,
250    ) -> Result<(), EthStreamError> {
251        self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
252            ProtocolBroadcastMessage::from(item),
253        )))?;
254
255        Ok(())
256    }
257
258    /// Sends a raw capability message directly over the stream
259    pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthStreamError> {
260        let mut bytes = Vec::new();
261        msg.id.encode(&mut bytes);
262        bytes.extend_from_slice(&msg.payload);
263
264        self.inner.start_send_unpin(bytes.into())?;
265        Ok(())
266    }
267}
268
269impl<S, E, N> Stream for EthStream<S, N>
270where
271    S: Stream<Item = Result<BytesMut, E>> + Unpin,
272    EthStreamError: From<E>,
273    N: NetworkPrimitives,
274{
275    type Item = Result<EthMessage<N>, EthStreamError>;
276
277    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
278        let this = self.project();
279        let res = ready!(this.inner.poll_next(cx));
280        let bytes = match res {
281            Some(Ok(bytes)) => bytes,
282            Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
283            None => return Poll::Ready(None),
284        };
285
286        if bytes.len() > MAX_MESSAGE_SIZE {
287            return Poll::Ready(Some(Err(EthStreamError::MessageTooBig(bytes.len()))))
288        }
289
290        let msg = match ProtocolMessage::decode_message(*this.version, &mut bytes.as_ref()) {
291            Ok(m) => m,
292            Err(err) => {
293                let msg = if bytes.len() > 50 {
294                    format!("{:02x?}...{:x?}", &bytes[..10], &bytes[bytes.len() - 10..])
295                } else {
296                    format!("{bytes:02x?}")
297                };
298                debug!(
299                    version=?this.version,
300                    %msg,
301                    "failed to decode protocol message"
302                );
303                return Poll::Ready(Some(Err(EthStreamError::InvalidMessage(err))))
304            }
305        };
306
307        if matches!(msg.message, EthMessage::Status(_)) {
308            return Poll::Ready(Some(Err(EthStreamError::EthHandshakeError(
309                EthHandshakeError::StatusNotInHandshake,
310            ))))
311        }
312
313        Poll::Ready(Some(Ok(msg.message)))
314    }
315}
316
317impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
318where
319    S: CanDisconnect<Bytes> + Unpin,
320    EthStreamError: From<<S as Sink<Bytes>>::Error>,
321    N: NetworkPrimitives,
322{
323    type Error = EthStreamError;
324
325    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
326        self.project().inner.poll_ready(cx).map_err(Into::into)
327    }
328
329    fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
330        if matches!(item, EthMessage::Status(_)) {
331            // TODO: to disconnect here we would need to do something similar to P2PStream's
332            // start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
333            // least similar.
334            //
335            // Other parts of reth do not yet need traits like CanDisconnect because atm they work
336            // exclusively with EthStream<P2PStream<S>>, where the inner P2PStream is accessible,
337            // allowing for its start_disconnect method to be called.
338            //
339            // self.project().inner.start_disconnect(DisconnectReason::ProtocolBreach);
340            return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
341        }
342
343        self.project()
344            .inner
345            .start_send(Bytes::from(alloy_rlp::encode(ProtocolMessage::from(item))))?;
346
347        Ok(())
348    }
349
350    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
351        self.project().inner.poll_flush(cx).map_err(Into::into)
352    }
353
354    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355        self.project().inner.poll_close(cx).map_err(Into::into)
356    }
357}
358
359impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
360where
361    S: CanDisconnect<Bytes> + Send,
362    EthStreamError: From<<S as Sink<Bytes>>::Error>,
363    N: NetworkPrimitives,
364{
365    async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
366        self.inner.disconnect(reason).await.map_err(Into::into)
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::UnauthedEthStream;
373    use crate::{
374        broadcast::BlockHashNumber,
375        errors::{EthHandshakeError, EthStreamError},
376        ethstream::RawCapabilityMessage,
377        hello::DEFAULT_TCP_PORT,
378        p2pstream::UnauthedP2PStream,
379        EthMessage, EthStream, EthVersion, HelloMessageWithProtocols, PassthroughCodec,
380        ProtocolVersion, Status,
381    };
382    use alloy_chains::NamedChain;
383    use alloy_primitives::{bytes::Bytes, B256, U256};
384    use alloy_rlp::Decodable;
385    use futures::{SinkExt, StreamExt};
386    use reth_ecies::stream::ECIESStream;
387    use reth_eth_wire_types::EthNetworkPrimitives;
388    use reth_ethereum_forks::{ForkFilter, Head};
389    use reth_network_peers::pk2id;
390    use secp256k1::{SecretKey, SECP256K1};
391    use std::time::Duration;
392    use tokio::net::{TcpListener, TcpStream};
393    use tokio_util::codec::Decoder;
394
395    #[tokio::test]
396    async fn can_handshake() {
397        let genesis = B256::random();
398        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
399
400        let status = Status {
401            version: EthVersion::Eth67,
402            chain: NamedChain::Mainnet.into(),
403            total_difficulty: U256::ZERO,
404            blockhash: B256::random(),
405            genesis,
406            // Pass the current fork id.
407            forkid: fork_filter.current(),
408        };
409
410        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
411        let local_addr = listener.local_addr().unwrap();
412
413        let status_clone = status;
414        let fork_filter_clone = fork_filter.clone();
415        let handle = tokio::spawn(async move {
416            // roughly based off of the design of tokio::net::TcpListener
417            let (incoming, _) = listener.accept().await.unwrap();
418            let stream = PassthroughCodec::default().framed(incoming);
419            let (_, their_status) = UnauthedEthStream::new(stream)
420                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
421                .await
422                .unwrap();
423
424            // just make sure it equals our status (our status is a clone of their status)
425            assert_eq!(their_status, status_clone);
426        });
427
428        let outgoing = TcpStream::connect(local_addr).await.unwrap();
429        let sink = PassthroughCodec::default().framed(outgoing);
430
431        // try to connect
432        let (_, their_status) = UnauthedEthStream::new(sink)
433            .handshake::<EthNetworkPrimitives>(status, fork_filter)
434            .await
435            .unwrap();
436
437        // their status is a clone of our status, these should be equal
438        assert_eq!(their_status, status);
439
440        // wait for it to finish
441        handle.await.unwrap();
442    }
443
444    #[tokio::test]
445    async fn pass_handshake_on_low_td_bitlen() {
446        let genesis = B256::random();
447        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
448
449        let status = Status {
450            version: EthVersion::Eth67,
451            chain: NamedChain::Mainnet.into(),
452            total_difficulty: U256::from(2).pow(U256::from(100)) - U256::from(1),
453            blockhash: B256::random(),
454            genesis,
455            // Pass the current fork id.
456            forkid: fork_filter.current(),
457        };
458
459        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
460        let local_addr = listener.local_addr().unwrap();
461
462        let status_clone = status;
463        let fork_filter_clone = fork_filter.clone();
464        let handle = tokio::spawn(async move {
465            // roughly based off of the design of tokio::net::TcpListener
466            let (incoming, _) = listener.accept().await.unwrap();
467            let stream = PassthroughCodec::default().framed(incoming);
468            let (_, their_status) = UnauthedEthStream::new(stream)
469                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
470                .await
471                .unwrap();
472
473            // just make sure it equals our status, and that the handshake succeeded
474            assert_eq!(their_status, status_clone);
475        });
476
477        let outgoing = TcpStream::connect(local_addr).await.unwrap();
478        let sink = PassthroughCodec::default().framed(outgoing);
479
480        // try to connect
481        let (_, their_status) = UnauthedEthStream::new(sink)
482            .handshake::<EthNetworkPrimitives>(status, fork_filter)
483            .await
484            .unwrap();
485
486        // their status is a clone of our status, these should be equal
487        assert_eq!(their_status, status);
488
489        // await the other handshake
490        handle.await.unwrap();
491    }
492
493    #[tokio::test]
494    async fn fail_handshake_on_high_td_bitlen() {
495        let genesis = B256::random();
496        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
497
498        let status = Status {
499            version: EthVersion::Eth67,
500            chain: NamedChain::Mainnet.into(),
501            total_difficulty: U256::from(2).pow(U256::from(100)),
502            blockhash: B256::random(),
503            genesis,
504            // Pass the current fork id.
505            forkid: fork_filter.current(),
506        };
507
508        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
509        let local_addr = listener.local_addr().unwrap();
510
511        let status_clone = status;
512        let fork_filter_clone = fork_filter.clone();
513        let handle = tokio::spawn(async move {
514            // roughly based off of the design of tokio::net::TcpListener
515            let (incoming, _) = listener.accept().await.unwrap();
516            let stream = PassthroughCodec::default().framed(incoming);
517            let handshake_res = UnauthedEthStream::new(stream)
518                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
519                .await;
520
521            // make sure the handshake fails due to td too high
522            assert!(matches!(
523                handshake_res,
524                Err(EthStreamError::EthHandshakeError(
525                    EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 101, maximum: 100 }
526                ))
527            ));
528        });
529
530        let outgoing = TcpStream::connect(local_addr).await.unwrap();
531        let sink = PassthroughCodec::default().framed(outgoing);
532
533        // try to connect
534        let handshake_res = UnauthedEthStream::new(sink)
535            .handshake::<EthNetworkPrimitives>(status, fork_filter)
536            .await;
537
538        // this handshake should also fail due to td too high
539        assert!(matches!(
540            handshake_res,
541            Err(EthStreamError::EthHandshakeError(
542                EthHandshakeError::TotalDifficultyBitLenTooLarge { got: 101, maximum: 100 }
543            ))
544        ));
545
546        // await the other handshake
547        handle.await.unwrap();
548    }
549
550    #[tokio::test]
551    async fn can_write_and_read_cleartext() {
552        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
553        let local_addr = listener.local_addr().unwrap();
554        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
555            vec![
556                BlockHashNumber { hash: B256::random(), number: 5 },
557                BlockHashNumber { hash: B256::random(), number: 6 },
558            ]
559            .into(),
560        );
561
562        let test_msg_clone = test_msg.clone();
563        let handle = tokio::spawn(async move {
564            // roughly based off of the design of tokio::net::TcpListener
565            let (incoming, _) = listener.accept().await.unwrap();
566            let stream = PassthroughCodec::default().framed(incoming);
567            let mut stream = EthStream::new(EthVersion::Eth67, stream);
568
569            // use the stream to get the next message
570            let message = stream.next().await.unwrap().unwrap();
571            assert_eq!(message, test_msg_clone);
572        });
573
574        let outgoing = TcpStream::connect(local_addr).await.unwrap();
575        let sink = PassthroughCodec::default().framed(outgoing);
576        let mut client_stream = EthStream::new(EthVersion::Eth67, sink);
577
578        client_stream.send(test_msg).await.unwrap();
579
580        // make sure the server receives the message and asserts before ending the test
581        handle.await.unwrap();
582    }
583
584    #[tokio::test]
585    async fn can_write_and_read_ecies() {
586        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
587        let local_addr = listener.local_addr().unwrap();
588        let server_key = SecretKey::new(&mut rand::thread_rng());
589        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
590            vec![
591                BlockHashNumber { hash: B256::random(), number: 5 },
592                BlockHashNumber { hash: B256::random(), number: 6 },
593            ]
594            .into(),
595        );
596
597        let test_msg_clone = test_msg.clone();
598        let handle = tokio::spawn(async move {
599            // roughly based off of the design of tokio::net::TcpListener
600            let (incoming, _) = listener.accept().await.unwrap();
601            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
602            let mut stream = EthStream::new(EthVersion::Eth67, stream);
603
604            // use the stream to get the next message
605            let message = stream.next().await.unwrap().unwrap();
606            assert_eq!(message, test_msg_clone);
607        });
608
609        // create the server pubkey
610        let server_id = pk2id(&server_key.public_key(SECP256K1));
611
612        let client_key = SecretKey::new(&mut rand::thread_rng());
613
614        let outgoing = TcpStream::connect(local_addr).await.unwrap();
615        let outgoing = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
616        let mut client_stream = EthStream::new(EthVersion::Eth67, outgoing);
617
618        client_stream.send(test_msg).await.unwrap();
619
620        // make sure the server receives the message and asserts before ending the test
621        handle.await.unwrap();
622    }
623
624    #[tokio::test(flavor = "multi_thread")]
625    async fn ethstream_over_p2p() {
626        // create a p2p stream and server, then confirm that the two are authed
627        // create tcpstream
628        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
629        let local_addr = listener.local_addr().unwrap();
630        let server_key = SecretKey::new(&mut rand::thread_rng());
631        let test_msg = EthMessage::<EthNetworkPrimitives>::NewBlockHashes(
632            vec![
633                BlockHashNumber { hash: B256::random(), number: 5 },
634                BlockHashNumber { hash: B256::random(), number: 6 },
635            ]
636            .into(),
637        );
638
639        let genesis = B256::random();
640        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
641
642        let status = Status {
643            version: EthVersion::Eth67,
644            chain: NamedChain::Mainnet.into(),
645            total_difficulty: U256::ZERO,
646            blockhash: B256::random(),
647            genesis,
648            // Pass the current fork id.
649            forkid: fork_filter.current(),
650        };
651
652        let status_copy = status;
653        let fork_filter_clone = fork_filter.clone();
654        let test_msg_clone = test_msg.clone();
655        let handle = tokio::spawn(async move {
656            // roughly based off of the design of tokio::net::TcpListener
657            let (incoming, _) = listener.accept().await.unwrap();
658            let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
659
660            let server_hello = HelloMessageWithProtocols {
661                protocol_version: ProtocolVersion::V5,
662                client_version: "bitcoind/1.0.0".to_string(),
663                protocols: vec![EthVersion::Eth67.into()],
664                port: DEFAULT_TCP_PORT,
665                id: pk2id(&server_key.public_key(SECP256K1)),
666            };
667
668            let unauthed_stream = UnauthedP2PStream::new(stream);
669            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
670            let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream)
671                .handshake(status_copy, fork_filter_clone)
672                .await
673                .unwrap();
674
675            // use the stream to get the next message
676            let message = eth_stream.next().await.unwrap().unwrap();
677            assert_eq!(message, test_msg_clone);
678        });
679
680        // create the server pubkey
681        let server_id = pk2id(&server_key.public_key(SECP256K1));
682
683        let client_key = SecretKey::new(&mut rand::thread_rng());
684
685        let outgoing = TcpStream::connect(local_addr).await.unwrap();
686        let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
687
688        let client_hello = HelloMessageWithProtocols {
689            protocol_version: ProtocolVersion::V5,
690            client_version: "bitcoind/1.0.0".to_string(),
691            protocols: vec![EthVersion::Eth67.into()],
692            port: DEFAULT_TCP_PORT,
693            id: pk2id(&client_key.public_key(SECP256K1)),
694        };
695
696        let unauthed_stream = UnauthedP2PStream::new(sink);
697        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
698
699        let (mut client_stream, _) =
700            UnauthedEthStream::new(p2p_stream).handshake(status, fork_filter).await.unwrap();
701
702        client_stream.send(test_msg).await.unwrap();
703
704        // make sure the server receives the message and asserts before ending the test
705        handle.await.unwrap();
706    }
707
708    #[tokio::test]
709    async fn handshake_should_timeout() {
710        let genesis = B256::random();
711        let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
712
713        let status = Status {
714            version: EthVersion::Eth67,
715            chain: NamedChain::Mainnet.into(),
716            total_difficulty: U256::ZERO,
717            blockhash: B256::random(),
718            genesis,
719            // Pass the current fork id.
720            forkid: fork_filter.current(),
721        };
722
723        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
724        let local_addr = listener.local_addr().unwrap();
725
726        let status_clone = status;
727        let fork_filter_clone = fork_filter.clone();
728        let _handle = tokio::spawn(async move {
729            // Delay accepting the connection for longer than the client's timeout period
730            tokio::time::sleep(Duration::from_secs(11)).await;
731            // roughly based off of the design of tokio::net::TcpListener
732            let (incoming, _) = listener.accept().await.unwrap();
733            let stream = PassthroughCodec::default().framed(incoming);
734            let (_, their_status) = UnauthedEthStream::new(stream)
735                .handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
736                .await
737                .unwrap();
738
739            // just make sure it equals our status (our status is a clone of their status)
740            assert_eq!(their_status, status_clone);
741        });
742
743        let outgoing = TcpStream::connect(local_addr).await.unwrap();
744        let sink = PassthroughCodec::default().framed(outgoing);
745
746        // try to connect
747        let handshake_result = UnauthedEthStream::new(sink)
748            .handshake_with_timeout::<EthNetworkPrimitives>(
749                status,
750                fork_filter,
751                Duration::from_secs(1),
752            )
753            .await;
754
755        // Assert that a timeout error occurred
756        assert!(
757            matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
758        );
759    }
760
761    #[tokio::test]
762    async fn can_write_and_read_raw_capability() {
763        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
764        let local_addr = listener.local_addr().unwrap();
765
766        let test_msg = RawCapabilityMessage { id: 0x1234, payload: Bytes::from(vec![1, 2, 3, 4]) };
767
768        let test_msg_clone = test_msg.clone();
769        let handle = tokio::spawn(async move {
770            let (incoming, _) = listener.accept().await.unwrap();
771            let stream = PassthroughCodec::default().framed(incoming);
772            let mut stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, stream);
773
774            let bytes = stream.inner_mut().next().await.unwrap().unwrap();
775
776            // Create a cursor to track position while decoding
777            let mut id_bytes = &bytes[..];
778            let decoded_id = <usize as Decodable>::decode(&mut id_bytes).unwrap();
779            assert_eq!(decoded_id, test_msg_clone.id);
780
781            // Get remaining bytes after ID decoding
782            let remaining = id_bytes;
783            assert_eq!(remaining, &test_msg_clone.payload[..]);
784        });
785
786        let outgoing = TcpStream::connect(local_addr).await.unwrap();
787        let sink = PassthroughCodec::default().framed(outgoing);
788        let mut client_stream = EthStream::<_, EthNetworkPrimitives>::new(EthVersion::Eth67, sink);
789
790        client_stream.start_send_raw(test_msg).unwrap();
791        client_stream.inner_mut().flush().await.unwrap();
792
793        handle.await.unwrap();
794    }
795}