reth_network/
listener.rs

1//! Contains connection-oriented interfaces.
2
3use futures::{ready, Stream};
4use std::{
5    io,
6    net::SocketAddr,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tokio::net::{TcpListener, TcpStream};
11
12/// A tcp connection listener.
13///
14/// Listens for incoming connections.
15#[must_use = "Transport does nothing unless polled."]
16#[pin_project::pin_project]
17#[derive(Debug)]
18pub struct ConnectionListener {
19    /// Local address of the listener stream.
20    local_address: SocketAddr,
21    /// The active tcp listener for incoming connections.
22    #[pin]
23    incoming: TcpListenerStream,
24}
25
26impl ConnectionListener {
27    /// Creates a new [`TcpListener`] that listens for incoming connections.
28    pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
29        let listener = TcpListener::bind(addr).await?;
30        let local_addr = listener.local_addr()?;
31        Ok(Self::new(listener, local_addr))
32    }
33
34    /// Creates a new connection listener stream.
35    pub(crate) const fn new(listener: TcpListener, local_address: SocketAddr) -> Self {
36        Self { local_address, incoming: TcpListenerStream { inner: listener } }
37    }
38
39    /// Polls the type to make progress.
40    pub fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ListenerEvent> {
41        let this = self.project();
42        match ready!(this.incoming.poll_next(cx)) {
43            Some(Ok((stream, remote_addr))) => {
44                if let Err(err) = stream.set_nodelay(true) {
45                    tracing::warn!(target: "net", "set nodelay failed: {:?}", err);
46                }
47                Poll::Ready(ListenerEvent::Incoming { stream, remote_addr })
48            }
49            Some(Err(err)) => Poll::Ready(ListenerEvent::Error(err)),
50            None => {
51                Poll::Ready(ListenerEvent::ListenerClosed { local_address: *this.local_address })
52            }
53        }
54    }
55
56    /// Returns the socket address this listener listens on.
57    pub const fn local_address(&self) -> SocketAddr {
58        self.local_address
59    }
60}
61
62/// Event type produced by the [`TcpListenerStream`].
63pub enum ListenerEvent {
64    /// Received a new incoming.
65    Incoming {
66        /// Accepted connection
67        stream: TcpStream,
68        /// Address of the remote peer.
69        remote_addr: SocketAddr,
70    },
71    /// Returned when the underlying connection listener has been closed.
72    ///
73    /// This is the case if the [`TcpListenerStream`] should ever return `None`
74    ListenerClosed {
75        /// Address of the closed listener.
76        local_address: SocketAddr,
77    },
78    /// Encountered an error when accepting a connection.
79    ///
80    /// This is non-fatal error as the listener continues to listen for new connections to accept.
81    Error(io::Error),
82}
83
84/// A stream of incoming [`TcpStream`]s.
85#[derive(Debug)]
86struct TcpListenerStream {
87    /// listener for incoming connections.
88    inner: TcpListener,
89}
90
91impl Stream for TcpListenerStream {
92    type Item = io::Result<(TcpStream, SocketAddr)>;
93
94    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        match self.inner.poll_accept(cx) {
96            Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
97            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
98            Poll::Pending => Poll::Pending,
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use std::{
107        net::{Ipv4Addr, SocketAddrV4},
108        pin::pin,
109    };
110    use tokio::macros::support::poll_fn;
111
112    #[tokio::test(flavor = "multi_thread")]
113    async fn test_incoming_listener() {
114        let listener =
115            ConnectionListener::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
116                .await
117                .unwrap();
118        let local_addr = listener.local_address();
119
120        tokio::task::spawn(async move {
121            let mut listener = pin!(listener);
122            match poll_fn(|cx| listener.as_mut().poll(cx)).await {
123                ListenerEvent::Incoming { .. } => {}
124                _ => {
125                    panic!("unexpected event")
126                }
127            }
128        });
129
130        let _ = TcpStream::connect(local_addr).await.unwrap();
131    }
132}