reth_eth_wire/
p2pstream.rs

1use crate::{
2    capability::SharedCapabilities,
3    disconnect::CanDisconnect,
4    errors::{P2PHandshakeError, P2PStreamError},
5    pinger::{Pinger, PingerEvent},
6    DisconnectReason, HelloMessage, HelloMessageWithProtocols,
7};
8use alloy_primitives::{
9    bytes::{Buf, BufMut, Bytes, BytesMut},
10    hex,
11};
12use alloy_rlp::{Decodable, Encodable, Error as RlpError, EMPTY_LIST_CODE};
13use futures::{Sink, SinkExt, StreamExt};
14use pin_project::pin_project;
15use reth_codecs::add_arbitrary_tests;
16use reth_metrics::metrics::counter;
17use reth_primitives_traits::GotExpected;
18use std::{
19    collections::VecDeque,
20    io,
21    pin::Pin,
22    task::{ready, Context, Poll},
23    time::Duration,
24};
25use tokio_stream::Stream;
26use tracing::{debug, trace};
27
28#[cfg(feature = "serde")]
29use serde::{Deserialize, Serialize};
30
31/// [`MAX_PAYLOAD_SIZE`] is the maximum size of an uncompressed message payload.
32/// This is defined in [EIP-706](https://eips.ethereum.org/EIPS/eip-706).
33const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
34
35/// [`MAX_RESERVED_MESSAGE_ID`] is the maximum message ID reserved for the `p2p` subprotocol. If
36/// there are any incoming messages with an ID greater than this, they are subprotocol messages.
37pub const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
38
39/// [`MAX_P2P_MESSAGE_ID`] is the maximum message ID in use for the `p2p` subprotocol.
40const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
41
42/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p`
43/// handshake has timed out.
44pub(crate) const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
45
46/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has
47/// timed out.
48const PING_TIMEOUT: Duration = Duration::from_secs(15);
49
50/// [`PING_INTERVAL`] determines the amount of time to wait between sending `p2p` ping messages
51/// when the peer is responsive.
52const PING_INTERVAL: Duration = Duration::from_secs(60);
53
54/// [`MAX_P2P_CAPACITY`] is the maximum number of messages that can be buffered to be sent in the
55/// `p2p` stream.
56///
57/// Note: this default is rather low because it is expected that the [`P2PStream`] wraps an
58/// [`ECIESStream`](reth_ecies::stream::ECIESStream) which internally already buffers a few MB of
59/// encoded data.
60const MAX_P2P_CAPACITY: usize = 2;
61
62/// An un-authenticated [`P2PStream`]. This is consumed and returns a [`P2PStream`] after the
63/// `Hello` handshake is completed.
64#[pin_project]
65#[derive(Debug)]
66pub struct UnauthedP2PStream<S> {
67    #[pin]
68    inner: S,
69}
70
71impl<S> UnauthedP2PStream<S> {
72    /// Create a new `UnauthedP2PStream` from a type `S` which implements `Stream` and `Sink`.
73    pub const fn new(inner: S) -> Self {
74        Self { inner }
75    }
76
77    /// Returns a reference to the inner stream.
78    pub const fn inner(&self) -> &S {
79        &self.inner
80    }
81}
82
83impl<S> UnauthedP2PStream<S>
84where
85    S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
86{
87    /// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is
88    /// completed successfully. This also returns the `Hello` message sent by the remote peer.
89    pub async fn handshake(
90        mut self,
91        hello: HelloMessageWithProtocols,
92    ) -> Result<(P2PStream<S>, HelloMessage), P2PStreamError> {
93        trace!(?hello, "sending p2p hello to peer");
94
95        // send our hello message with the Sink
96        self.inner.send(alloy_rlp::encode(P2PMessage::Hello(hello.message())).into()).await?;
97
98        let first_message_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
99            .await
100            .or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
101            .ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
102
103        // let's check the compressed length first, we will need to check again once confirming
104        // that it contains snappy-compressed data (this will be the case for all non-p2p messages).
105        if first_message_bytes.len() > MAX_PAYLOAD_SIZE {
106            return Err(P2PStreamError::MessageTooBig {
107                message_size: first_message_bytes.len(),
108                max_size: MAX_PAYLOAD_SIZE,
109            })
110        }
111
112        // The first message sent MUST be a hello OR disconnect message
113        //
114        // If the first message is a disconnect message, we should not decode using
115        // Decodable::decode, because the first message (either Disconnect or Hello) is not snappy
116        // compressed, and the Decodable implementation assumes that non-hello messages are snappy
117        // compressed.
118        let their_hello = match P2PMessage::decode(&mut &first_message_bytes[..]) {
119            Ok(P2PMessage::Hello(hello)) => Ok(hello),
120            Ok(P2PMessage::Disconnect(reason)) => {
121                if matches!(reason, DisconnectReason::TooManyPeers) {
122                    // Too many peers is a very common disconnect reason that spams the DEBUG logs
123                    trace!(%reason, "Disconnected by peer during handshake");
124                } else {
125                    debug!(%reason, "Disconnected by peer during handshake");
126                };
127                counter!("p2pstream.disconnected_errors").increment(1);
128                Err(P2PStreamError::HandshakeError(P2PHandshakeError::Disconnected(reason)))
129            }
130            Err(err) => {
131                debug!(%err, msg=%hex::encode(&first_message_bytes), "Failed to decode first message from peer");
132                Err(P2PStreamError::HandshakeError(err.into()))
133            }
134            Ok(msg) => {
135                debug!(?msg, "expected hello message but received another message");
136                Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
137            }
138        }?;
139
140        trace!(
141            hello=?their_hello,
142            "validating incoming p2p hello from peer"
143        );
144
145        if (hello.protocol_version as u8) != their_hello.protocol_version as u8 {
146            // send a disconnect message notifying the peer of the protocol version mismatch
147            self.send_disconnect(DisconnectReason::IncompatibleP2PProtocolVersion).await?;
148            return Err(P2PStreamError::MismatchedProtocolVersion(GotExpected {
149                got: their_hello.protocol_version,
150                expected: hello.protocol_version,
151            }))
152        }
153
154        // determine shared capabilities (currently returns only one capability)
155        let capability_res =
156            SharedCapabilities::try_new(hello.protocols, their_hello.capabilities.clone());
157
158        let shared_capability = match capability_res {
159            Err(err) => {
160                // we don't share any capabilities, send a disconnect message
161                self.send_disconnect(DisconnectReason::UselessPeer).await?;
162                Err(err)
163            }
164            Ok(cap) => Ok(cap),
165        }?;
166
167        let stream = P2PStream::new(self.inner, shared_capability);
168
169        Ok((stream, their_hello))
170    }
171}
172
173impl<S> UnauthedP2PStream<S>
174where
175    S: Sink<Bytes, Error = io::Error> + Unpin,
176{
177    /// Send a disconnect message during the handshake. This is sent without snappy compression.
178    pub async fn send_disconnect(
179        &mut self,
180        reason: DisconnectReason,
181    ) -> Result<(), P2PStreamError> {
182        trace!(
183            %reason,
184            "Sending disconnect message during the handshake",
185        );
186        self.inner
187            .send(Bytes::from(alloy_rlp::encode(P2PMessage::Disconnect(reason))))
188            .await
189            .map_err(P2PStreamError::Io)
190    }
191}
192
193impl<S> CanDisconnect<Bytes> for P2PStream<S>
194where
195    S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
196{
197    async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
198        self.disconnect(reason).await
199    }
200}
201
202/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
203/// protocol messages.
204///
205/// This stream supports multiple shared capabilities, that were negotiated during the handshake.
206///
207/// ### Message-ID based multiplexing
208///
209/// > Each capability is given as much of the message-ID space as it needs. All such capabilities
210/// > must statically specify how many message IDs they require. On connection and reception of the
211/// > Hello message, both peers have equivalent information about what capabilities they share
212/// > (including versions) and are able to form consensus over the composition of message ID space.
213///
214/// > Message IDs are assumed to be compact from ID 0x10 onwards (0x00-0x0f is reserved for the
215/// > "p2p" capability) and given to each shared (equal-version, equal-name) capability in
216/// > alphabetic order. Capability names are case-sensitive. Capabilities which are not shared are
217/// > ignored. If multiple versions are shared of the same (equal name) capability, the numerically
218/// > highest wins, others are ignored.
219///
220/// See also <https://github.com/ethereum/devp2p/blob/master/rlpx.md#message-id-based-multiplexing>
221///
222/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first
223/// byte of each message starts from 0. If this stream only supports a single capability, for
224/// example `eth` then the first byte of each message will match
225/// [EthMessageID](reth_eth_wire_types::message::EthMessageID).
226#[pin_project]
227#[derive(Debug)]
228pub struct P2PStream<S> {
229    #[pin]
230    inner: S,
231
232    /// The snappy encoder used for compressing outgoing messages
233    encoder: snap::raw::Encoder,
234
235    /// The snappy decoder used for decompressing incoming messages
236    decoder: snap::raw::Decoder,
237
238    /// The state machine used for keeping track of the peer's ping status.
239    pinger: Pinger,
240
241    /// The supported capability for this stream.
242    shared_capabilities: SharedCapabilities,
243
244    /// Outgoing messages buffered for sending to the underlying stream.
245    outgoing_messages: VecDeque<Bytes>,
246
247    /// Maximum number of messages that we can buffer here before the [Sink] impl returns
248    /// [`Poll::Pending`].
249    outgoing_message_buffer_capacity: usize,
250
251    /// Whether this stream is currently in the process of disconnecting by sending a disconnect
252    /// message.
253    disconnecting: bool,
254}
255
256impl<S> P2PStream<S> {
257    /// Create a new [`P2PStream`] from the provided stream.
258    /// New [`P2PStream`]s are assumed to have completed the `p2p` handshake successfully and are
259    /// ready to send and receive subprotocol messages.
260    pub fn new(inner: S, shared_capabilities: SharedCapabilities) -> Self {
261        Self {
262            inner,
263            encoder: snap::raw::Encoder::new(),
264            decoder: snap::raw::Decoder::new(),
265            pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
266            shared_capabilities,
267            outgoing_messages: VecDeque::new(),
268            outgoing_message_buffer_capacity: MAX_P2P_CAPACITY,
269            disconnecting: false,
270        }
271    }
272
273    /// Returns a reference to the inner stream.
274    pub const fn inner(&self) -> &S {
275        &self.inner
276    }
277
278    /// Sets a custom outgoing message buffer capacity.
279    ///
280    /// # Panics
281    ///
282    /// If the provided capacity is `0`.
283    pub fn set_outgoing_message_buffer_capacity(&mut self, capacity: usize) {
284        self.outgoing_message_buffer_capacity = capacity;
285    }
286
287    /// Returns the shared capabilities for this stream.
288    ///
289    /// This includes all the shared capabilities that were negotiated during the handshake and
290    /// their offsets based on the number of messages of each capability.
291    pub const fn shared_capabilities(&self) -> &SharedCapabilities {
292        &self.shared_capabilities
293    }
294
295    /// Returns `true` if the stream has outgoing capacity.
296    fn has_outgoing_capacity(&self) -> bool {
297        self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
298    }
299
300    /// Queues in a _snappy_ encoded [`P2PMessage::Pong`] message.
301    fn send_pong(&mut self) {
302        self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Pong)));
303    }
304
305    /// Queues in a _snappy_ encoded [`P2PMessage::Ping`] message.
306    pub fn send_ping(&mut self) {
307        self.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(P2PMessage::Ping)));
308    }
309}
310
311/// Gracefully disconnects the connection by sending a disconnect message and stop reading new
312/// messages.
313pub trait DisconnectP2P {
314    /// Starts to gracefully disconnect.
315    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
316
317    /// Returns `true` if the connection is about to disconnect.
318    fn is_disconnecting(&self) -> bool;
319}
320
321impl<S> DisconnectP2P for P2PStream<S> {
322    /// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
323    /// reading new messages.
324    ///
325    /// Once disconnect process has started, the [`Stream`] will terminate immediately.
326    ///
327    /// # Errors
328    ///
329    /// Returns an error only if the message fails to compress.
330    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
331        // clear any buffered messages and queue in
332        self.outgoing_messages.clear();
333        let disconnect = P2PMessage::Disconnect(reason);
334        let mut buf = Vec::with_capacity(disconnect.length());
335        disconnect.encode(&mut buf);
336
337        let mut compressed = vec![0u8; 1 + snap::raw::max_compress_len(buf.len() - 1)];
338        let compressed_size =
339            self.encoder.compress(&buf[1..], &mut compressed[1..]).map_err(|err| {
340                debug!(
341                    %err,
342                    msg=%hex::encode(&buf[1..]),
343                    "error compressing disconnect"
344                );
345                err
346            })?;
347
348        // truncate the compressed buffer to the actual compressed size (plus one for the message
349        // id)
350        compressed.truncate(compressed_size + 1);
351
352        // we do not add the capability offset because the disconnect message is a `p2p` reserved
353        // message
354        compressed[0] = buf[0];
355
356        self.outgoing_messages.push_back(compressed.into());
357        self.disconnecting = true;
358        Ok(())
359    }
360
361    fn is_disconnecting(&self) -> bool {
362        self.disconnecting
363    }
364}
365
366impl<S> P2PStream<S>
367where
368    S: Sink<Bytes, Error = io::Error> + Unpin + Send,
369{
370    /// Disconnects the connection by sending a disconnect message.
371    ///
372    /// This future resolves once the disconnect message has been sent and the stream has been
373    /// closed.
374    pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
375        self.start_disconnect(reason)?;
376        self.close().await
377    }
378}
379
380// S must also be `Sink` because we need to be able to respond with ping messages to follow the
381// protocol
382impl<S> Stream for P2PStream<S>
383where
384    S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
385{
386    type Item = Result<BytesMut, P2PStreamError>;
387
388    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
389        let this = self.get_mut();
390
391        if this.disconnecting {
392            // if disconnecting, stop reading messages
393            return Poll::Ready(None)
394        }
395
396        // we should loop here to ensure we don't return Poll::Pending if we have a message to
397        // return behind any pings we need to respond to
398        while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
399            let bytes = match res {
400                Some(Ok(bytes)) => bytes,
401                Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
402                None => return Poll::Ready(None),
403            };
404
405            if bytes.is_empty() {
406                // empty messages are not allowed
407                return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
408            }
409
410            // first decode disconnect reasons, because they can be encoded in a variety of forms
411            // over the wire, in both snappy compressed and uncompressed forms.
412            //
413            // see: [crate::disconnect::tests::test_decode_known_reasons]
414            let id = bytes[0];
415            if id == P2PMessageID::Disconnect as u8 {
416                // We can't handle the error here because disconnect reasons are encoded as both:
417                // * snappy compressed, AND
418                // * uncompressed
419                // over the network.
420                //
421                // If the decoding succeeds, we already checked the id and know this is a
422                // disconnect message, so we can return with the reason.
423                //
424                // If the decoding fails, we continue, and will attempt to decode it again if the
425                // message is snappy compressed. Failure handling in that step is the primary point
426                // where an error is returned if the disconnect reason is malformed.
427                if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
428                    return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
429                }
430            }
431
432            // first check that the compressed message length does not exceed the max
433            // payload size
434            let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
435            if decompressed_len > MAX_PAYLOAD_SIZE {
436                return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
437                    message_size: decompressed_len,
438                    max_size: MAX_PAYLOAD_SIZE,
439                })))
440            }
441
442            // create a buffer to hold the decompressed message, adding a byte to the length for
443            // the message ID byte, which is the first byte in this buffer
444            let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
445
446            // each message following a successful handshake is compressed with snappy, so we need
447            // to decompress the message before we can decode it.
448            this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..]).map_err(|err| {
449                debug!(
450                    %err,
451                    msg=%hex::encode(&bytes[1..]),
452                    "error decompressing p2p message"
453                );
454                err
455            })?;
456
457            match id {
458                _ if id == P2PMessageID::Ping as u8 => {
459                    trace!("Received Ping, Sending Pong");
460                    this.send_pong();
461                    // This is required because the `Sink` may not be polled externally, and if
462                    // that happens, the pong will never be sent.
463                    cx.waker().wake_by_ref();
464                }
465                _ if id == P2PMessageID::Hello as u8 => {
466                    // we have received a hello message outside of the handshake, so we will return
467                    // an error
468                    return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
469                        P2PHandshakeError::HelloNotInHandshake,
470                    ))))
471                }
472                _ if id == P2PMessageID::Pong as u8 => {
473                    // if we were waiting for a pong, this will reset the pinger state
474                    this.pinger.on_pong()?
475                }
476                _ if id == P2PMessageID::Disconnect as u8 => {
477                    // At this point, the `decompress_buf` contains the snappy decompressed
478                    // disconnect message.
479                    //
480                    // It's possible we already tried to RLP decode this, but it was snappy
481                    // compressed, so we need to RLP decode it again.
482                    let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).inspect_err(|err| {
483                        debug!(
484                            %err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
485                        );
486                    })?;
487                    return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
488                }
489                _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
490                    // we have received an unknown reserved message
491                    return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
492                }
493                _ => {
494                    // we have received a message that is outside the `p2p` reserved message space,
495                    // so it is a subprotocol message.
496
497                    // Peers must be able to identify messages meant for different subprotocols
498                    // using a single message ID byte, and those messages must be distinct from the
499                    // lower-level `p2p` messages.
500                    //
501                    // To ensure that messages for subprotocols are distinct from messages meant
502                    // for the `p2p` capability, message IDs 0x00 - 0x0f are reserved for `p2p`
503                    // messages, so subprotocol messages must have an ID of 0x10 or higher.
504                    //
505                    // To ensure that messages for two different capabilities are distinct from
506                    // each other, all shared capabilities are first ordered lexicographically.
507                    // Message IDs are then reserved in this order, starting at 0x10, reserving a
508                    // message ID for each message the capability supports.
509                    //
510                    // For example, if the shared capabilities are `eth/67` (containing 10
511                    // messages), and "qrs/65" (containing 8 messages):
512                    //
513                    //  * The special case of `p2p`: `p2p` is reserved message IDs 0x00 - 0x0f.
514                    //  * `eth/67` is reserved message IDs 0x10 - 0x19.
515                    //  * `qrs/65` is reserved message IDs 0x1a - 0x21.
516                    //
517                    decompress_buf[0] = bytes[0] - MAX_RESERVED_MESSAGE_ID - 1;
518
519                    return Poll::Ready(Some(Ok(decompress_buf)))
520                }
521            }
522        }
523
524        Poll::Pending
525    }
526}
527
528impl<S> Sink<Bytes> for P2PStream<S>
529where
530    S: Sink<Bytes, Error = io::Error> + Unpin,
531{
532    type Error = P2PStreamError;
533
534    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
535        let mut this = self.as_mut();
536
537        // poll the pinger to determine if we should send a ping
538        match this.pinger.poll_ping(cx) {
539            Poll::Pending => {}
540            Poll::Ready(Ok(PingerEvent::Ping)) => {
541                this.send_ping();
542            }
543            _ => {
544                // encode the disconnect message
545                this.start_disconnect(DisconnectReason::PingTimeout)?;
546
547                // End the stream after ping related error
548                return Poll::Ready(Ok(()))
549            }
550        }
551
552        match this.inner.poll_ready_unpin(cx) {
553            Poll::Pending => {}
554            Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
555            Poll::Ready(Ok(())) => {
556                let flushed = this.poll_flush(cx);
557                if flushed.is_ready() {
558                    return flushed
559                }
560            }
561        }
562
563        if self.has_outgoing_capacity() {
564            // still has capacity
565            Poll::Ready(Ok(()))
566        } else {
567            Poll::Pending
568        }
569    }
570
571    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
572        if item.len() > MAX_PAYLOAD_SIZE {
573            return Err(P2PStreamError::MessageTooBig {
574                message_size: item.len(),
575                max_size: MAX_PAYLOAD_SIZE,
576            })
577        }
578
579        if item.is_empty() {
580            // empty messages are not allowed
581            return Err(P2PStreamError::EmptyProtocolMessage)
582        }
583
584        // ensure we have free capacity
585        if !self.has_outgoing_capacity() {
586            return Err(P2PStreamError::SendBufferFull)
587        }
588
589        let this = self.project();
590
591        let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
592        let compressed_size =
593            this.encoder.compress(&item[1..], &mut compressed[1..]).map_err(|err| {
594                debug!(
595                    %err,
596                    msg=%hex::encode(&item[1..]),
597                    "error compressing p2p message"
598                );
599                err
600            })?;
601
602        // truncate the compressed buffer to the actual compressed size (plus one for the message
603        // id)
604        compressed.truncate(compressed_size + 1);
605
606        // all messages sent in this stream are subprotocol messages, so we need to switch the
607        // message id based on the offset
608        compressed[0] = item[0] + MAX_RESERVED_MESSAGE_ID + 1;
609        this.outgoing_messages.push_back(compressed.freeze());
610
611        Ok(())
612    }
613
614    /// Returns `Poll::Ready(Ok(()))` when no buffered items remain.
615    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616        let mut this = self.project();
617        let poll_res = loop {
618            match this.inner.as_mut().poll_ready(cx) {
619                Poll::Pending => break Poll::Pending,
620                Poll::Ready(Err(err)) => break Poll::Ready(Err(err.into())),
621                Poll::Ready(Ok(())) => {
622                    let Some(message) = this.outgoing_messages.pop_front() else {
623                        break Poll::Ready(Ok(()))
624                    };
625                    if let Err(err) = this.inner.as_mut().start_send(message) {
626                        break Poll::Ready(Err(err.into()))
627                    }
628                }
629            }
630        };
631
632        ready!(this.inner.as_mut().poll_flush(cx))?;
633
634        poll_res
635    }
636
637    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
638        ready!(self.as_mut().poll_flush(cx))?;
639        ready!(self.project().inner.poll_close(cx))?;
640
641        Poll::Ready(Ok(()))
642    }
643}
644
645/// This represents only the reserved `p2p` subprotocol messages.
646#[derive(Debug, Clone, PartialEq, Eq)]
647#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
648#[cfg_attr(any(test, feature = "arbitrary"), derive(arbitrary::Arbitrary))]
649#[add_arbitrary_tests(rlp)]
650pub enum P2PMessage {
651    /// The first packet sent over the connection, and sent once by both sides.
652    Hello(HelloMessage),
653
654    /// Inform the peer that a disconnection is imminent; if received, a peer should disconnect
655    /// immediately.
656    Disconnect(DisconnectReason),
657
658    /// Requests an immediate reply of [`P2PMessage::Pong`] from the peer.
659    Ping,
660
661    /// Reply to the peer's [`P2PMessage::Ping`] packet.
662    Pong,
663}
664
665impl P2PMessage {
666    /// Gets the [`P2PMessageID`] for the given message.
667    pub const fn message_id(&self) -> P2PMessageID {
668        match self {
669            Self::Hello(_) => P2PMessageID::Hello,
670            Self::Disconnect(_) => P2PMessageID::Disconnect,
671            Self::Ping => P2PMessageID::Ping,
672            Self::Pong => P2PMessageID::Pong,
673        }
674    }
675}
676
677impl Encodable for P2PMessage {
678    /// The [`Encodable`] implementation for [`P2PMessage::Ping`] and [`P2PMessage::Pong`] encodes
679    /// the message as RLP, and prepends a snappy header to the RLP bytes for all variants except
680    /// the [`P2PMessage::Hello`] variant, because the hello message is never compressed in the
681    /// `p2p` subprotocol.
682    fn encode(&self, out: &mut dyn BufMut) {
683        (self.message_id() as u8).encode(out);
684        match self {
685            Self::Hello(msg) => msg.encode(out),
686            Self::Disconnect(msg) => msg.encode(out),
687            Self::Ping => {
688                // Ping payload is _always_ snappy encoded
689                out.put_u8(0x01);
690                out.put_u8(0x00);
691                out.put_u8(EMPTY_LIST_CODE);
692            }
693            Self::Pong => {
694                // Pong payload is _always_ snappy encoded
695                out.put_u8(0x01);
696                out.put_u8(0x00);
697                out.put_u8(EMPTY_LIST_CODE);
698            }
699        }
700    }
701
702    fn length(&self) -> usize {
703        let payload_len = match self {
704            Self::Hello(msg) => msg.length(),
705            Self::Disconnect(msg) => msg.length(),
706            // id + snappy encoded payload
707            Self::Ping | Self::Pong => 3, // len([0x01, 0x00, 0xc0]) = 3
708        };
709        payload_len + 1 // (1 for length of p2p message id)
710    }
711}
712
713impl Decodable for P2PMessage {
714    /// The [`Decodable`] implementation for [`P2PMessage`] assumes that each of the message
715    /// variants are snappy compressed, except for the [`P2PMessage::Hello`] variant since the
716    /// hello message is never compressed in the `p2p` subprotocol.
717    ///
718    /// The [`Decodable`] implementation for [`P2PMessage::Ping`] and [`P2PMessage::Pong`] expects
719    /// a snappy encoded payload, see [`Encodable`] implementation.
720    fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
721        /// Removes the snappy prefix from the Ping/Pong buffer
722        fn advance_snappy_ping_pong_payload(buf: &mut &[u8]) -> alloy_rlp::Result<()> {
723            if buf.len() < 3 {
724                return Err(RlpError::InputTooShort)
725            }
726            if buf[..3] != [0x01, 0x00, EMPTY_LIST_CODE] {
727                return Err(RlpError::Custom("expected snappy payload"))
728            }
729            buf.advance(3);
730            Ok(())
731        }
732
733        let message_id = u8::decode(&mut &buf[..])?;
734        let id = P2PMessageID::try_from(message_id)
735            .or(Err(RlpError::Custom("unknown p2p message id")))?;
736        buf.advance(1);
737        match id {
738            P2PMessageID::Hello => Ok(Self::Hello(HelloMessage::decode(buf)?)),
739            P2PMessageID::Disconnect => Ok(Self::Disconnect(DisconnectReason::decode(buf)?)),
740            P2PMessageID::Ping => {
741                advance_snappy_ping_pong_payload(buf)?;
742                Ok(Self::Ping)
743            }
744            P2PMessageID::Pong => {
745                advance_snappy_ping_pong_payload(buf)?;
746                Ok(Self::Pong)
747            }
748        }
749    }
750}
751
752/// Message IDs for `p2p` subprotocol messages.
753#[derive(Debug, Copy, Clone, Eq, PartialEq)]
754pub enum P2PMessageID {
755    /// Message ID for the [`P2PMessage::Hello`] message.
756    Hello = 0x00,
757
758    /// Message ID for the [`P2PMessage::Disconnect`] message.
759    Disconnect = 0x01,
760
761    /// Message ID for the [`P2PMessage::Ping`] message.
762    Ping = 0x02,
763
764    /// Message ID for the [`P2PMessage::Pong`] message.
765    Pong = 0x03,
766}
767
768impl From<P2PMessage> for P2PMessageID {
769    fn from(msg: P2PMessage) -> Self {
770        match msg {
771            P2PMessage::Hello(_) => Self::Hello,
772            P2PMessage::Disconnect(_) => Self::Disconnect,
773            P2PMessage::Ping => Self::Ping,
774            P2PMessage::Pong => Self::Pong,
775        }
776    }
777}
778
779impl TryFrom<u8> for P2PMessageID {
780    type Error = P2PStreamError;
781
782    fn try_from(id: u8) -> Result<Self, Self::Error> {
783        match id {
784            0x00 => Ok(Self::Hello),
785            0x01 => Ok(Self::Disconnect),
786            0x02 => Ok(Self::Ping),
787            0x03 => Ok(Self::Pong),
788            _ => Err(P2PStreamError::UnknownReservedMessageId(id)),
789        }
790    }
791}
792
793#[cfg(test)]
794mod tests {
795    use super::*;
796    use crate::{capability::SharedCapability, test_utils::eth_hello, EthVersion, ProtocolVersion};
797    use tokio::net::{TcpListener, TcpStream};
798    use tokio_util::codec::Decoder;
799
800    #[tokio::test]
801    async fn test_can_disconnect() {
802        reth_tracing::init_test_tracing();
803        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
804        let local_addr = listener.local_addr().unwrap();
805
806        let expected_disconnect = DisconnectReason::UselessPeer;
807
808        let handle = tokio::spawn(async move {
809            // roughly based off of the design of tokio::net::TcpListener
810            let (incoming, _) = listener.accept().await.unwrap();
811            let stream = crate::PassthroughCodec::default().framed(incoming);
812
813            let (server_hello, _) = eth_hello();
814
815            let (mut p2p_stream, _) =
816                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
817
818            p2p_stream.disconnect(expected_disconnect).await.unwrap();
819        });
820
821        let outgoing = TcpStream::connect(local_addr).await.unwrap();
822        let sink = crate::PassthroughCodec::default().framed(outgoing);
823
824        let (client_hello, _) = eth_hello();
825
826        let (mut p2p_stream, _) =
827            UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
828
829        let err = p2p_stream.next().await.unwrap().unwrap_err();
830        match err {
831            P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
832            e => panic!("unexpected err: {e}"),
833        }
834
835        handle.await.unwrap();
836    }
837
838    #[tokio::test]
839    async fn test_can_disconnect_weird_disconnect_encoding() {
840        reth_tracing::init_test_tracing();
841        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
842        let local_addr = listener.local_addr().unwrap();
843
844        let expected_disconnect = DisconnectReason::SubprotocolSpecific;
845
846        let handle = tokio::spawn(async move {
847            // roughly based off of the design of tokio::net::TcpListener
848            let (incoming, _) = listener.accept().await.unwrap();
849            let stream = crate::PassthroughCodec::default().framed(incoming);
850
851            let (server_hello, _) = eth_hello();
852
853            let (mut p2p_stream, _) =
854                UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
855
856            // Unrolled `disconnect` method, without compression
857            p2p_stream.outgoing_messages.clear();
858
859            p2p_stream.outgoing_messages.push_back(Bytes::from(alloy_rlp::encode(
860                P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific),
861            )));
862            p2p_stream.disconnecting = true;
863            p2p_stream.close().await.unwrap();
864        });
865
866        let outgoing = TcpStream::connect(local_addr).await.unwrap();
867        let sink = crate::PassthroughCodec::default().framed(outgoing);
868
869        let (client_hello, _) = eth_hello();
870
871        let (mut p2p_stream, _) =
872            UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
873
874        let err = p2p_stream.next().await.unwrap().unwrap_err();
875        match err {
876            P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
877            e => panic!("unexpected err: {e}"),
878        }
879
880        handle.await.unwrap();
881    }
882
883    #[tokio::test]
884    async fn test_handshake_passthrough() {
885        // create a p2p stream and server, then confirm that the two are authed
886        // create tcpstream
887        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
888        let local_addr = listener.local_addr().unwrap();
889
890        let handle = tokio::spawn(async move {
891            // roughly based off of the design of tokio::net::TcpListener
892            let (incoming, _) = listener.accept().await.unwrap();
893            let stream = crate::PassthroughCodec::default().framed(incoming);
894
895            let (server_hello, _) = eth_hello();
896
897            let unauthed_stream = UnauthedP2PStream::new(stream);
898            let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
899
900            // ensure that the two share a single capability, eth67
901            assert_eq!(
902                *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
903                SharedCapability::Eth {
904                    version: EthVersion::Eth67,
905                    offset: MAX_RESERVED_MESSAGE_ID + 1
906                }
907            );
908        });
909
910        let outgoing = TcpStream::connect(local_addr).await.unwrap();
911        let sink = crate::PassthroughCodec::default().framed(outgoing);
912
913        let (client_hello, _) = eth_hello();
914
915        let unauthed_stream = UnauthedP2PStream::new(sink);
916        let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
917
918        // ensure that the two share a single capability, eth67
919        assert_eq!(
920            *p2p_stream.shared_capabilities.iter_caps().next().unwrap(),
921            SharedCapability::Eth {
922                version: EthVersion::Eth67,
923                offset: MAX_RESERVED_MESSAGE_ID + 1
924            }
925        );
926
927        // make sure the server receives the message and asserts before ending the test
928        handle.await.unwrap();
929    }
930
931    #[tokio::test]
932    async fn test_handshake_disconnect() {
933        // create a p2p stream and server, then confirm that the two are authed
934        // create tcpstream
935        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
936        let local_addr = listener.local_addr().unwrap();
937
938        let handle = tokio::spawn(Box::pin(async move {
939            // roughly based off of the design of tokio::net::TcpListener
940            let (incoming, _) = listener.accept().await.unwrap();
941            let stream = crate::PassthroughCodec::default().framed(incoming);
942
943            let (server_hello, _) = eth_hello();
944
945            let unauthed_stream = UnauthedP2PStream::new(stream);
946            match unauthed_stream.handshake(server_hello.clone()).await {
947                Ok((_, hello)) => {
948                    panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
949                }
950                Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
951                    assert_ne!(expected, got);
952                    assert_eq!(expected, server_hello.protocol_version);
953                }
954                Err(other_err) => {
955                    panic!("expected mismatched protocol version error, got {other_err:?}")
956                }
957            }
958        }));
959
960        let outgoing = TcpStream::connect(local_addr).await.unwrap();
961        let sink = crate::PassthroughCodec::default().framed(outgoing);
962
963        let (mut client_hello, _) = eth_hello();
964
965        // modify the hello to include an incompatible p2p protocol version
966        client_hello.protocol_version = ProtocolVersion::V4;
967
968        let unauthed_stream = UnauthedP2PStream::new(sink);
969        match unauthed_stream.handshake(client_hello.clone()).await {
970            Ok((_, hello)) => {
971                panic!("expected handshake to fail, instead got a successful Hello: {hello:?}")
972            }
973            Err(P2PStreamError::MismatchedProtocolVersion(GotExpected { got, expected })) => {
974                assert_ne!(expected, got);
975                assert_eq!(expected, client_hello.protocol_version);
976            }
977            Err(other_err) => {
978                panic!("expected mismatched protocol version error, got {other_err:?}")
979            }
980        }
981
982        // make sure the server receives the message and asserts before ending the test
983        handle.await.unwrap();
984    }
985
986    #[test]
987    fn snappy_decode_encode_ping() {
988        let snappy_ping = b"\x02\x01\0\xc0";
989        let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
990        assert!(matches!(ping, P2PMessage::Ping));
991        assert_eq!(alloy_rlp::encode(ping), &snappy_ping[..]);
992    }
993
994    #[test]
995    fn snappy_decode_encode_pong() {
996        let snappy_pong = b"\x03\x01\0\xc0";
997        let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
998        assert!(matches!(pong, P2PMessage::Pong));
999        assert_eq!(alloy_rlp::encode(pong), &snappy_pong[..]);
1000    }
1001}