1use crate::errors::PingerError;
2use std::{
3 pin::Pin,
4 task::{Context, Poll},
5 time::Duration,
6};
7use tokio::time::{Instant, Interval, Sleep};
8use tokio_stream::Stream;
9
10#[derive(Debug)]
13pub(crate) struct Pinger {
14 ping_interval: Interval,
16 timeout_timer: Pin<Box<Sleep>>,
18 timeout: Duration,
20 state: PingState,
22}
23
24impl Pinger {
27 pub(crate) fn new(ping_interval: Duration, timeout_duration: Duration) -> Self {
30 let now = Instant::now();
31 let timeout_timer = tokio::time::sleep(timeout_duration);
32 Self {
33 state: PingState::Ready,
34 ping_interval: tokio::time::interval_at(now + ping_interval, ping_interval),
35 timeout_timer: Box::pin(timeout_timer),
36 timeout: timeout_duration,
37 }
38 }
39
40 pub(crate) fn on_pong(&mut self) -> Result<(), PingerError> {
43 match self.state {
44 PingState::Ready => Err(PingerError::UnexpectedPong),
45 PingState::WaitingForPong => {
46 self.state = PingState::Ready;
47 self.ping_interval.reset();
48 Ok(())
49 }
50 PingState::TimedOut => {
51 self.state = PingState::Ready;
54 self.ping_interval.reset();
55 Ok(())
56 }
57 }
58 }
59
60 pub(crate) const fn state(&self) -> PingState {
62 self.state
63 }
64
65 pub(crate) fn poll_ping(
68 &mut self,
69 cx: &mut Context<'_>,
70 ) -> Poll<Result<PingerEvent, PingerError>> {
71 match self.state() {
72 PingState::Ready => {
73 if self.ping_interval.poll_tick(cx).is_ready() {
74 self.timeout_timer.as_mut().reset(Instant::now() + self.timeout);
75 self.state = PingState::WaitingForPong;
76 return Poll::Ready(Ok(PingerEvent::Ping))
77 }
78 }
79 PingState::WaitingForPong => {
80 if self.timeout_timer.is_elapsed() {
81 self.state = PingState::TimedOut;
82 return Poll::Ready(Ok(PingerEvent::Timeout))
83 }
84 }
85 PingState::TimedOut => {
86 return Poll::Pending
89 }
90 };
91 Poll::Pending
92 }
93}
94
95impl Stream for Pinger {
96 type Item = Result<PingerEvent, PingerError>;
97
98 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99 self.get_mut().poll_ping(cx).map(Some)
100 }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub(crate) enum PingState {
106 Ready,
109 WaitingForPong,
111 TimedOut,
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
119pub(crate) enum PingerEvent {
120 Ping,
122
123 Timeout,
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use futures::StreamExt;
131
132 #[tokio::test]
133 async fn test_ping_timeout() {
134 let interval = Duration::from_millis(300);
135 let mut pinger = Pinger::new(interval, Duration::from_millis(20));
137 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
138 pinger.on_pong().unwrap();
139 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
140
141 tokio::time::sleep(interval).await;
142 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Timeout);
143 pinger.on_pong().unwrap();
144
145 assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
146 }
147}