reth_eth_wire/
disconnect.rs

1//! Disconnect
2
3use std::future::Future;
4
5use futures::{Sink, SinkExt};
6use reth_ecies::stream::ECIESStream;
7use reth_eth_wire_types::DisconnectReason;
8use tokio::io::AsyncWrite;
9use tokio_util::codec::{Encoder, Framed};
10
11/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using
12/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
13/// underlying stream supports it.
14pub trait CanDisconnect<T>: Sink<T> + Unpin {
15    /// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
16    /// information if the stream implements a protocol that can carry the additional disconnect
17    /// metadata.
18    fn disconnect(
19        &mut self,
20        reason: DisconnectReason,
21    ) -> impl Future<Output = Result<(), <Self as Sink<T>>::Error>> + Send;
22}
23
24// basic impls for things like Framed<TcpStream, etc>
25impl<T, I, U> CanDisconnect<I> for Framed<T, U>
26where
27    T: AsyncWrite + Unpin + Send,
28    U: Encoder<I> + Send,
29{
30    async fn disconnect(
31        &mut self,
32        _reason: DisconnectReason,
33    ) -> Result<(), <Self as Sink<I>>::Error> {
34        self.close().await
35    }
36}
37
38impl<S> CanDisconnect<bytes::Bytes> for ECIESStream<S>
39where
40    S: AsyncWrite + Unpin + Send,
41{
42    async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> {
43        self.close().await
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use crate::{p2pstream::P2PMessage, DisconnectReason};
50    use alloy_primitives::hex;
51    use alloy_rlp::{Decodable, Encodable};
52
53    fn all_reasons() -> Vec<DisconnectReason> {
54        vec![
55            DisconnectReason::DisconnectRequested,
56            DisconnectReason::TcpSubsystemError,
57            DisconnectReason::ProtocolBreach,
58            DisconnectReason::UselessPeer,
59            DisconnectReason::TooManyPeers,
60            DisconnectReason::AlreadyConnected,
61            DisconnectReason::IncompatibleP2PProtocolVersion,
62            DisconnectReason::NullNodeIdentity,
63            DisconnectReason::ClientQuitting,
64            DisconnectReason::UnexpectedHandshakeIdentity,
65            DisconnectReason::ConnectedToSelf,
66            DisconnectReason::PingTimeout,
67            DisconnectReason::SubprotocolSpecific,
68        ]
69    }
70
71    #[test]
72    fn disconnect_round_trip() {
73        let all_reasons = all_reasons();
74
75        for reason in all_reasons {
76            let disconnect = P2PMessage::Disconnect(reason);
77
78            let mut disconnect_encoded = Vec::new();
79            disconnect.encode(&mut disconnect_encoded);
80
81            let disconnect_decoded = P2PMessage::decode(&mut &disconnect_encoded[..]).unwrap();
82
83            assert_eq!(disconnect, disconnect_decoded);
84        }
85    }
86
87    #[test]
88    fn test_reason_too_short() {
89        assert!(DisconnectReason::decode(&mut &[0u8; 0][..]).is_err())
90    }
91
92    #[test]
93    fn test_reason_too_long() {
94        assert!(DisconnectReason::decode(&mut &[0u8; 3][..]).is_err())
95    }
96
97    #[test]
98    fn test_reason_zero_length_list() {
99        let list_with_zero_length = hex::decode("c000").unwrap();
100        let res = DisconnectReason::decode(&mut &list_with_zero_length[..]);
101        assert!(res.is_err());
102        assert_eq!(res.unwrap_err().to_string(), "unexpected list length (got 0, expected 1)")
103    }
104
105    #[test]
106    fn disconnect_encoding_length() {
107        let all_reasons = all_reasons();
108
109        for reason in all_reasons {
110            let disconnect = P2PMessage::Disconnect(reason);
111
112            let mut disconnect_encoded = Vec::new();
113            disconnect.encode(&mut disconnect_encoded);
114
115            assert_eq!(disconnect_encoded.len(), disconnect.length());
116        }
117    }
118
119    #[test]
120    fn test_decode_known_reasons() {
121        let all_reasons = vec![
122            // encoding the disconnect reason as a single byte
123            "0100", // 0x00 case
124            "0180", // second 0x00 case
125            "0101", "0102", "0103", "0104", "0105", "0106", "0107", "0108", "0109", "010a", "010b",
126            "0110",   // encoding the disconnect reason in a list
127            "01c100", // 0x00 case
128            "01c180", // second 0x00 case
129            "01c101", "01c102", "01c103", "01c104", "01c105", "01c106", "01c107", "01c108",
130            "01c109", "01c10a", "01c10b", "01c110",
131        ];
132
133        for reason in all_reasons {
134            let reason = hex::decode(reason).unwrap();
135            let message = P2PMessage::decode(&mut &reason[..]).unwrap();
136            let P2PMessage::Disconnect(_) = message else {
137                panic!("expected a disconnect message");
138            };
139        }
140    }
141
142    #[test]
143    fn test_decode_disconnect_requested() {
144        let reason = "0100";
145        let reason = hex::decode(reason).unwrap();
146        match P2PMessage::decode(&mut &reason[..]).unwrap() {
147            P2PMessage::Disconnect(DisconnectReason::DisconnectRequested) => {}
148            _ => {
149                unreachable!()
150            }
151        }
152    }
153}