1use crate::{
2 capability::SharedCapabilities,
3 disconnect::CanDisconnect,
4 errors::{P2PHandshakeError, P2PStreamError},
5 pinger::{Pinger, PingerEvent},
6 DisconnectReason, HelloMessage, HelloMessageWithProtocols,
7};
8use alloy_primitives::{
9 bytes::{Buf, BufMut, Bytes, BytesMut},
10 hex,
11};
12use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE};
13use futures::{Sink, SinkExt, StreamExt};
14use pin_project::pin_project;
15use reth_codecs::add_arbitrary_tests;
16use reth_metrics::metrics::counter;
17use reth_primitives_traits::GotExpected;
18use std::{
19 collections::VecDeque,
20 io,
21 pin::Pin,
22 task::{ready, Context, Poll},
23 time::Duration,
24};
25use tokio_stream::Stream;
26use tracing::{debug, trace};
27
28#[cfg(feature = "serde")]
29use serde::{Deserialize, Serialize};
30
31const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
34
35pub const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
38
39const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
41
42pub(crate) const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
45
46const PING_TIMEOUT: Duration = Duration::from_secs(15);
49
50const PING_INTERVAL: Duration = Duration::from_secs(60);
53
54const MAX_P2P_CAPACITY: usize = 2;
61
62#[pin_project]
65#[derive(Debug)]
66pub struct UnauthedP2PStream<S> {
67 #[pin]
68 inner: S,
69}
70
71impl<S> UnauthedP2PStream<S> {
72 pub const fn new(inner: S) -> Self {
74 Self { inner }
75 }
76
77 pub const fn inner(&self) -> &S {
79 &self.inner
80 }
81}
82
83impl<S> UnauthedP2PStream<S>
84where
85 S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
86{
87 pub async fn handshake(
90 mut self,
91 hello: HelloMessageWithProtocols,
92 ) -> Result<(P2PStream<S>, HelloMessage), P2PStreamError> {
93 trace!(?hello, "sending p2p hello to peer");
94
95 self.inner.send(alloy_rlp::encode(P2PMessage::Hello(hello.message())).into()).await?;
97
98 let first_message_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
99 .await
100 .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
101 .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
102
103 if first_message_bytes.len() > MAX_PAYLOAD_SIZE {
106 return Err(P2PStreamError::MessageTooBig {
107 message_size: first_message_bytes.len(),
108 max_size: MAX_PAYLOAD_SIZE,
109 })
110 }
111
112 let their_hello = match P2PMessage::decode(&mut &first_message_bytes[..]) {
119 Ok(P2PMessage::Hello(hello)) => Ok(hello),
120 Ok(P2PMessage::Disconnect(reason)) => {
121 if matches!(reason, DisconnectReason::TooManyPeers) {
122 trace!(%reason, "Disconnected by peer during handshake");
124 } else {
125 debug!(%reason, "Disconnected by peer during handshake");
126 };
127 counter!("p2pstream.disconnected_errors").increment(1);
128 Err(P2PStreamError::HandshakeError(P2PHandshakeError::Disconnected(reason)))
129 }
130 Err(err) => {
131 debug!(%err, msg=%hex::encode(&first_message_bytes), "Failed to decode first message from peer");
132 Err(P2PStreamError::HandshakeError(err.into()))
133 }
134 Ok(msg) => {
135 debug!(?msg, "expected hello message but received another message");
136 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
137 }
138 }?;
139
140 trace!(
141 hello=?their_hello,
142 "validating incoming p2p hello from peer"
143 );
144
145 if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
146 self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
148 return Err(P2PStreamError::MismatchedProtocolVersion(GotExpected {
149 got: their_hello.protocol_version,
150 expected: hello.protocol_version,
151 }))
152 }
153
154 let capability_res =
156 SharedCapabilities::try_new(hello.protocols, their_hello.capabilities.clone());
157
158 let shared_capability = match capability_res {
159 Err(err) => {
160 self.send_disconnect(DisconnectReason::UselessPeer).await?;
162 Err(err)
163 }
164 Ok(cap) => Ok(cap),
165 }?;
166
167 let stream = P2PStream::new(self.inner, shared_capability);
168
169 Ok((stream, their_hello))
170 }
171}
172
173impl<S> UnauthedP2PStream<S>
174where
175 S: Sink<Bytes, Error = io::Error> + Unpin,
176{
177 pub async fn send_disconnect(
179 &mut self,
180 reason: DisconnectReason,
181 ) -> Result<(), P2PStreamError> {
182 trace!(
183 %reason,
184 "Sending disconnect message during the handshake",
185 );
186 self.inner
187 .send(Bytes::from(alloy_rlp::encode(P2PMessage::Disconnect(reason))))
188 .await
189 .map_err(P2PStreamError::Io)
190 }
191}
192
193impl<S> CanDisconnect<Bytes> for P2PStream<S>
194where
195 S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
196{
197 async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
198 self.disconnect(reason).await
199 }
200}
201
202#[pin_project]
227#[derive(Debug)]
228pub struct P2PStream<S> {
229 #[pin]
230 inner: S,
231
232 encoder: snap::raw::Encoder,
234
235 decoder: snap::raw::Decoder,
237
238 pinger: Pinger,
240
241 shared_capabilities: SharedCapabilities,
243
244 outgoing_messages: VecDeque<Bytes>,
246
247 outgoing_message_buffer_capacity: usize,
250
251 disconnecting: bool,
254}
255
256impl<S> P2PStream<S> {
257 pub fn new(inner: S, shared_capabilities: SharedCapabilities) -> Self {
261 Self {
262 inner,
263 encoder: snap::raw::Encoder::new(),
264 decoder: snap::raw::Decoder::new(),
265 pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
266 shared_capabilities,
267 outgoing_messages: VecDeque::new(),
268 outgoing_message_buffer_capacity: MAX_P2P_CAPACITY,
269 disconnecting: false,
270 }
271 }
272
273 pub const fn inner(&self) -> &S {
275 &self.inner
276 }
277
278 pub fn set_outgoing_message_buffer_capacity(&mut self, capacity: usize) {
284 self.outgoing_message_buffer_capacity = capacity;
285 }
286
287 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
292 &self.shared_capabilities
293 }
294
295 fn has_outgoing_capacity(&self) -> bool {
297 self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
298 }
299
300 fn send_pong(&mut self) {
302 self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Pong)));
303 }
304
305 pub fn send_ping(&mut self) {
307 self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Ping)));
308 }
309}
310
311pub trait DisconnectP2P {
314 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
316
317 fn is_disconnecting(&self) -> bool;
319}
320
321impl<S> DisconnectP2P for P2PStream<S> {
322 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
331 self.outgoing_messages.clear();
333 let disconnect = P2PMessage::Disconnect(reason);
334 let mut buf = Vec::with_capacity(disconnect.length());
335 disconnect.encode(&mut buf);
336
337 let mut compressed = vec![0u8; 1 + snap::raw::max_compress_len(buf.len() - 1)];
338 let compressed_size =
339 self.encoder.compress(&buf[1..], &mut compressed[1..]).map_err(|err| {
340 debug!(
341 %err,
342 msg=%hex::encode(&buf[1..]),
343 "error compressing disconnect"
344 );
345 err
346 })?;
347
348 compressed.truncate(compressed_size + 1);
351
352 compressed[0] = buf[0];
355
356 self.outgoing_messages.push_back(compressed.into());
357 self.disconnecting = true;
358 Ok(())
359 }
360
361 fn is_disconnecting(&self) -> bool {
362 self.disconnecting
363 }
364}
365
366impl<S> P2PStream<S>
367where
368 S: Sink<Bytes, Error = io::Error> + Unpin + Send,
369{
370 pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
375 self.start_disconnect(reason)?;
376 self.close().await
377 }
378}
379
380impl<S> Stream for P2PStream<S>
383where
384 S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
385{
386 type Item = Result<BytesMut, P2PStreamError>;
387
388 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
389 let this = self.get_mut();
390
391 if this.disconnecting {
392 return Poll::Ready(None)
394 }
395
396 while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
399 let bytes = match res {
400 Some(Ok(bytes)) => bytes,
401 Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
402 None => return Poll::Ready(None),
403 };
404
405 if bytes.is_empty() {
406 return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
408 }
409
410 let id = bytes[0];
415 if id == P2PMessageID::Disconnect as u8 {
416 if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
428 return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
429 }
430 }
431
432 let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
435 if decompressed_len > MAX_PAYLOAD_SIZE {
436 return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
437 message_size: decompressed_len,
438 max_size: MAX_PAYLOAD_SIZE,
439 })))
440 }
441
442 let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
445
446 this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..]).map_err(|err| {
449 debug!(
450 %err,
451 msg=%hex::encode(&bytes[1..]),
452 "error decompressing p2p message"
453 );
454 err
455 })?;
456
457 match id {
458 _ if id == P2PMessageID::Ping as u8 => {
459 trace!("Received Ping, Sending Pong");
460 this.send_pong();
461 cx.waker().wake_by_ref();
464 }
465 _ if id == P2PMessageID::Hello as u8 => {
466 return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
469 P2PHandshakeError::HelloNotInHandshake,
470 ))))
471 }
472 _ if id == P2PMessageID::Pong as u8 => {
473 this.pinger.on_pong()?
475 }
476 _ if id == P2PMessageID::Disconnect as u8 => {
477 let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).inspect_err(|err| {
483 debug!(
484 %err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
485 );
486 })?;
487 return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
488 }
489 _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
490 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
492 }
493 _ => {
494 decompress_buf[0] = bytes[0] - MAX_RESERVED_MESSAGE_ID - 1;
518
519 return Poll::Ready(Some(Ok(decompress_buf)))
520 }
521 }
522 }
523
524 Poll::Pending
525 }
526}
527
528impl<S> Sink<Bytes> for P2PStream<S>
529where
530 S: Sink<Bytes, Error = io::Error> + Unpin,
531{
532 type Error = P2PStreamError;
533
534 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
535 let mut this = self.as_mut();
536
537 match this.pinger.poll_ping(cx) {
539 Poll::Pending => {}
540 Poll::Ready(Ok(PingerEvent::Ping)) => {
541 this.send_ping();
542 }
543 _ => {
544 this.start_disconnect(DisconnectReason::PingTimeout)?;
546
547 return Poll::Ready(Ok(()))
549 }
550 }
551
552 match this.inner.poll_ready_unpin(cx) {
553 Poll::Pending => {}
554 Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
555 Poll::Ready(Ok(())) => {
556 let flushed = this.poll_flush(cx);
557 if flushed.is_ready() {
558 return flushed
559 }
560 }
561 }
562
563 if self.has_outgoing_capacity() {
564 Poll::Ready(Ok(()))
566 } else {
567 Poll::Pending
568 }
569 }
570
571 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
572 if item.len() > MAX_PAYLOAD_SIZE {
573 return Err(P2PStreamError::MessageTooBig {
574 message_size: item.len(),
575 max_size: MAX_PAYLOAD_SIZE,
576 })
577 }
578
579 if item.is_empty() {
580 return Err(P2PStreamError::EmptyProtocolMessage)
582 }
583
584 if !self.has_outgoing_capacity() {
586 return Err(P2PStreamError::SendBufferFull)
587 }
588
589 let this = self.project();
590
591 let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
592 let compressed_size =
593 this.encoder.compress(&item[1..], &mut compressed[1..]).map_err(|err| {
594 debug!(
595 %err,
596 msg=%hex::encode(&item[1..]),
597 "error compressing p2p message"
598 );
599 err
600 })?;
601
602 compressed.truncate(compressed_size + 1);
605
606 compressed[0] = item[0] + MAX_RESERVED_MESSAGE_ID + 1;
609 this.outgoing_messages.push_back(compressed.freeze());
610
611 Ok(())
612 }
613
614 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616 let mut this = self.project();
617 let poll_res = loop {
618 match this.inner.as_mut().poll_ready(cx) {
619 Poll::Pending => break Poll::Pending,
620 Poll::Ready(Err(err)) => break Poll::Ready(Err(err.into())),
621 Poll::Ready(Ok(())) => {
622 let Some(message) = this.outgoing_messages.pop_front() else {
623 break Poll::Ready(Ok(()))
624 };
625 if let Err(err) = this.inner.as_mut().start_send(message) {
626 break Poll::Ready(Err(err.into()))
627 }
628 }
629 }
630 };
631
632 ready!(this.inner.as_mut().poll_flush(cx))?;
633
634 poll_res
635 }
636
637 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
638 ready!(self.as_mut().poll_flush(cx))?;
639 ready!(self.project().inner.poll_close(cx))?;
640
641 Poll::Ready(Ok(()))
642 }
643}
644
645#[derive(Debug, Clone, PartialEq, Eq)]
647#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
648#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
649#[add_arbitrary_tests(rlp)]
650pub enum P2PMessage {
651 Hello(HelloMessage),
653
654 Disconnect(DisconnectReason),
657
658 Ping,
660
661 Pong,
663}
664
665impl P2PMessage {
666 pub const fn message_id(&self) -> P2PMessageID {
668 match self {
669 Self::Hello(_) => P2PMessageID::Hello,
670 Self::Disconnect(_) => P2PMessageID::Disconnect,
671 Self::Ping => P2PMessageID::Ping,
672 Self::Pong => P2PMessageID::Pong,
673 }
674 }
675}
676
677impl Encodable for P2PMessage {
678 fn encode(&self, out: &mut dyn BufMut) {
683 (self.message_id() as u8).encode(out);
684 match self {
685 Self::Hello(msg) => msg.encode(out),
686 Self::Disconnect(msg) => msg.encode(out),
687 Self::Ping => {
688 out.put_u8(0x01);
690 out.put_u8(0x00);
691 out.put_u8(EMPTY_LIST_CODE);
692 }
693 Self::Pong => {
694 out.put_u8(0x01);
696 out.put_u8(0x00);
697 out.put_u8(EMPTY_LIST_CODE);
698 }
699 }
700 }
701
702 fn length(&self) -> usize {
703 let payload_len = match self {
704 Self::Hello(msg) => msg.length(),
705 Self::Disconnect(msg) => msg.length(),
706 Self::Ping | Self::Pong => 3, };
709 payload_len + 1 }
711}
712
713impl Decodable for P2PMessage {
714 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
721 fn advance_snappy_ping_pong_payload(buf: &mut &[u8]) -> alloy_rlp::Result<()> {
723 if buf.len() < 3 {
724 return Err(RlpError::InputTooShort)
725 }
726 if buf[..3] != [0x01, 0x00, EMPTY_LIST_CODE] {
727 return Err(RlpError::Custom("expected snappy payload"))
728 }
729 buf.advance(3);
730 Ok(())
731 }
732
733 let message_id = u8::decode(&mut &buf[..])?;
734 let id = P2PMessageID::try_from(message_id)
735 .or(Err(RlpError::Custom("unknown p2p message id")))?;
736 buf.advance(1);
737 match id {
738 P2PMessageID::Hello => Ok(Self::Hello(HelloMessage::decode(buf)?)),
739 P2PMessageID::Disconnect => Ok(Self::Disconnect(DisconnectReason::decode(buf)?)),
740 P2PMessageID::Ping => {
741 advance_snappy_ping_pong_payload(buf)?;
742 Ok(Self::Ping)
743 }
744 P2PMessageID::Pong => {
745 advance_snappy_ping_pong_payload(buf)?;
746 Ok(Self::Pong)
747 }
748 }
749 }
750}
751
752#[derive(Debug, Copy, Clone, Eq, PartialEq)]
754pub enum P2PMessageID {
755 Hello = 0x00,
757
758 Disconnect = 0x01,
760
761 Ping = 0x02,
763
764 Pong = 0x03,
766}
767
768impl From<P2PMessage> for P2PMessageID {
769 fn from(msg: P2PMessage) -> Self {
770 match msg {
771 P2PMessage::Hello(_) => Self::Hello,
772 P2PMessage::Disconnect(_) => Self::Disconnect,
773 P2PMessage::Ping => Self::Ping,
774 P2PMessage::Pong => Self::Pong,
775 }
776 }
777}
778
779impl TryFrom<u8> for P2PMessageID {
780 type Error = P2PStreamError;
781
782 fn try_from(id: u8) -> Result<Self, Self::Error> {
783 match id {
784 0x00 => Ok(Self::Hello),
785 0x01 => Ok(Self::Disconnect),
786 0x02 => Ok(Self::Ping),
787 0x03 => Ok(Self::Pong),
788 _ => Err(P2PStreamError::UnknownReservedMessageId(id)),
789 }
790 }
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796 use crate::{capability::SharedCapability, test_utils::eth_hello, EthVersion, ProtocolVersion};
797 use tokio::net::{TcpListener, TcpStream};
798 use tokio_util::codec::Decoder;
799
800 #[tokio::test]
801 async fn test_can_disconnect() {
802 reth_tracing::init_test_tracing();
803 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
804 let local_addr = listener.local_addr().unwrap();
805
806 let expected_disconnect = DisconnectReason::UselessPeer;
807
808 let handle = tokio::spawn(async move {
809 let (incoming, _) = listener.accept().await.unwrap();
811 let stream = crate::PassthroughCodec::default().framed(incoming);
812
813 let (server_hello, _) = eth_hello();
814
815 let (mut p2p_stream, _) =
816 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
817
818 p2p_stream.disconnect(expected_disconnect).await.unwrap();
819 });
820
821 let outgoing = TcpStream::connect(local_addr).await.unwrap();
822 let sink = crate::PassthroughCodec::default().framed(outgoing);
823
824 let (client_hello, _) = eth_hello();
825
826 let (mut p2p_stream, _) =
827 UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
828
829 let err = p2p_stream.next().await.unwrap().unwrap_err();
830 match err {
831 P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
832 e => panic!("unexpected err: {e}"),
833 }
834
835 handle.await.unwrap();
836 }
837
838 #[tokio::test]
839 async fn test_can_disconnect_weird_disconnect_encoding() {
840 reth_tracing::init_test_tracing();
841 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
842 let local_addr = listener.local_addr().unwrap();
843
844 let expected_disconnect = DisconnectReason::SubprotocolSpecific;
845
846 let handle = tokio::spawn(async move {
847 let (incoming, _) = listener.accept().await.unwrap();
849 let stream = crate::PassthroughCodec::default().framed(incoming);
850
851 let (server_hello, _) = eth_hello();
852
853 let (mut p2p_stream, _) =
854 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
855
856 p2p_stream.outgoing_messages.clear();
858
859 p2p_stream.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(
860 P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific),
861 )));
862 p2p_stream.disconnecting = true;
863 p2p_stream.close().await.unwrap();
864 });
865
866 let outgoing = TcpStream::connect(local_addr).await.unwrap();
867 let sink = crate::PassthroughCodec::default().framed(outgoing);
868
869 let (client_hello, _) = eth_hello();
870
871 let (mut p2p_stream, _) =
872 UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
873
874 let err = p2p_stream.next().await.unwrap().unwrap_err();
875 match err {
876 P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
877 e => panic!("unexpected err: {e}"),
878 }
879
880 handle.await.unwrap();
881 }
882
883 #[tokio::test]
884 async fn test_handshake_passthrough() {
885 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
888 let local_addr = listener.local_addr().unwrap();
889
890 let handle = tokio::spawn(async move {
891 let (incoming, _) = listener.accept().await.unwrap();
893 let stream = crate::PassthroughCodec::default().framed(incoming);
894
895 let (server_hello, _) = eth_hello();
896
897 let unauthed_stream = UnauthedP2PStream::new(stream);
898 let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
899
900 assert_eq!(
902 *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
903 SharedCapability::Eth {
904 version: EthVersion::Eth67,
905 offset: MAX_RESERVED_MESSAGE_ID + 1
906 }
907 );
908 });
909
910 let outgoing = TcpStream::connect(local_addr).await.unwrap();
911 let sink = crate::PassthroughCodec::default().framed(outgoing);
912
913 let (client_hello, _) = eth_hello();
914
915 let unauthed_stream = UnauthedP2PStream::new(sink);
916 let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
917
918 assert_eq!(
920 *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
921 SharedCapability::Eth {
922 version: EthVersion::Eth67,
923 offset: MAX_RESERVED_MESSAGE_ID + 1
924 }
925 );
926
927 handle.await.unwrap();
929 }
930
931 #[tokio::test]
932 async fn test_handshake_disconnect() {
933 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
936 let local_addr = listener.local_addr().unwrap();
937
938 let handle = tokio::spawn(Box::pin(async move {
939 let (incoming, _) = listener.accept().await.unwrap();
941 let stream = crate::PassthroughCodec::default().framed(incoming);
942
943 let (server_hello, _) = eth_hello();
944
945 let unauthed_stream = UnauthedP2PStream::new(stream);
946 match unauthed_stream.handshake(server_hello.clone()).await {
947 Ok((_, hello)) => {
948 panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
949 }
950 Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
951 assert_ne!(expected, got);
952 assert_eq!(expected, server_hello.protocol_version);
953 }
954 Err(other_err) => {
955 panic!("expected mismatched protocol version error, got {other_err:?}")
956 }
957 }
958 }));
959
960 let outgoing = TcpStream::connect(local_addr).await.unwrap();
961 let sink = crate::PassthroughCodec::default().framed(outgoing);
962
963 let (mut client_hello, _) = eth_hello();
964
965 client_hello.protocol_version = ProtocolVersion::V4;
967
968 let unauthed_stream = UnauthedP2PStream::new(sink);
969 match unauthed_stream.handshake(client_hello.clone()).await {
970 Ok((_, hello)) => {
971 panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
972 }
973 Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
974 assert_ne!(expected, got);
975 assert_eq!(expected, client_hello.protocol_version);
976 }
977 Err(other_err) => {
978 panic!("expected mismatched protocol version error, got {other_err:?}")
979 }
980 }
981
982 handle.await.unwrap();
984 }
985
986 #[test]
987 fn snappy_decode_encode_ping() {
988 let snappy_ping = b"\x02\x01\0\xc0";
989 let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
990 assert!(matches!(ping, P2PMessage::Ping));
991 assert_eq!(alloy_rlp::encode(ping), &snappy_ping[..]);
992 }
993
994 #[test]
995 fn snappy_decode_encode_pong() {
996 let snappy_pong = b"\x03\x01\0\xc0";
997 let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
998 assert!(matches!(pong, P2PMessage::Pong));
999 assert_eq!(alloy_rlp::encode(pong), &snappy_pong[..]);
1000 }
1001}