1use crate::{
4 codec::ECIESCodec, error::ECIESErrorImpl, ECIESError, EgressECIESValue, IngressECIESValue,
5};
6use alloy_primitives::{
7 bytes::{Bytes, BytesMut},
8 B512 as PeerId,
9};
10use futures::{ready, Sink, SinkExt};
11use secp256k1::SecretKey;
12use std::{
13 fmt::Debug,
14 io,
15 pin::Pin,
16 task::{Context, Poll},
17 time::Duration,
18};
19use tokio::{
20 io::{AsyncRead, AsyncWrite},
21 time::timeout,
22};
23use tokio_stream::{Stream, StreamExt};
24use tokio_util::codec::{Decoder, Framed};
25use tracing::{instrument, trace};
26
27const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
28
29#[derive(Debug)]
31#[pin_project::pin_project]
32pub struct ECIESStream<Io> {
33 #[pin]
34 stream: Framed<Io, ECIESCodec>,
35 remote_id: PeerId,
36}
37
38impl<Io> ECIESStream<Io>
39where
40 Io: AsyncRead + AsyncWrite + Unpin,
41{
42 #[instrument(skip(transport, secret_key))]
44 pub async fn connect(
45 transport: Io,
46 secret_key: SecretKey,
47 remote_id: PeerId,
48 ) -> Result<Self, ECIESError> {
49 Self::connect_with_timeout(transport, secret_key, remote_id, HANDSHAKE_TIMEOUT).await
50 }
51
52 pub async fn connect_with_timeout(
54 transport: Io,
55 secret_key: SecretKey,
56 remote_id: PeerId,
57 timeout_limit: Duration,
58 ) -> Result<Self, ECIESError> {
59 timeout(timeout_limit, Self::connect_without_timeout(transport, secret_key, remote_id))
60 .await
61 .map_err(|_| ECIESError::from(ECIESErrorImpl::StreamTimeout))?
62 }
63
64 pub async fn connect_without_timeout(
66 transport: Io,
67 secret_key: SecretKey,
68 remote_id: PeerId,
69 ) -> Result<Self, ECIESError> {
70 let ecies = ECIESCodec::new_client(secret_key, remote_id)
71 .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid handshake"))?;
72
73 let mut transport = ecies.framed(transport);
74
75 trace!("sending ecies auth ...");
76 transport.send(EgressECIESValue::Auth).await?;
77
78 trace!("waiting for ecies ack ...");
79
80 let msg = transport.try_next().await?;
81
82 let msg = msg.ok_or(ECIESErrorImpl::UnreadableStream)?;
86
87 trace!("parsing ecies ack ...");
88 if matches!(msg, IngressECIESValue::Ack) {
89 Ok(Self { stream: transport, remote_id })
90 } else {
91 Err(ECIESErrorImpl::InvalidHandshake {
92 expected: IngressECIESValue::Ack,
93 msg: Some(msg),
94 }
95 .into())
96 }
97 }
98
99 pub async fn incoming(transport: Io, secret_key: SecretKey) -> Result<Self, ECIESError> {
101 let ecies = ECIESCodec::new_server(secret_key)?;
102
103 trace!("incoming ecies stream");
104 let mut transport = ecies.framed(transport);
105 let msg = transport.try_next().await?;
106
107 trace!("receiving ecies auth");
108 let remote_id = match &msg {
109 Some(IngressECIESValue::AuthReceive(remote_id)) => *remote_id,
110 _ => {
111 return Err(ECIESErrorImpl::InvalidHandshake {
112 expected: IngressECIESValue::AuthReceive(Default::default()),
113 msg,
114 }
115 .into())
116 }
117 };
118
119 trace!("sending ecies ack");
120 transport.send(EgressECIESValue::Ack).await?;
121
122 Ok(Self { stream: transport, remote_id })
123 }
124
125 pub const fn remote_id(&self) -> PeerId {
127 self.remote_id
128 }
129}
130
131impl<Io> Stream for ECIESStream<Io>
132where
133 Io: AsyncRead + Unpin,
134{
135 type Item = Result<BytesMut, io::Error>;
136
137 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138 match ready!(self.project().stream.poll_next(cx)) {
139 Some(Ok(IngressECIESValue::Message(body))) => Poll::Ready(Some(Ok(body))),
140 Some(other) => Poll::Ready(Some(Err(io::Error::new(
141 io::ErrorKind::Other,
142 format!("ECIES stream protocol error: expected message, received {other:?}"),
143 )))),
144 None => Poll::Ready(None),
145 }
146 }
147}
148
149impl<Io> Sink<Bytes> for ECIESStream<Io>
150where
151 Io: AsyncWrite + Unpin,
152{
153 type Error = io::Error;
154
155 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 self.project().stream.poll_ready(cx)
157 }
158
159 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
160 self.project().stream.start_send(EgressECIESValue::Message(item))?;
161 Ok(())
162 }
163
164 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165 self.project().stream.poll_flush(cx)
166 }
167
168 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169 self.project().stream.poll_close(cx)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use reth_network_peers::pk2id;
177 use secp256k1::SECP256K1;
178 use tokio::net::{TcpListener, TcpStream};
179
180 #[tokio::test]
181 async fn can_write_and_read() {
182 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
183 let addr = listener.local_addr().unwrap();
184 let server_key = SecretKey::new(&mut rand::thread_rng());
185
186 let handle = tokio::spawn(async move {
187 let (incoming, _) = listener.accept().await.unwrap();
189 let mut stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
190
191 let message = stream.next().await.unwrap().unwrap();
193 assert_eq!(message, Bytes::from("hello"));
194 });
195
196 let server_id = pk2id(&server_key.public_key(SECP256K1));
198
199 let client_key = SecretKey::new(&mut rand::thread_rng());
200 let outgoing = TcpStream::connect(addr).await.unwrap();
201 let mut client_stream =
202 ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
203 client_stream.send(Bytes::from("hello")).await.unwrap();
204
205 handle.await.unwrap();
207 }
208
209 #[tokio::test]
210 async fn connection_should_timeout() {
211 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
212 let addr = listener.local_addr().unwrap();
213 let server_key = SecretKey::new(&mut rand::thread_rng());
214
215 let _handle = tokio::spawn(async move {
216 tokio::time::sleep(Duration::from_secs(11)).await;
218 let (incoming, _) = listener.accept().await.unwrap();
219 let mut stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
220
221 let message = stream.next().await.unwrap().unwrap();
223 assert_eq!(message, Bytes::from("hello"));
224 });
225
226 let server_id = pk2id(&server_key.public_key(SECP256K1));
228
229 let client_key = SecretKey::new(&mut rand::thread_rng());
230 let outgoing = TcpStream::connect(addr).await.unwrap();
231
232 let connect_result = ECIESStream::connect_with_timeout(
234 outgoing,
235 client_key,
236 server_id,
237 Duration::from_secs(1),
238 )
239 .await;
240
241 assert!(
243 matches!(connect_result, Err(e) if e.to_string() == ECIESErrorImpl::StreamTimeout.to_string())
244 );
245 }
246}