1use futures::{ready, Stream};
4use std::{
5 io,
6 net::SocketAddr,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use tokio::net::{TcpListener, TcpStream};
11
12#[must_use = "Transport does nothing unless polled."]
16#[pin_project::pin_project]
17#[derive(Debug)]
18pub struct ConnectionListener {
19 local_address: SocketAddr,
21 #[pin]
23 incoming: TcpListenerStream,
24}
25
26impl ConnectionListener {
27 pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
29 let listener = TcpListener::bind(addr).await?;
30 let local_addr = listener.local_addr()?;
31 Ok(Self::new(listener, local_addr))
32 }
33
34 pub(crate) const fn new(listener: TcpListener, local_address: SocketAddr) -> Self {
36 Self { local_address, incoming: TcpListenerStream { inner: listener } }
37 }
38
39 pub fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ListenerEvent> {
41 let this = self.project();
42 match ready!(this.incoming.poll_next(cx)) {
43 Some(Ok((stream, remote_addr))) => {
44 if let Err(err) = stream.set_nodelay(true) {
45 tracing::warn!(target: "net", "set nodelay failed: {:?}", err);
46 }
47 Poll::Ready(ListenerEvent::Incoming { stream, remote_addr })
48 }
49 Some(Err(err)) => Poll::Ready(ListenerEvent::Error(err)),
50 None => {
51 Poll::Ready(ListenerEvent::ListenerClosed { local_address: *this.local_address })
52 }
53 }
54 }
55
56 pub const fn local_address(&self) -> SocketAddr {
58 self.local_address
59 }
60}
61
62pub enum ListenerEvent {
64 Incoming {
66 stream: TcpStream,
68 remote_addr: SocketAddr,
70 },
71 ListenerClosed {
75 local_address: SocketAddr,
77 },
78 Error(io::Error),
82}
83
84#[derive(Debug)]
86struct TcpListenerStream {
87 inner: TcpListener,
89}
90
91impl Stream for TcpListenerStream {
92 type Item = io::Result<(TcpStream, SocketAddr)>;
93
94 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 match self.inner.poll_accept(cx) {
96 Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))),
97 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
98 Poll::Pending => Poll::Pending,
99 }
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use std::{
107 net::{Ipv4Addr, SocketAddrV4},
108 pin::pin,
109 };
110 use tokio::macros::support::poll_fn;
111
112 #[tokio::test(flavor = "multi_thread")]
113 async fn test_incoming_listener() {
114 let listener =
115 ConnectionListener::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)))
116 .await
117 .unwrap();
118 let local_addr = listener.local_address();
119
120 tokio::task::spawn(async move {
121 let mut listener = pin!(listener);
122 match poll_fn(|cx| listener.as_mut().poll(cx)).await {
123 ListenerEvent::Incoming { .. } => {}
124 _ => {
125 panic!("unexpected event")
126 }
127 }
128 });
129
130 let _ = TcpStream::connect(local_addr).await.unwrap();
131 }
132}