reth_ecies/
stream.rs

1//! The ECIES Stream implementation which wraps over [`AsyncRead`] and [`AsyncWrite`].
2
3use crate::{
4    codec::ECIESCodec, error::ECIESErrorImpl, ECIESError, EgressECIESValue, IngressECIESValue,
5};
6use alloy_primitives::{
7    bytes::{Bytes, BytesMut},
8    B512 as PeerId,
9};
10use futures::{ready, Sink, SinkExt};
11use secp256k1::SecretKey;
12use std::{
13    fmt::Debug,
14    io,
15    pin::Pin,
16    task::{Context, Poll},
17    time::Duration,
18};
19use tokio::{
20    io::{AsyncRead, AsyncWrite},
21    time::timeout,
22};
23use tokio_stream::{Stream, StreamExt};
24use tokio_util::codec::{Decoder, Framed};
25use tracing::{instrument, trace};
26
27const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
28
29/// `ECIES` stream over TCP exchanging raw bytes
30#[derive(Debug)]
31#[pin_project::pin_project]
32pub struct ECIESStream<Io> {
33    #[pin]
34    stream: Framed<Io, ECIESCodec>,
35    remote_id: PeerId,
36}
37
38impl<Io> ECIESStream<Io>
39where
40    Io: AsyncRead + AsyncWrite + Unpin,
41{
42    /// Connect to an `ECIES` server
43    #[instrument(skip(transport, secret_key))]
44    pub async fn connect(
45        transport: Io,
46        secret_key: SecretKey,
47        remote_id: PeerId,
48    ) -> Result<Self, ECIESError> {
49        Self::connect_with_timeout(transport, secret_key, remote_id, HANDSHAKE_TIMEOUT).await
50    }
51
52    /// Wrapper around `connect_no_timeout` which enforces a timeout.
53    pub async fn connect_with_timeout(
54        transport: Io,
55        secret_key: SecretKey,
56        remote_id: PeerId,
57        timeout_limit: Duration,
58    ) -> Result<Self, ECIESError> {
59        timeout(timeout_limit, Self::connect_without_timeout(transport, secret_key, remote_id))
60            .await
61            .map_err(|_| ECIESError::from(ECIESErrorImpl::StreamTimeout))?
62    }
63
64    /// Connect to an `ECIES` server with no timeout.
65    pub async fn connect_without_timeout(
66        transport: Io,
67        secret_key: SecretKey,
68        remote_id: PeerId,
69    ) -> Result<Self, ECIESError> {
70        let ecies = ECIESCodec::new_client(secret_key, remote_id)
71            .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid handshake"))?;
72
73        let mut transport = ecies.framed(transport);
74
75        trace!("sending ecies auth ...");
76        transport.send(EgressECIESValue::Auth).await?;
77
78        trace!("waiting for ecies ack ...");
79
80        let msg = transport.try_next().await?;
81
82        // `Framed` returns `None` if the underlying stream is no longer readable, and the codec is
83        // unable to decode another message from the (partially filled) buffer. This usually happens
84        // if the remote drops the TcpStream.
85        let msg = msg.ok_or(ECIESErrorImpl::UnreadableStream)?;
86
87        trace!("parsing ecies ack ...");
88        if matches!(msg, IngressECIESValue::Ack) {
89            Ok(Self { stream: transport, remote_id })
90        } else {
91            Err(ECIESErrorImpl::InvalidHandshake {
92                expected: IngressECIESValue::Ack,
93                msg: Some(msg),
94            }
95            .into())
96        }
97    }
98
99    /// Listen on a just connected ECIES client
100    pub async fn incoming(transport: Io, secret_key: SecretKey) -> Result<Self, ECIESError> {
101        let ecies = ECIESCodec::new_server(secret_key)?;
102
103        trace!("incoming ecies stream");
104        let mut transport = ecies.framed(transport);
105        let msg = transport.try_next().await?;
106
107        trace!("receiving ecies auth");
108        let remote_id = match &msg {
109            Some(IngressECIESValue::AuthReceive(remote_id)) => *remote_id,
110            _ => {
111                return Err(ECIESErrorImpl::InvalidHandshake {
112                    expected: IngressECIESValue::AuthReceive(Default::default()),
113                    msg,
114                }
115                .into())
116            }
117        };
118
119        trace!("sending ecies ack");
120        transport.send(EgressECIESValue::Ack).await?;
121
122        Ok(Self { stream: transport, remote_id })
123    }
124
125    /// Get the remote id
126    pub const fn remote_id(&self) -> PeerId {
127        self.remote_id
128    }
129}
130
131impl<Io> Stream for ECIESStream<Io>
132where
133    Io: AsyncRead + Unpin,
134{
135    type Item = Result<BytesMut, io::Error>;
136
137    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138        match ready!(self.project().stream.poll_next(cx)) {
139            Some(Ok(IngressECIESValue::Message(body))) => Poll::Ready(Some(Ok(body))),
140            Some(other) => Poll::Ready(Some(Err(io::Error::new(
141                io::ErrorKind::Other,
142                format!("ECIES stream protocol error: expected message, received {other:?}"),
143            )))),
144            None => Poll::Ready(None),
145        }
146    }
147}
148
149impl<Io> Sink<Bytes> for ECIESStream<Io>
150where
151    Io: AsyncWrite + Unpin,
152{
153    type Error = io::Error;
154
155    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        self.project().stream.poll_ready(cx)
157    }
158
159    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
160        self.project().stream.start_send(EgressECIESValue::Message(item))?;
161        Ok(())
162    }
163
164    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        self.project().stream.poll_flush(cx)
166    }
167
168    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        self.project().stream.poll_close(cx)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use reth_network_peers::pk2id;
177    use secp256k1::SECP256K1;
178    use tokio::net::{TcpListener, TcpStream};
179
180    #[tokio::test]
181    async fn can_write_and_read() {
182        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
183        let addr = listener.local_addr().unwrap();
184        let server_key = SecretKey::new(&mut rand::thread_rng());
185
186        let handle = tokio::spawn(async move {
187            // roughly based off of the design of tokio::net::TcpListener
188            let (incoming, _) = listener.accept().await.unwrap();
189            let mut stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
190
191            // use the stream to get the next message
192            let message = stream.next().await.unwrap().unwrap();
193            assert_eq!(message, Bytes::from("hello"));
194        });
195
196        // create the server pubkey
197        let server_id = pk2id(&server_key.public_key(SECP256K1));
198
199        let client_key = SecretKey::new(&mut rand::thread_rng());
200        let outgoing = TcpStream::connect(addr).await.unwrap();
201        let mut client_stream =
202            ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
203        client_stream.send(Bytes::from("hello")).await.unwrap();
204
205        // make sure the server receives the message and asserts before ending the test
206        handle.await.unwrap();
207    }
208
209    #[tokio::test]
210    async fn connection_should_timeout() {
211        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
212        let addr = listener.local_addr().unwrap();
213        let server_key = SecretKey::new(&mut rand::thread_rng());
214
215        let _handle = tokio::spawn(async move {
216            // Delay accepting the connection for longer than the client's timeout period
217            tokio::time::sleep(Duration::from_secs(11)).await;
218            let (incoming, _) = listener.accept().await.unwrap();
219            let mut stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
220
221            // use the stream to get the next message
222            let message = stream.next().await.unwrap().unwrap();
223            assert_eq!(message, Bytes::from("hello"));
224        });
225
226        // create the server pubkey
227        let server_id = pk2id(&server_key.public_key(SECP256K1));
228
229        let client_key = SecretKey::new(&mut rand::thread_rng());
230        let outgoing = TcpStream::connect(addr).await.unwrap();
231
232        // Attempt to connect, expecting a timeout due to the server's delayed response
233        let connect_result = ECIESStream::connect_with_timeout(
234            outgoing,
235            client_key,
236            server_id,
237            Duration::from_secs(1),
238        )
239        .await;
240
241        // Assert that a timeout error occurred
242        assert!(
243            matches!(connect_result, Err(e) if e.to_string() == ECIESErrorImpl::StreamTimeout.to_string())
244        );
245    }
246}