reth_network/session/
active.rs

1//! Represents an established session.
2
3use core::sync::atomic::Ordering;
4use std::{
5    collections::VecDeque,
6    future::Future,
7    net::SocketAddr,
8    pin::Pin,
9    sync::{atomic::AtomicU64, Arc},
10    task::{ready, Context, Poll},
11    time::{Duration, Instant},
12};
13
14use crate::{
15    message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16    session::{
17        conn::EthRlpxConnection,
18        handle::{ActiveSessionMessage, SessionCommand},
19        SessionId,
20    },
21};
22use alloy_primitives::Sealable;
23use futures::{stream::Fuse, SinkExt, StreamExt};
24use metrics::Gauge;
25use reth_eth_wire::{
26    capability::RawCapabilityMessage,
27    errors::{EthHandshakeError, EthStreamError, P2PStreamError},
28    message::{EthBroadcastMessage, RequestPair},
29    Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives,
30};
31use reth_metrics::common::mpsc::MeteredPollSender;
32use reth_network_api::PeerRequest;
33use reth_network_p2p::error::RequestError;
34use reth_network_peers::PeerId;
35use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
36use reth_primitives_traits::Block;
37use rustc_hash::FxHashMap;
38use tokio::{
39    sync::{mpsc::error::TrySendError, oneshot},
40    time::Interval,
41};
42use tokio_stream::wrappers::ReceiverStream;
43use tokio_util::sync::PollSender;
44use tracing::{debug, trace};
45
46// Constants for timeout updating.
47
48/// Minimum timeout value
49const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
50/// Maximum timeout value
51const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
52/// How much the new measurements affect the current timeout (X percent)
53const SAMPLE_IMPACT: f64 = 0.1;
54/// Amount of RTTs before timeout
55const TIMEOUT_SCALING: u32 = 3;
56
57/// The type that advances an established session by listening for incoming messages (from local
58/// node or read from connection) and emitting events back to the
59/// [`SessionManager`](super::SessionManager).
60///
61/// It listens for
62///    - incoming commands from the [`SessionManager`](super::SessionManager)
63///    - incoming _internal_ requests/broadcasts via the request/command channel
64///    - incoming requests/broadcasts _from remote_ via the connection
65///    - responses for handled ETH requests received from the remote peer.
66#[allow(dead_code)]
67pub(crate) struct ActiveSession<N: NetworkPrimitives> {
68    /// Keeps track of request ids.
69    pub(crate) next_id: u64,
70    /// The underlying connection.
71    pub(crate) conn: EthRlpxConnection<N>,
72    /// Identifier of the node we're connected to.
73    pub(crate) remote_peer_id: PeerId,
74    /// The address we're connected to.
75    pub(crate) remote_addr: SocketAddr,
76    /// All capabilities the peer announced
77    pub(crate) remote_capabilities: Arc<Capabilities>,
78    /// Internal identifier of this session
79    pub(crate) session_id: SessionId,
80    /// Incoming commands from the manager
81    pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
82    /// Sink to send messages to the [`SessionManager`](super::SessionManager).
83    pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
84    /// A message that needs to be delivered to the session manager
85    pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
86    /// Incoming internal requests which are delegated to the remote peer.
87    pub(crate) internal_request_tx: Fuse<ReceiverStream<PeerRequest<N>>>,
88    /// All requests sent to the remote peer we're waiting on a response
89    pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
90    /// All requests that were sent by the remote peer and we're waiting on an internal response
91    pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
92    /// Buffered messages that should be handled and sent to the peer.
93    pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
94    /// The maximum time we wait for a response from a peer.
95    pub(crate) internal_request_timeout: Arc<AtomicU64>,
96    /// Interval when to check for timed out requests.
97    pub(crate) internal_request_timeout_interval: Interval,
98    /// If an [`ActiveSession`] does not receive a response at all within this duration then it is
99    /// considered a protocol violation and the session will initiate a drop.
100    pub(crate) protocol_breach_request_timeout: Duration,
101    /// Used to reserve a slot to guarantee that the termination message is delivered
102    pub(crate) terminate_message:
103        Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
104}
105
106impl<N: NetworkPrimitives> ActiveSession<N> {
107    /// Returns `true` if the session is currently in the process of disconnecting
108    fn is_disconnecting(&self) -> bool {
109        self.conn.inner().is_disconnecting()
110    }
111
112    /// Returns the next request id
113    fn next_id(&mut self) -> u64 {
114        let id = self.next_id;
115        self.next_id += 1;
116        id
117    }
118
119    /// Shrinks the capacity of the internal buffers.
120    pub fn shrink_to_fit(&mut self) {
121        self.received_requests_from_remote.shrink_to_fit();
122        self.queued_outgoing.shrink_to_fit();
123    }
124
125    /// Handle a message read from the connection.
126    ///
127    /// Returns an error if the message is considered to be in violation of the protocol.
128    fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
129        /// A macro that handles an incoming request
130        /// This creates a new channel and tries to send the sender half to the session while
131        /// storing the receiver half internally so the pending response can be polled.
132        macro_rules! on_request {
133            ($req:ident, $resp_item:ident, $req_item:ident) => {{
134                let RequestPair { request_id, message: request } = $req;
135                let (tx, response) = oneshot::channel();
136                let received = ReceivedRequest {
137                    request_id,
138                    rx: PeerResponse::$resp_item { response },
139                    received: Instant::now(),
140                };
141                self.received_requests_from_remote.push(received);
142                self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
143                    request,
144                    response: tx,
145                }))
146                .into()
147            }};
148        }
149
150        /// Processes a response received from the peer
151        macro_rules! on_response {
152            ($resp:ident, $item:ident) => {{
153                let RequestPair { request_id, message } = $resp;
154                #[allow(clippy::collapsible_match)]
155                if let Some(req) = self.inflight_requests.remove(&request_id) {
156                    match req.request {
157                        RequestState::Waiting(PeerRequest::$item { response, .. }) => {
158                            let _ = response.send(Ok(message));
159                            self.update_request_timeout(req.timestamp, Instant::now());
160                        }
161                        RequestState::Waiting(request) => {
162                            request.send_bad_response();
163                        }
164                        RequestState::TimedOut => {
165                            // request was already timed out internally
166                            self.update_request_timeout(req.timestamp, Instant::now());
167                        }
168                    }
169                } else {
170                    // we received a response to a request we never sent
171                    self.on_bad_message();
172                }
173
174                OnIncomingMessageOutcome::Ok
175            }};
176        }
177
178        match msg {
179            message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
180                error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
181                message,
182            },
183            EthMessage::NewBlockHashes(msg) => {
184                self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
185            }
186            EthMessage::NewBlock(msg) => {
187                let block =
188                    NewBlockMessage { hash: msg.block.header().hash_slow(), block: Arc::new(*msg) };
189                self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
190            }
191            EthMessage::Transactions(msg) => {
192                self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
193            }
194            EthMessage::NewPooledTransactionHashes66(msg) => {
195                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
196            }
197            EthMessage::NewPooledTransactionHashes68(msg) => {
198                if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
199                    return OnIncomingMessageOutcome::BadMessage {
200                        error: EthStreamError::TransactionHashesInvalidLenOfFields {
201                            hashes_len: msg.hashes.len(),
202                            types_len: msg.types.len(),
203                            sizes_len: msg.sizes.len(),
204                        },
205                        message: EthMessage::NewPooledTransactionHashes68(msg),
206                    }
207                }
208                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
209            }
210            EthMessage::GetBlockHeaders(req) => {
211                on_request!(req, BlockHeaders, GetBlockHeaders)
212            }
213            EthMessage::BlockHeaders(resp) => {
214                on_response!(resp, GetBlockHeaders)
215            }
216            EthMessage::GetBlockBodies(req) => {
217                on_request!(req, BlockBodies, GetBlockBodies)
218            }
219            EthMessage::BlockBodies(resp) => {
220                on_response!(resp, GetBlockBodies)
221            }
222            EthMessage::GetPooledTransactions(req) => {
223                on_request!(req, PooledTransactions, GetPooledTransactions)
224            }
225            EthMessage::PooledTransactions(resp) => {
226                on_response!(resp, GetPooledTransactions)
227            }
228            EthMessage::GetNodeData(req) => {
229                on_request!(req, NodeData, GetNodeData)
230            }
231            EthMessage::NodeData(resp) => {
232                on_response!(resp, GetNodeData)
233            }
234            EthMessage::GetReceipts(req) => {
235                on_request!(req, Receipts, GetReceipts)
236            }
237            EthMessage::Receipts(resp) => {
238                on_response!(resp, GetReceipts)
239            }
240        }
241    }
242
243    /// Handle an internal peer request that will be sent to the remote.
244    fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
245        let request_id = self.next_id();
246        let msg = request.create_request_message(request_id);
247        self.queued_outgoing.push_back(msg.into());
248        let req = InflightRequest {
249            request: RequestState::Waiting(request),
250            timestamp: Instant::now(),
251            deadline,
252        };
253        self.inflight_requests.insert(request_id, req);
254    }
255
256    /// Handle a message received from the internal network
257    fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
258        match msg {
259            PeerMessage::NewBlockHashes(msg) => {
260                self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
261            }
262            PeerMessage::NewBlock(msg) => {
263                self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
264            }
265            PeerMessage::PooledTransactions(msg) => {
266                if msg.is_valid_for_version(self.conn.version()) {
267                    self.queued_outgoing.push_back(EthMessage::from(msg).into());
268                }
269            }
270            PeerMessage::EthRequest(req) => {
271                let deadline = self.request_deadline();
272                self.on_internal_peer_request(req, deadline);
273            }
274            PeerMessage::SendTransactions(msg) => {
275                self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
276            }
277            PeerMessage::ReceivedTransaction(_) => {
278                unreachable!("Not emitted by network")
279            }
280            PeerMessage::Other(other) => {
281                debug!(target: "net::session", message_id=%other.id, "Ignoring unsupported message");
282                self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
283            }
284        }
285    }
286
287    /// Returns the deadline timestamp at which the request times out
288    fn request_deadline(&self) -> Instant {
289        Instant::now() +
290            Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
291    }
292
293    /// Handle a Response to the peer
294    ///
295    /// This will queue the response to be sent to the peer
296    fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
297        match resp.try_into_message(id) {
298            Ok(msg) => {
299                self.queued_outgoing.push_back(msg.into());
300            }
301            Err(err) => {
302                debug!(target: "net", %err, "Failed to respond to received request");
303            }
304        }
305    }
306
307    /// Send a message back to the [`SessionManager`](super::SessionManager).
308    ///
309    /// Returns the message if the bounded channel is currently unable to handle this message.
310    #[allow(clippy::result_large_err)]
311    fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
312        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
313
314        match sender
315            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
316        {
317            Ok(_) => Ok(()),
318            Err(err) => {
319                trace!(
320                    target: "net",
321                    %err,
322                    "no capacity for incoming broadcast",
323                );
324                match err {
325                    TrySendError::Full(msg) => Err(msg),
326                    TrySendError::Closed(_) => Ok(()),
327                }
328            }
329        }
330    }
331
332    /// Send a message back to the [`SessionManager`](super::SessionManager)
333    /// covering both broadcasts and incoming requests.
334    ///
335    /// Returns the message if the bounded channel is currently unable to handle this message.
336    #[allow(clippy::result_large_err)]
337    fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
338        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
339
340        match sender
341            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
342        {
343            Ok(_) => Ok(()),
344            Err(err) => {
345                trace!(
346                    target: "net",
347                    %err,
348                    "no capacity for incoming request",
349                );
350                match err {
351                    TrySendError::Full(msg) => Err(msg),
352                    TrySendError::Closed(_) => {
353                        // Note: this would mean the `SessionManager` was dropped, which is already
354                        // handled by checking if the command receiver channel has been closed.
355                        Ok(())
356                    }
357                }
358            }
359        }
360    }
361
362    /// Notify the manager that the peer sent a bad message
363    fn on_bad_message(&self) {
364        let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
365        let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
366    }
367
368    /// Report back that this session has been closed.
369    fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
370        trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
371        let msg = ActiveSessionMessage::Disconnected {
372            peer_id: self.remote_peer_id,
373            remote_addr: self.remote_addr,
374        };
375
376        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
377        self.poll_terminate_message(cx).expect("message is set")
378    }
379
380    /// Report back that this session has been closed due to an error
381    fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
382        let msg = ActiveSessionMessage::ClosedOnConnectionError {
383            peer_id: self.remote_peer_id,
384            remote_addr: self.remote_addr,
385            error,
386        };
387        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
388        self.poll_terminate_message(cx).expect("message is set")
389    }
390
391    /// Starts the disconnect process
392    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
393        self.conn
394            .inner_mut()
395            .start_disconnect(reason)
396            .map_err(P2PStreamError::from)
397            .map_err(Into::into)
398    }
399
400    /// Flushes the disconnect message and emits the corresponding message
401    fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
402        debug_assert!(self.is_disconnecting(), "not disconnecting");
403
404        // try to close the flush out the remaining Disconnect message
405        let _ = ready!(self.conn.poll_close_unpin(cx));
406        self.emit_disconnect(cx)
407    }
408
409    /// Attempts to disconnect by sending the given disconnect reason
410    fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
411        match self.start_disconnect(reason) {
412            Ok(()) => {
413                // we're done
414                self.poll_disconnect(cx)
415            }
416            Err(err) => {
417                debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
418                self.close_on_error(err, cx)
419            }
420        }
421    }
422
423    /// Checks for _internally_ timed out requests.
424    ///
425    /// If a requests misses its deadline, then it is timed out internally.
426    /// If a request misses the `protocol_breach_request_timeout` then this session is considered in
427    /// protocol violation and will close.
428    ///
429    /// Returns `true` if a peer missed the `protocol_breach_request_timeout`, in which case the
430    /// session should be terminated.
431    #[must_use]
432    fn check_timed_out_requests(&mut self, now: Instant) -> bool {
433        for (id, req) in &mut self.inflight_requests {
434            if req.is_timed_out(now) {
435                if req.is_waiting() {
436                    debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
437                    req.timeout();
438                } else if now - req.timestamp > self.protocol_breach_request_timeout {
439                    return true
440                }
441            }
442        }
443
444        false
445    }
446
447    /// Updates the request timeout with a request's timestamps
448    fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
449        let elapsed = received.saturating_duration_since(sent);
450
451        let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
452        let request_timeout = calculate_new_timeout(current, elapsed);
453        self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
454        self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
455    }
456
457    /// If a termination message is queued this will try to send it
458    fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
459        let (mut tx, msg) = self.terminate_message.take()?;
460        match tx.poll_reserve(cx) {
461            Poll::Pending => {
462                self.terminate_message = Some((tx, msg));
463                return Some(Poll::Pending)
464            }
465            Poll::Ready(Ok(())) => {
466                let _ = tx.send_item(msg);
467            }
468            Poll::Ready(Err(_)) => {
469                // channel closed
470            }
471        }
472        // terminate the task
473        Some(Poll::Ready(()))
474    }
475}
476
477impl<N: NetworkPrimitives> Future for ActiveSession<N> {
478    type Output = ();
479
480    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
481        let this = self.get_mut();
482
483        // if the session is terminate we have to send the termination message before we can close
484        if let Some(terminate) = this.poll_terminate_message(cx) {
485            return terminate
486        }
487
488        if this.is_disconnecting() {
489            return this.poll_disconnect(cx)
490        }
491
492        // The receive loop can be CPU intensive since it involves message decoding which could take
493        // up a lot of resources and increase latencies for other sessions if not yielded manually.
494        // If the budget is exhausted we manually yield back control to the (coop) scheduler. This
495        // manual yield point should prevent situations where polling appears to be frozen. See also <https://tokio.rs/blog/2020-04-preemption>
496        // And tokio's docs on cooperative scheduling <https://docs.rs/tokio/latest/tokio/task/#cooperative-scheduling>
497        let mut budget = 4;
498
499        // The main poll loop that drives the session
500        'main: loop {
501            let mut progress = false;
502
503            // we prioritize incoming commands sent from the session manager
504            loop {
505                match this.commands_rx.poll_next_unpin(cx) {
506                    Poll::Pending => break,
507                    Poll::Ready(None) => {
508                        // this is only possible when the manager was dropped, in which case we also
509                        // terminate this session
510                        return Poll::Ready(())
511                    }
512                    Poll::Ready(Some(cmd)) => {
513                        progress = true;
514                        match cmd {
515                            SessionCommand::Disconnect { reason } => {
516                                debug!(
517                                    target: "net::session",
518                                    ?reason,
519                                    remote_peer_id=?this.remote_peer_id,
520                                    "Received disconnect command for session"
521                                );
522                                let reason =
523                                    reason.unwrap_or(DisconnectReason::DisconnectRequested);
524
525                                return this.try_disconnect(reason, cx)
526                            }
527                            SessionCommand::Message(msg) => {
528                                this.on_internal_peer_message(msg);
529                            }
530                        }
531                    }
532                }
533            }
534
535            let deadline = this.request_deadline();
536
537            while let Poll::Ready(Some(req)) = this.internal_request_tx.poll_next_unpin(cx) {
538                progress = true;
539                this.on_internal_peer_request(req, deadline);
540            }
541
542            // Advance all active requests.
543            // We remove each request one by one and add them back.
544            for idx in (0..this.received_requests_from_remote.len()).rev() {
545                let mut req = this.received_requests_from_remote.swap_remove(idx);
546                match req.rx.poll(cx) {
547                    Poll::Pending => {
548                        // not ready yet
549                        this.received_requests_from_remote.push(req);
550                    }
551                    Poll::Ready(resp) => {
552                        this.handle_outgoing_response(req.request_id, resp);
553                    }
554                }
555            }
556
557            // Send messages by advancing the sink and queuing in buffered messages
558            while this.conn.poll_ready_unpin(cx).is_ready() {
559                if let Some(msg) = this.queued_outgoing.pop_front() {
560                    progress = true;
561                    let res = match msg {
562                        OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
563                        OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
564                        OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
565                    };
566                    if let Err(err) = res {
567                        debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
568                        // notify the manager
569                        return this.close_on_error(err, cx)
570                    }
571                } else {
572                    // no more messages to send over the wire
573                    break
574                }
575            }
576
577            // read incoming messages from the wire
578            'receive: loop {
579                // ensure we still have enough budget for another iteration
580                budget -= 1;
581                if budget == 0 {
582                    // make sure we're woken up again
583                    cx.waker().wake_by_ref();
584                    break 'main
585                }
586
587                // try to resend the pending message that we could not send because the channel was
588                // full. [`PollSender`] will ensure that we're woken up again when the channel is
589                // ready to receive the message, and will only error if the channel is closed.
590                if let Some(msg) = this.pending_message_to_session.take() {
591                    match this.to_session_manager.poll_reserve(cx) {
592                        Poll::Ready(Ok(_)) => {
593                            let _ = this.to_session_manager.send_item(msg);
594                        }
595                        Poll::Ready(Err(_)) => return Poll::Ready(()),
596                        Poll::Pending => {
597                            this.pending_message_to_session = Some(msg);
598                            break 'receive
599                        }
600                    };
601                }
602
603                match this.conn.poll_next_unpin(cx) {
604                    Poll::Pending => break,
605                    Poll::Ready(None) => {
606                        if this.is_disconnecting() {
607                            break
608                        }
609                        debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
610                        return this.emit_disconnect(cx)
611                    }
612                    Poll::Ready(Some(res)) => {
613                        match res {
614                            Ok(msg) => {
615                                trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
616                                // decode and handle message
617                                match this.on_incoming_message(msg) {
618                                    OnIncomingMessageOutcome::Ok => {
619                                        // handled successfully
620                                        progress = true;
621                                    }
622                                    OnIncomingMessageOutcome::BadMessage { error, message } => {
623                                        debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
624                                        return this.close_on_error(error, cx)
625                                    }
626                                    OnIncomingMessageOutcome::NoCapacity(msg) => {
627                                        // failed to send due to lack of capacity
628                                        this.pending_message_to_session = Some(msg);
629                                        continue 'receive
630                                    }
631                                }
632                            }
633                            Err(err) => {
634                                debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
635                                return this.close_on_error(err, cx)
636                            }
637                        }
638                    }
639                }
640            }
641
642            if !progress {
643                break 'main
644            }
645        }
646
647        while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
648            // check for timed out requests
649            if this.check_timed_out_requests(Instant::now()) {
650                if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
651                    let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
652                    this.pending_message_to_session = Some(msg);
653                }
654            }
655        }
656
657        this.shrink_to_fit();
658
659        Poll::Pending
660    }
661}
662
663/// Tracks a request received from the peer
664pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
665    /// Protocol Identifier
666    request_id: u64,
667    /// Receiver half of the channel that's supposed to receive the proper response.
668    rx: PeerResponse<N>,
669    /// Timestamp when we read this msg from the wire.
670    #[allow(dead_code)]
671    received: Instant,
672}
673
674/// A request that waits for a response from the peer
675pub(crate) struct InflightRequest<R> {
676    /// Request we sent to peer and the internal response channel
677    request: RequestState<R>,
678    /// Instant when the request was sent
679    timestamp: Instant,
680    /// Time limit for the response
681    deadline: Instant,
682}
683
684// === impl InflightRequest ===
685
686impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
687    /// Returns true if the request is timedout
688    #[inline]
689    fn is_timed_out(&self, now: Instant) -> bool {
690        now > self.deadline
691    }
692
693    /// Returns true if we're still waiting for a response
694    #[inline]
695    const fn is_waiting(&self) -> bool {
696        matches!(self.request, RequestState::Waiting(_))
697    }
698
699    /// This will timeout the request by sending an error response to the internal channel
700    fn timeout(&mut self) {
701        let mut req = RequestState::TimedOut;
702        std::mem::swap(&mut self.request, &mut req);
703
704        if let RequestState::Waiting(req) = req {
705            req.send_err_response(RequestError::Timeout);
706        }
707    }
708}
709
710/// All outcome variants when handling an incoming message
711enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
712    /// Message successfully handled.
713    Ok,
714    /// Message is considered to be in violation of the protocol
715    BadMessage { error: EthStreamError, message: EthMessage<N> },
716    /// Currently no capacity to handle the message
717    NoCapacity(ActiveSessionMessage<N>),
718}
719
720impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
721    for OnIncomingMessageOutcome<N>
722{
723    fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
724        match res {
725            Ok(_) => Self::Ok,
726            Err(msg) => Self::NoCapacity(msg),
727        }
728    }
729}
730
731enum RequestState<R> {
732    /// Waiting for the response
733    Waiting(R),
734    /// Request already timed out
735    TimedOut,
736}
737
738/// Outgoing messages that can be sent over the wire.
739pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
740    /// A message that is owned.
741    Eth(EthMessage<N>),
742    /// A message that may be shared by multiple sessions.
743    Broadcast(EthBroadcastMessage<N>),
744    /// A raw capability message
745    Raw(RawCapabilityMessage),
746}
747
748impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
749    fn from(value: EthMessage<N>) -> Self {
750        Self::Eth(value)
751    }
752}
753
754impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
755    fn from(value: EthBroadcastMessage<N>) -> Self {
756        Self::Broadcast(value)
757    }
758}
759
760/// Calculates a new timeout using an updated estimation of the RTT
761#[inline]
762fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
763    let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
764
765    // this dampens sudden changes by taking a weighted mean of the old and new values
766    let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
767
768    smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
769}
770
771/// A helper struct that wraps the queue of outgoing messages and a metric to track their count
772pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
773    messages: VecDeque<OutgoingMessage<N>>,
774    count: Gauge,
775}
776
777impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
778    pub(crate) const fn new(metric: Gauge) -> Self {
779        Self { messages: VecDeque::new(), count: metric }
780    }
781
782    pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
783        self.messages.push_back(message);
784        self.count.increment(1);
785    }
786
787    pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
788        self.messages.pop_front().inspect(|_| self.count.decrement(1))
789    }
790
791    pub(crate) fn shrink_to_fit(&mut self) {
792        self.messages.shrink_to_fit();
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799    use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
800    use reth_chainspec::MAINNET;
801    use reth_ecies::stream::ECIESStream;
802    use reth_eth_wire::{
803        EthNetworkPrimitives, EthStream, GetBlockBodies, HelloMessageWithProtocols, P2PStream,
804        Status, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
805    };
806    use reth_network_peers::pk2id;
807    use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
808    use reth_primitives::{EthereumHardfork, ForkFilter};
809    use secp256k1::{SecretKey, SECP256K1};
810    use tokio::{
811        net::{TcpListener, TcpStream},
812        sync::mpsc,
813    };
814
815    /// Returns a testing `HelloMessage` and new secretkey
816    fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
817        HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
818    }
819
820    struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
821        _remote_capabilities: Arc<Capabilities>,
822        active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
823        active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
824        to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
825        secret_key: SecretKey,
826        local_peer_id: PeerId,
827        hello: HelloMessageWithProtocols,
828        status: Status,
829        fork_filter: ForkFilter,
830        next_id: usize,
831    }
832
833    impl<N: NetworkPrimitives> SessionBuilder<N> {
834        fn next_id(&mut self) -> SessionId {
835            let id = self.next_id;
836            self.next_id += 1;
837            SessionId(id)
838        }
839
840        /// Connects a new Eth stream and executes the given closure with that established stream
841        fn with_client_stream<F, O>(
842            &self,
843            local_addr: SocketAddr,
844            f: F,
845        ) -> Pin<Box<dyn Future<Output = ()> + Send>>
846        where
847            F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
848            O: Future<Output = ()> + Send + Sync,
849        {
850            let status = self.status;
851            let fork_filter = self.fork_filter.clone();
852            let local_peer_id = self.local_peer_id;
853            let mut hello = self.hello.clone();
854            let key = SecretKey::new(&mut rand::thread_rng());
855            hello.id = pk2id(&key.public_key(SECP256K1));
856            Box::pin(async move {
857                let outgoing = TcpStream::connect(local_addr).await.unwrap();
858                let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
859
860                let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
861
862                let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
863                    .handshake(status, fork_filter)
864                    .await
865                    .unwrap();
866                f(client_stream).await
867            })
868        }
869
870        async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
871            let remote_addr = stream.local_addr().unwrap();
872            let session_id = self.next_id();
873            let (_disconnect_tx, disconnect_rx) = oneshot::channel();
874            let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
875
876            tokio::task::spawn(start_pending_incoming_session(
877                disconnect_rx,
878                session_id,
879                stream,
880                pending_sessions_tx,
881                remote_addr,
882                self.secret_key,
883                self.hello.clone(),
884                self.status,
885                self.fork_filter.clone(),
886                Default::default(),
887            ));
888
889            let mut stream = ReceiverStream::new(pending_sessions_rx);
890
891            match stream.next().await.unwrap() {
892                PendingSessionEvent::Established {
893                    session_id,
894                    remote_addr,
895                    peer_id,
896                    capabilities,
897                    conn,
898                    ..
899                } => {
900                    let (_to_session_tx, messages_rx) = mpsc::channel(10);
901                    let (commands_to_session, commands_rx) = mpsc::channel(10);
902                    let poll_sender = PollSender::new(self.active_session_tx.clone());
903
904                    self.to_sessions.push(commands_to_session);
905
906                    ActiveSession {
907                        next_id: 0,
908                        remote_peer_id: peer_id,
909                        remote_addr,
910                        remote_capabilities: Arc::clone(&capabilities),
911                        session_id,
912                        commands_rx: ReceiverStream::new(commands_rx),
913                        to_session_manager: MeteredPollSender::new(
914                            poll_sender,
915                            "network_active_session",
916                        ),
917                        pending_message_to_session: None,
918                        internal_request_tx: ReceiverStream::new(messages_rx).fuse(),
919                        inflight_requests: Default::default(),
920                        conn,
921                        queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
922                        received_requests_from_remote: Default::default(),
923                        internal_request_timeout_interval: tokio::time::interval(
924                            INITIAL_REQUEST_TIMEOUT,
925                        ),
926                        internal_request_timeout: Arc::new(AtomicU64::new(
927                            INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
928                        )),
929                        protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
930                        terminate_message: None,
931                    }
932                }
933                ev => {
934                    panic!("unexpected message {ev:?}")
935                }
936            }
937        }
938    }
939
940    impl Default for SessionBuilder {
941        fn default() -> Self {
942            let (active_session_tx, active_session_rx) = mpsc::channel(100);
943
944            let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand::thread_rng());
945            let local_peer_id = pk2id(&pk);
946
947            Self {
948                next_id: 0,
949                _remote_capabilities: Arc::new(Capabilities::from(vec![])),
950                active_session_tx,
951                active_session_rx: ReceiverStream::new(active_session_rx),
952                to_sessions: vec![],
953                hello: eth_hello(&secret_key),
954                secret_key,
955                local_peer_id,
956                status: StatusBuilder::default().build(),
957                fork_filter: MAINNET
958                    .hardfork_fork_filter(EthereumHardfork::Frontier)
959                    .expect("The Frontier fork filter should exist on mainnet"),
960            }
961        }
962    }
963
964    #[tokio::test(flavor = "multi_thread")]
965    async fn test_disconnect() {
966        let mut builder = SessionBuilder::default();
967
968        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
969        let local_addr = listener.local_addr().unwrap();
970
971        let expected_disconnect = DisconnectReason::UselessPeer;
972
973        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
974            let msg = client_stream.next().await.unwrap().unwrap_err();
975            assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
976        });
977
978        tokio::task::spawn(async move {
979            let (incoming, _) = listener.accept().await.unwrap();
980            let mut session = builder.connect_incoming(incoming).await;
981
982            session.start_disconnect(expected_disconnect).unwrap();
983            session.await
984        });
985
986        fut.await;
987    }
988
989    #[tokio::test(flavor = "multi_thread")]
990    async fn handle_dropped_stream() {
991        let mut builder = SessionBuilder::default();
992
993        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
994        let local_addr = listener.local_addr().unwrap();
995
996        let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
997            drop(client_stream);
998            tokio::time::sleep(Duration::from_secs(1)).await
999        });
1000
1001        let (tx, rx) = oneshot::channel();
1002
1003        tokio::task::spawn(async move {
1004            let (incoming, _) = listener.accept().await.unwrap();
1005            let session = builder.connect_incoming(incoming).await;
1006            session.await;
1007
1008            tx.send(()).unwrap();
1009        });
1010
1011        tokio::task::spawn(fut);
1012
1013        rx.await.unwrap();
1014    }
1015
1016    #[tokio::test(flavor = "multi_thread")]
1017    async fn test_send_many_messages() {
1018        reth_tracing::init_test_tracing();
1019        let mut builder = SessionBuilder::default();
1020
1021        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1022        let local_addr = listener.local_addr().unwrap();
1023
1024        let num_messages = 100;
1025
1026        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1027            for _ in 0..num_messages {
1028                client_stream
1029                    .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1030                    .await
1031                    .unwrap();
1032            }
1033        });
1034
1035        let (tx, rx) = oneshot::channel();
1036
1037        tokio::task::spawn(async move {
1038            let (incoming, _) = listener.accept().await.unwrap();
1039            let session = builder.connect_incoming(incoming).await;
1040            session.await;
1041
1042            tx.send(()).unwrap();
1043        });
1044
1045        tokio::task::spawn(fut);
1046
1047        rx.await.unwrap();
1048    }
1049
1050    #[tokio::test(flavor = "multi_thread")]
1051    async fn test_request_timeout() {
1052        reth_tracing::init_test_tracing();
1053
1054        let mut builder = SessionBuilder::default();
1055
1056        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1057        let local_addr = listener.local_addr().unwrap();
1058
1059        let request_timeout = Duration::from_millis(100);
1060        let drop_timeout = Duration::from_millis(1500);
1061
1062        let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1063            let _client_stream = client_stream;
1064            tokio::time::sleep(drop_timeout * 60).await;
1065        });
1066        tokio::task::spawn(fut);
1067
1068        let (incoming, _) = listener.accept().await.unwrap();
1069        let mut session = builder.connect_incoming(incoming).await;
1070        session
1071            .internal_request_timeout
1072            .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1073        session.protocol_breach_request_timeout = drop_timeout;
1074        session.internal_request_timeout_interval =
1075            tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1076        let (tx, rx) = oneshot::channel();
1077        let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1078        session.on_internal_peer_request(req, Instant::now());
1079        tokio::spawn(session);
1080
1081        let err = rx.await.unwrap().unwrap_err();
1082        assert_eq!(err, RequestError::Timeout);
1083
1084        // wait for protocol breach error
1085        let msg = builder.active_session_rx.next().await.unwrap();
1086        match msg {
1087            ActiveSessionMessage::ProtocolBreach { .. } => {}
1088            ev => unreachable!("{ev:?}"),
1089        }
1090    }
1091
1092    #[tokio::test(flavor = "multi_thread")]
1093    async fn test_keep_alive() {
1094        let mut builder = SessionBuilder::default();
1095
1096        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1097        let local_addr = listener.local_addr().unwrap();
1098
1099        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1100            let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1101            client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1102        });
1103
1104        let (tx, rx) = oneshot::channel();
1105
1106        tokio::task::spawn(async move {
1107            let (incoming, _) = listener.accept().await.unwrap();
1108            let session = builder.connect_incoming(incoming).await;
1109            session.await;
1110
1111            tx.send(()).unwrap();
1112        });
1113
1114        tokio::task::spawn(fut);
1115
1116        rx.await.unwrap();
1117    }
1118
1119    // This tests that incoming messages are delivered when there's capacity.
1120    #[tokio::test(flavor = "multi_thread")]
1121    async fn test_send_at_capacity() {
1122        let mut builder = SessionBuilder::default();
1123
1124        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1125        let local_addr = listener.local_addr().unwrap();
1126
1127        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1128            client_stream
1129                .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1130                .await
1131                .unwrap();
1132            let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1133        });
1134        tokio::task::spawn(fut);
1135
1136        let (incoming, _) = listener.accept().await.unwrap();
1137        let session = builder.connect_incoming(incoming).await;
1138
1139        // fill the entire message buffer with an unrelated message
1140        let mut num_fill_messages = 0;
1141        loop {
1142            if builder
1143                .active_session_tx
1144                .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1145                .is_err()
1146            {
1147                break
1148            }
1149            num_fill_messages += 1;
1150        }
1151
1152        tokio::task::spawn(async move {
1153            session.await;
1154        });
1155
1156        tokio::time::sleep(Duration::from_millis(100)).await;
1157
1158        for _ in 0..num_fill_messages {
1159            let message = builder.active_session_rx.next().await.unwrap();
1160            match message {
1161                ActiveSessionMessage::ProtocolBreach { .. } => {}
1162                ev => unreachable!("{ev:?}"),
1163            }
1164        }
1165
1166        let message = builder.active_session_rx.next().await.unwrap();
1167        match message {
1168            ActiveSessionMessage::ValidMessage {
1169                message: PeerMessage::PooledTransactions(_),
1170                ..
1171            } => {}
1172            _ => unreachable!(),
1173        }
1174    }
1175
1176    #[test]
1177    fn timeout_calculation_sanity_tests() {
1178        let rtt = Duration::from_secs(5);
1179        // timeout for an RTT of `rtt`
1180        let timeout = rtt * TIMEOUT_SCALING;
1181
1182        // if rtt hasn't changed, timeout shouldn't change
1183        assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1184
1185        // if rtt changed, the new timeout should change less than it
1186        assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1187        assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1188        assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1189        assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1190    }
1191}