reth_eth_wire/
disconnect.rs1use std::future::Future;
4
5use futures::{Sink, SinkExt};
6use reth_ecies::stream::ECIESStream;
7use reth_eth_wire_types::DisconnectReason;
8use tokio::io::AsyncWrite;
9use tokio_util::codec::{Encoder, Framed};
10
11pub trait CanDisconnect<T>: Sink<T> + Unpin {
15 fn disconnect(
19 &mut self,
20 reason: DisconnectReason,
21 ) -> impl Future<Output = Result<(), <Self as Sink<T>>::Error>> + Send;
22}
23
24impl<T, I, U> CanDisconnect<I> for Framed<T, U>
26where
27 T: AsyncWrite + Unpin + Send,
28 U: Encoder<I> + Send,
29{
30 async fn disconnect(
31 &mut self,
32 _reason: DisconnectReason,
33 ) -> Result<(), <Self as Sink<I>>::Error> {
34 self.close().await
35 }
36}
37
38impl<S> CanDisconnect<bytes::Bytes> for ECIESStream<S>
39where
40 S: AsyncWrite + Unpin + Send,
41{
42 async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> {
43 self.close().await
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use crate::{p2pstream::P2PMessage, DisconnectReason};
50 use alloy_primitives::hex;
51 use alloy_rlp::{Decodable, Encodable};
52
53 fn all_reasons() -> Vec<DisconnectReason> {
54 vec![
55 DisconnectReason::DisconnectRequested,
56 DisconnectReason::TcpSubsystemError,
57 DisconnectReason::ProtocolBreach,
58 DisconnectReason::UselessPeer,
59 DisconnectReason::TooManyPeers,
60 DisconnectReason::AlreadyConnected,
61 DisconnectReason::IncompatibleP2PProtocolVersion,
62 DisconnectReason::NullNodeIdentity,
63 DisconnectReason::ClientQuitting,
64 DisconnectReason::UnexpectedHandshakeIdentity,
65 DisconnectReason::ConnectedToSelf,
66 DisconnectReason::PingTimeout,
67 DisconnectReason::SubprotocolSpecific,
68 ]
69 }
70
71 #[test]
72 fn disconnect_round_trip() {
73 let all_reasons = all_reasons();
74
75 for reason in all_reasons {
76 let disconnect = P2PMessage::Disconnect(reason);
77
78 let mut disconnect_encoded = Vec::new();
79 disconnect.encode(&mut disconnect_encoded);
80
81 let disconnect_decoded = P2PMessage::decode(&mut &disconnect_encoded[..]).unwrap();
82
83 assert_eq!(disconnect, disconnect_decoded);
84 }
85 }
86
87 #[test]
88 fn test_reason_too_short() {
89 assert!(DisconnectReason::decode(&mut &[0u8; 0][..]).is_err())
90 }
91
92 #[test]
93 fn test_reason_too_long() {
94 assert!(DisconnectReason::decode(&mut &[0u8; 3][..]).is_err())
95 }
96
97 #[test]
98 fn test_reason_zero_length_list() {
99 let list_with_zero_length = hex::decode("c000").unwrap();
100 let res = DisconnectReason::decode(&mut &list_with_zero_length[..]);
101 assert!(res.is_err());
102 assert_eq!(res.unwrap_err().to_string(), "unexpected list length (got 0, expected 1)")
103 }
104
105 #[test]
106 fn disconnect_encoding_length() {
107 let all_reasons = all_reasons();
108
109 for reason in all_reasons {
110 let disconnect = P2PMessage::Disconnect(reason);
111
112 let mut disconnect_encoded = Vec::new();
113 disconnect.encode(&mut disconnect_encoded);
114
115 assert_eq!(disconnect_encoded.len(), disconnect.length());
116 }
117 }
118
119 #[test]
120 fn test_decode_known_reasons() {
121 let all_reasons = vec![
122 "0100", "0180", "0101", "0102", "0103", "0104", "0105", "0106", "0107", "0108", "0109", "010a", "010b",
126 "0110", "01c100", "01c180", "01c101", "01c102", "01c103", "01c104", "01c105", "01c106", "01c107", "01c108",
130 "01c109", "01c10a", "01c10b", "01c110",
131 ];
132
133 for reason in all_reasons {
134 let reason = hex::decode(reason).unwrap();
135 let message = P2PMessage::decode(&mut &reason[..]).unwrap();
136 let P2PMessage::Disconnect(_) = message else {
137 panic!("expected a disconnect message");
138 };
139 }
140 }
141
142 #[test]
143 fn test_decode_disconnect_requested() {
144 let reason = "0100";
145 let reason = hex::decode(reason).unwrap();
146 match P2PMessage::decode(&mut &reason[..]).unwrap() {
147 P2PMessage::Disconnect(DisconnectReason::DisconnectRequested) => {}
148 _ => {
149 unreachable!()
150 }
151 }
152 }
153}