reth_network/session/
active.rs

1//! Represents an established session.
2
3use core::sync::atomic::Ordering;
4use std::{
5    collections::VecDeque,
6    future::Future,
7    net::SocketAddr,
8    pin::Pin,
9    sync::{atomic::AtomicU64, Arc},
10    task::{ready, Context, Poll},
11    time::{Duration, Instant},
12};
13
14use crate::{
15    message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16    session::{
17        conn::EthRlpxConnection,
18        handle::{ActiveSessionMessage, SessionCommand},
19        SessionId,
20    },
21};
22use alloy_primitives::Sealable;
23use futures::{stream::Fuse, SinkExt, StreamExt};
24use metrics::Gauge;
25use reth_eth_wire::{
26    errors::{EthHandshakeError, EthStreamError},
27    message::{EthBroadcastMessage, RequestPair},
28    Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives,
29};
30use reth_eth_wire_types::RawCapabilityMessage;
31use reth_metrics::common::mpsc::MeteredPollSender;
32use reth_network_api::PeerRequest;
33use reth_network_p2p::error::RequestError;
34use reth_network_peers::PeerId;
35use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
36use reth_primitives_traits::Block;
37use rustc_hash::FxHashMap;
38use tokio::{
39    sync::{mpsc::error::TrySendError, oneshot},
40    time::Interval,
41};
42use tokio_stream::wrappers::ReceiverStream;
43use tokio_util::sync::PollSender;
44use tracing::{debug, trace};
45
46// Constants for timeout updating.
47
48/// Minimum timeout value
49const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
50/// Maximum timeout value
51const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
52/// How much the new measurements affect the current timeout (X percent)
53const SAMPLE_IMPACT: f64 = 0.1;
54/// Amount of RTTs before timeout
55const TIMEOUT_SCALING: u32 = 3;
56
57/// Restricts the number of queued outgoing messages for larger responses:
58///  - Block Bodies
59///  - Receipts
60///  - Headers
61///  - `PooledTransactions`
62///
63/// With proper softlimits in place (2MB) this targets 10MB (4+1 * 2MB) of outgoing response data.
64///
65/// This parameter serves as backpressure for reading additional requests from the remote.
66/// Once we've queued up more responses than this, the session should prioritize message flushing
67/// before reading any more messages from the remote peer, throttling the peer.
68const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
69
70/// The type that advances an established session by listening for incoming messages (from local
71/// node or read from connection) and emitting events back to the
72/// [`SessionManager`](super::SessionManager).
73///
74/// It listens for
75///    - incoming commands from the [`SessionManager`](super::SessionManager)
76///    - incoming _internal_ requests/broadcasts via the request/command channel
77///    - incoming requests/broadcasts _from remote_ via the connection
78///    - responses for handled ETH requests received from the remote peer.
79#[expect(dead_code)]
80pub(crate) struct ActiveSession<N: NetworkPrimitives> {
81    /// Keeps track of request ids.
82    pub(crate) next_id: u64,
83    /// The underlying connection.
84    pub(crate) conn: EthRlpxConnection<N>,
85    /// Identifier of the node we're connected to.
86    pub(crate) remote_peer_id: PeerId,
87    /// The address we're connected to.
88    pub(crate) remote_addr: SocketAddr,
89    /// All capabilities the peer announced
90    pub(crate) remote_capabilities: Arc<Capabilities>,
91    /// Internal identifier of this session
92    pub(crate) session_id: SessionId,
93    /// Incoming commands from the manager
94    pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
95    /// Sink to send messages to the [`SessionManager`](super::SessionManager).
96    pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
97    /// A message that needs to be delivered to the session manager
98    pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
99    /// Incoming internal requests which are delegated to the remote peer.
100    pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
101    /// All requests sent to the remote peer we're waiting on a response
102    pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
103    /// All requests that were sent by the remote peer and we're waiting on an internal response
104    pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
105    /// Buffered messages that should be handled and sent to the peer.
106    pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
107    /// The maximum time we wait for a response from a peer.
108    pub(crate) internal_request_timeout: Arc<AtomicU64>,
109    /// Interval when to check for timed out requests.
110    pub(crate) internal_request_timeout_interval: Interval,
111    /// If an [`ActiveSession`] does not receive a response at all within this duration then it is
112    /// considered a protocol violation and the session will initiate a drop.
113    pub(crate) protocol_breach_request_timeout: Duration,
114    /// Used to reserve a slot to guarantee that the termination message is delivered
115    pub(crate) terminate_message:
116        Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
117}
118
119impl<N: NetworkPrimitives> ActiveSession<N> {
120    /// Returns `true` if the session is currently in the process of disconnecting
121    fn is_disconnecting(&self) -> bool {
122        self.conn.inner().is_disconnecting()
123    }
124
125    /// Returns the next request id
126    const fn next_id(&mut self) -> u64 {
127        let id = self.next_id;
128        self.next_id += 1;
129        id
130    }
131
132    /// Shrinks the capacity of the internal buffers.
133    pub fn shrink_to_fit(&mut self) {
134        self.received_requests_from_remote.shrink_to_fit();
135        self.queued_outgoing.shrink_to_fit();
136    }
137
138    /// Returns how many responses we've currently queued up.
139    fn queued_response_count(&self) -> usize {
140        self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
141    }
142
143    /// Handle a message read from the connection.
144    ///
145    /// Returns an error if the message is considered to be in violation of the protocol.
146    fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
147        /// A macro that handles an incoming request
148        /// This creates a new channel and tries to send the sender half to the session while
149        /// storing the receiver half internally so the pending response can be polled.
150        macro_rules! on_request {
151            ($req:ident, $resp_item:ident, $req_item:ident) => {{
152                let RequestPair { request_id, message: request } = $req;
153                let (tx, response) = oneshot::channel();
154                let received = ReceivedRequest {
155                    request_id,
156                    rx: PeerResponse::$resp_item { response },
157                    received: Instant::now(),
158                };
159                self.received_requests_from_remote.push(received);
160                self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
161                    request,
162                    response: tx,
163                }))
164                .into()
165            }};
166        }
167
168        /// Processes a response received from the peer
169        macro_rules! on_response {
170            ($resp:ident, $item:ident) => {{
171                let RequestPair { request_id, message } = $resp;
172                if let Some(req) = self.inflight_requests.remove(&request_id) {
173                    match req.request {
174                        RequestState::Waiting(PeerRequest::$item { response, .. }) => {
175                            let _ = response.send(Ok(message));
176                            self.update_request_timeout(req.timestamp, Instant::now());
177                        }
178                        RequestState::Waiting(request) => {
179                            request.send_bad_response();
180                        }
181                        RequestState::TimedOut => {
182                            // request was already timed out internally
183                            self.update_request_timeout(req.timestamp, Instant::now());
184                        }
185                    }
186                } else {
187                    // we received a response to a request we never sent
188                    self.on_bad_message();
189                }
190
191                OnIncomingMessageOutcome::Ok
192            }};
193        }
194
195        match msg {
196            message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
197                error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
198                message,
199            },
200            EthMessage::NewBlockHashes(msg) => {
201                self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
202            }
203            EthMessage::NewBlock(msg) => {
204                let block =
205                    NewBlockMessage { hash: msg.block.header().hash_slow(), block: Arc::new(*msg) };
206                self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
207            }
208            EthMessage::Transactions(msg) => {
209                self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
210            }
211            EthMessage::NewPooledTransactionHashes66(msg) => {
212                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
213            }
214            EthMessage::NewPooledTransactionHashes68(msg) => {
215                if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
216                    return OnIncomingMessageOutcome::BadMessage {
217                        error: EthStreamError::TransactionHashesInvalidLenOfFields {
218                            hashes_len: msg.hashes.len(),
219                            types_len: msg.types.len(),
220                            sizes_len: msg.sizes.len(),
221                        },
222                        message: EthMessage::NewPooledTransactionHashes68(msg),
223                    }
224                }
225                self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
226            }
227            EthMessage::GetBlockHeaders(req) => {
228                on_request!(req, BlockHeaders, GetBlockHeaders)
229            }
230            EthMessage::BlockHeaders(resp) => {
231                on_response!(resp, GetBlockHeaders)
232            }
233            EthMessage::GetBlockBodies(req) => {
234                on_request!(req, BlockBodies, GetBlockBodies)
235            }
236            EthMessage::BlockBodies(resp) => {
237                on_response!(resp, GetBlockBodies)
238            }
239            EthMessage::GetPooledTransactions(req) => {
240                on_request!(req, PooledTransactions, GetPooledTransactions)
241            }
242            EthMessage::PooledTransactions(resp) => {
243                on_response!(resp, GetPooledTransactions)
244            }
245            EthMessage::GetNodeData(req) => {
246                on_request!(req, NodeData, GetNodeData)
247            }
248            EthMessage::NodeData(resp) => {
249                on_response!(resp, GetNodeData)
250            }
251            EthMessage::GetReceipts(req) => {
252                on_request!(req, Receipts, GetReceipts)
253            }
254            EthMessage::Receipts(resp) => {
255                on_response!(resp, GetReceipts)
256            }
257            EthMessage::Receipts69(resp) => {
258                // TODO: remove mandatory blooms
259                let resp = resp.map(|receipts| receipts.into_with_bloom());
260                on_response!(resp, GetReceipts)
261            }
262            EthMessage::BlockRangeUpdate(msg) => {
263                self.try_emit_broadcast(PeerMessage::BlockRangeUpdated(msg)).into()
264            }
265            EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
266        }
267    }
268
269    /// Handle an internal peer request that will be sent to the remote.
270    fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
271        let request_id = self.next_id();
272        let msg = request.create_request_message(request_id);
273        self.queued_outgoing.push_back(msg.into());
274        let req = InflightRequest {
275            request: RequestState::Waiting(request),
276            timestamp: Instant::now(),
277            deadline,
278        };
279        self.inflight_requests.insert(request_id, req);
280    }
281
282    /// Handle a message received from the internal network
283    fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
284        match msg {
285            PeerMessage::NewBlockHashes(msg) => {
286                self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
287            }
288            PeerMessage::NewBlock(msg) => {
289                self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
290            }
291            PeerMessage::PooledTransactions(msg) => {
292                if msg.is_valid_for_version(self.conn.version()) {
293                    self.queued_outgoing.push_back(EthMessage::from(msg).into());
294                }
295            }
296            PeerMessage::EthRequest(req) => {
297                let deadline = self.request_deadline();
298                self.on_internal_peer_request(req, deadline);
299            }
300            PeerMessage::SendTransactions(msg) => {
301                self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
302            }
303            PeerMessage::BlockRangeUpdated(_) => {}
304            PeerMessage::ReceivedTransaction(_) => {
305                unreachable!("Not emitted by network")
306            }
307            PeerMessage::Other(other) => {
308                self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
309            }
310        }
311    }
312
313    /// Returns the deadline timestamp at which the request times out
314    fn request_deadline(&self) -> Instant {
315        Instant::now() +
316            Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
317    }
318
319    /// Handle a Response to the peer
320    ///
321    /// This will queue the response to be sent to the peer
322    fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
323        match resp.try_into_message(id) {
324            Ok(msg) => {
325                self.queued_outgoing.push_back(msg.into());
326            }
327            Err(err) => {
328                debug!(target: "net", %err, "Failed to respond to received request");
329            }
330        }
331    }
332
333    /// Send a message back to the [`SessionManager`](super::SessionManager).
334    ///
335    /// Returns the message if the bounded channel is currently unable to handle this message.
336    #[expect(clippy::result_large_err)]
337    fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
338        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
339
340        match sender
341            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
342        {
343            Ok(_) => Ok(()),
344            Err(err) => {
345                trace!(
346                    target: "net",
347                    %err,
348                    "no capacity for incoming broadcast",
349                );
350                match err {
351                    TrySendError::Full(msg) => Err(msg),
352                    TrySendError::Closed(_) => Ok(()),
353                }
354            }
355        }
356    }
357
358    /// Send a message back to the [`SessionManager`](super::SessionManager)
359    /// covering both broadcasts and incoming requests.
360    ///
361    /// Returns the message if the bounded channel is currently unable to handle this message.
362    #[expect(clippy::result_large_err)]
363    fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
364        let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
365
366        match sender
367            .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
368        {
369            Ok(_) => Ok(()),
370            Err(err) => {
371                trace!(
372                    target: "net",
373                    %err,
374                    "no capacity for incoming request",
375                );
376                match err {
377                    TrySendError::Full(msg) => Err(msg),
378                    TrySendError::Closed(_) => {
379                        // Note: this would mean the `SessionManager` was dropped, which is already
380                        // handled by checking if the command receiver channel has been closed.
381                        Ok(())
382                    }
383                }
384            }
385        }
386    }
387
388    /// Notify the manager that the peer sent a bad message
389    fn on_bad_message(&self) {
390        let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
391        let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
392    }
393
394    /// Report back that this session has been closed.
395    fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
396        trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
397        let msg = ActiveSessionMessage::Disconnected {
398            peer_id: self.remote_peer_id,
399            remote_addr: self.remote_addr,
400        };
401
402        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
403        self.poll_terminate_message(cx).expect("message is set")
404    }
405
406    /// Report back that this session has been closed due to an error
407    fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
408        let msg = ActiveSessionMessage::ClosedOnConnectionError {
409            peer_id: self.remote_peer_id,
410            remote_addr: self.remote_addr,
411            error,
412        };
413        self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
414        self.poll_terminate_message(cx).expect("message is set")
415    }
416
417    /// Starts the disconnect process
418    fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
419        Ok(self.conn.inner_mut().start_disconnect(reason)?)
420    }
421
422    /// Flushes the disconnect message and emits the corresponding message
423    fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
424        debug_assert!(self.is_disconnecting(), "not disconnecting");
425
426        // try to close the flush out the remaining Disconnect message
427        let _ = ready!(self.conn.poll_close_unpin(cx));
428        self.emit_disconnect(cx)
429    }
430
431    /// Attempts to disconnect by sending the given disconnect reason
432    fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
433        match self.start_disconnect(reason) {
434            Ok(()) => {
435                // we're done
436                self.poll_disconnect(cx)
437            }
438            Err(err) => {
439                debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
440                self.close_on_error(err, cx)
441            }
442        }
443    }
444
445    /// Checks for _internally_ timed out requests.
446    ///
447    /// If a requests misses its deadline, then it is timed out internally.
448    /// If a request misses the `protocol_breach_request_timeout` then this session is considered in
449    /// protocol violation and will close.
450    ///
451    /// Returns `true` if a peer missed the `protocol_breach_request_timeout`, in which case the
452    /// session should be terminated.
453    #[must_use]
454    fn check_timed_out_requests(&mut self, now: Instant) -> bool {
455        for (id, req) in &mut self.inflight_requests {
456            if req.is_timed_out(now) {
457                if req.is_waiting() {
458                    debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
459                    req.timeout();
460                } else if now - req.timestamp > self.protocol_breach_request_timeout {
461                    return true
462                }
463            }
464        }
465
466        false
467    }
468
469    /// Updates the request timeout with a request's timestamps
470    fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
471        let elapsed = received.saturating_duration_since(sent);
472
473        let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
474        let request_timeout = calculate_new_timeout(current, elapsed);
475        self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
476        self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
477    }
478
479    /// If a termination message is queued this will try to send it
480    fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
481        let (mut tx, msg) = self.terminate_message.take()?;
482        match tx.poll_reserve(cx) {
483            Poll::Pending => {
484                self.terminate_message = Some((tx, msg));
485                return Some(Poll::Pending)
486            }
487            Poll::Ready(Ok(())) => {
488                let _ = tx.send_item(msg);
489            }
490            Poll::Ready(Err(_)) => {
491                // channel closed
492            }
493        }
494        // terminate the task
495        Some(Poll::Ready(()))
496    }
497}
498
499impl<N: NetworkPrimitives> Future for ActiveSession<N> {
500    type Output = ();
501
502    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
503        let this = self.get_mut();
504
505        // if the session is terminate we have to send the termination message before we can close
506        if let Some(terminate) = this.poll_terminate_message(cx) {
507            return terminate
508        }
509
510        if this.is_disconnecting() {
511            return this.poll_disconnect(cx)
512        }
513
514        // The receive loop can be CPU intensive since it involves message decoding which could take
515        // up a lot of resources and increase latencies for other sessions if not yielded manually.
516        // If the budget is exhausted we manually yield back control to the (coop) scheduler. This
517        // manual yield point should prevent situations where polling appears to be frozen. See also <https://tokio.rs/blog/2020-04-preemption>
518        // And tokio's docs on cooperative scheduling <https://docs.rs/tokio/latest/tokio/task/#cooperative-scheduling>
519        let mut budget = 4;
520
521        // The main poll loop that drives the session
522        'main: loop {
523            let mut progress = false;
524
525            // we prioritize incoming commands sent from the session manager
526            loop {
527                match this.commands_rx.poll_next_unpin(cx) {
528                    Poll::Pending => break,
529                    Poll::Ready(None) => {
530                        // this is only possible when the manager was dropped, in which case we also
531                        // terminate this session
532                        return Poll::Ready(())
533                    }
534                    Poll::Ready(Some(cmd)) => {
535                        progress = true;
536                        match cmd {
537                            SessionCommand::Disconnect { reason } => {
538                                debug!(
539                                    target: "net::session",
540                                    ?reason,
541                                    remote_peer_id=?this.remote_peer_id,
542                                    "Received disconnect command for session"
543                                );
544                                let reason =
545                                    reason.unwrap_or(DisconnectReason::DisconnectRequested);
546
547                                return this.try_disconnect(reason, cx)
548                            }
549                            SessionCommand::Message(msg) => {
550                                this.on_internal_peer_message(msg);
551                            }
552                        }
553                    }
554                }
555            }
556
557            let deadline = this.request_deadline();
558
559            while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
560                progress = true;
561                this.on_internal_peer_request(req, deadline);
562            }
563
564            // Advance all active requests.
565            // We remove each request one by one and add them back.
566            for idx in (0..this.received_requests_from_remote.len()).rev() {
567                let mut req = this.received_requests_from_remote.swap_remove(idx);
568                match req.rx.poll(cx) {
569                    Poll::Pending => {
570                        // not ready yet
571                        this.received_requests_from_remote.push(req);
572                    }
573                    Poll::Ready(resp) => {
574                        this.handle_outgoing_response(req.request_id, resp);
575                    }
576                }
577            }
578
579            // Send messages by advancing the sink and queuing in buffered messages
580            while this.conn.poll_ready_unpin(cx).is_ready() {
581                if let Some(msg) = this.queued_outgoing.pop_front() {
582                    progress = true;
583                    let res = match msg {
584                        OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
585                        OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
586                        OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
587                    };
588                    if let Err(err) = res {
589                        debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
590                        // notify the manager
591                        return this.close_on_error(err, cx)
592                    }
593                } else {
594                    // no more messages to send over the wire
595                    break
596                }
597            }
598
599            // read incoming messages from the wire
600            'receive: loop {
601                // ensure we still have enough budget for another iteration
602                budget -= 1;
603                if budget == 0 {
604                    // make sure we're woken up again
605                    cx.waker().wake_by_ref();
606                    break 'main
607                }
608
609                // try to resend the pending message that we could not send because the channel was
610                // full. [`PollSender`] will ensure that we're woken up again when the channel is
611                // ready to receive the message, and will only error if the channel is closed.
612                if let Some(msg) = this.pending_message_to_session.take() {
613                    match this.to_session_manager.poll_reserve(cx) {
614                        Poll::Ready(Ok(_)) => {
615                            let _ = this.to_session_manager.send_item(msg);
616                        }
617                        Poll::Ready(Err(_)) => return Poll::Ready(()),
618                        Poll::Pending => {
619                            this.pending_message_to_session = Some(msg);
620                            break 'receive
621                        }
622                    };
623                }
624
625                // check whether we should throttle incoming messages
626                if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
627                    // we're currently waiting for the responses to the peer's requests which aren't
628                    // queued as outgoing yet
629                    //
630                    // Note: we don't need to register the waker here because we polled the requests
631                    // above
632                    break 'receive
633                }
634
635                // we also need to check if we have multiple responses queued up
636                if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
637                    this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
638                {
639                    // if we've queued up more responses than allowed, we don't poll for new
640                    // messages and break the receive loop early
641                    //
642                    // Note: we don't need to register the waker here because we still have
643                    // queued messages and the sink impl registered the waker because we've
644                    // already advanced it to `Pending` earlier
645                    break 'receive
646                }
647
648                match this.conn.poll_next_unpin(cx) {
649                    Poll::Pending => break,
650                    Poll::Ready(None) => {
651                        if this.is_disconnecting() {
652                            break
653                        }
654                        debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
655                        return this.emit_disconnect(cx)
656                    }
657                    Poll::Ready(Some(res)) => {
658                        match res {
659                            Ok(msg) => {
660                                trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
661                                // decode and handle message
662                                match this.on_incoming_message(msg) {
663                                    OnIncomingMessageOutcome::Ok => {
664                                        // handled successfully
665                                        progress = true;
666                                    }
667                                    OnIncomingMessageOutcome::BadMessage { error, message } => {
668                                        debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
669                                        return this.close_on_error(error, cx)
670                                    }
671                                    OnIncomingMessageOutcome::NoCapacity(msg) => {
672                                        // failed to send due to lack of capacity
673                                        this.pending_message_to_session = Some(msg);
674                                    }
675                                }
676                            }
677                            Err(err) => {
678                                debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
679                                return this.close_on_error(err, cx)
680                            }
681                        }
682                    }
683                }
684            }
685
686            if !progress {
687                break 'main
688            }
689        }
690
691        while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
692            // check for timed out requests
693            if this.check_timed_out_requests(Instant::now()) {
694                if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
695                    let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
696                    this.pending_message_to_session = Some(msg);
697                }
698            }
699        }
700
701        this.shrink_to_fit();
702
703        Poll::Pending
704    }
705}
706
707/// Tracks a request received from the peer
708pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
709    /// Protocol Identifier
710    request_id: u64,
711    /// Receiver half of the channel that's supposed to receive the proper response.
712    rx: PeerResponse<N>,
713    /// Timestamp when we read this msg from the wire.
714    #[expect(dead_code)]
715    received: Instant,
716}
717
718/// A request that waits for a response from the peer
719pub(crate) struct InflightRequest<R> {
720    /// Request we sent to peer and the internal response channel
721    request: RequestState<R>,
722    /// Instant when the request was sent
723    timestamp: Instant,
724    /// Time limit for the response
725    deadline: Instant,
726}
727
728// === impl InflightRequest ===
729
730impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
731    /// Returns true if the request is timedout
732    #[inline]
733    fn is_timed_out(&self, now: Instant) -> bool {
734        now > self.deadline
735    }
736
737    /// Returns true if we're still waiting for a response
738    #[inline]
739    const fn is_waiting(&self) -> bool {
740        matches!(self.request, RequestState::Waiting(_))
741    }
742
743    /// This will timeout the request by sending an error response to the internal channel
744    fn timeout(&mut self) {
745        let mut req = RequestState::TimedOut;
746        std::mem::swap(&mut self.request, &mut req);
747
748        if let RequestState::Waiting(req) = req {
749            req.send_err_response(RequestError::Timeout);
750        }
751    }
752}
753
754/// All outcome variants when handling an incoming message
755enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
756    /// Message successfully handled.
757    Ok,
758    /// Message is considered to be in violation of the protocol
759    BadMessage { error: EthStreamError, message: EthMessage<N> },
760    /// Currently no capacity to handle the message
761    NoCapacity(ActiveSessionMessage<N>),
762}
763
764impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
765    for OnIncomingMessageOutcome<N>
766{
767    fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
768        match res {
769            Ok(_) => Self::Ok,
770            Err(msg) => Self::NoCapacity(msg),
771        }
772    }
773}
774
775enum RequestState<R> {
776    /// Waiting for the response
777    Waiting(R),
778    /// Request already timed out
779    TimedOut,
780}
781
782/// Outgoing messages that can be sent over the wire.
783pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
784    /// A message that is owned.
785    Eth(EthMessage<N>),
786    /// A message that may be shared by multiple sessions.
787    Broadcast(EthBroadcastMessage<N>),
788    /// A raw capability message
789    Raw(RawCapabilityMessage),
790}
791
792impl<N: NetworkPrimitives> OutgoingMessage<N> {
793    /// Returns true if this is a response.
794    const fn is_response(&self) -> bool {
795        match self {
796            Self::Eth(msg) => msg.is_response(),
797            _ => false,
798        }
799    }
800}
801
802impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
803    fn from(value: EthMessage<N>) -> Self {
804        Self::Eth(value)
805    }
806}
807
808impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
809    fn from(value: EthBroadcastMessage<N>) -> Self {
810        Self::Broadcast(value)
811    }
812}
813
814/// Calculates a new timeout using an updated estimation of the RTT
815#[inline]
816fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
817    let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
818
819    // this dampens sudden changes by taking a weighted mean of the old and new values
820    let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
821
822    smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
823}
824
825/// A helper struct that wraps the queue of outgoing messages and a metric to track their count
826pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
827    messages: VecDeque<OutgoingMessage<N>>,
828    count: Gauge,
829}
830
831impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
832    pub(crate) const fn new(metric: Gauge) -> Self {
833        Self { messages: VecDeque::new(), count: metric }
834    }
835
836    pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
837        self.messages.push_back(message);
838        self.count.increment(1);
839    }
840
841    pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
842        self.messages.pop_front().inspect(|_| self.count.decrement(1))
843    }
844
845    pub(crate) fn shrink_to_fit(&mut self) {
846        self.messages.shrink_to_fit();
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853    use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
854    use alloy_eips::eip2124::ForkFilter;
855    use reth_chainspec::MAINNET;
856    use reth_ecies::stream::ECIESStream;
857    use reth_eth_wire::{
858        handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockBodies,
859        HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
860        UnifiedStatus,
861    };
862    use reth_ethereum_forks::EthereumHardfork;
863    use reth_network_peers::pk2id;
864    use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
865    use secp256k1::{SecretKey, SECP256K1};
866    use tokio::{
867        net::{TcpListener, TcpStream},
868        sync::mpsc,
869    };
870
871    /// Returns a testing `HelloMessage` and new secretkey
872    fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
873        HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
874    }
875
876    struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
877        _remote_capabilities: Arc<Capabilities>,
878        active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
879        active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
880        to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
881        secret_key: SecretKey,
882        local_peer_id: PeerId,
883        hello: HelloMessageWithProtocols,
884        status: UnifiedStatus,
885        fork_filter: ForkFilter,
886        next_id: usize,
887    }
888
889    impl<N: NetworkPrimitives> SessionBuilder<N> {
890        fn next_id(&mut self) -> SessionId {
891            let id = self.next_id;
892            self.next_id += 1;
893            SessionId(id)
894        }
895
896        /// Connects a new Eth stream and executes the given closure with that established stream
897        fn with_client_stream<F, O>(
898            &self,
899            local_addr: SocketAddr,
900            f: F,
901        ) -> Pin<Box<dyn Future<Output = ()> + Send>>
902        where
903            F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
904            O: Future<Output = ()> + Send + Sync,
905        {
906            let status = self.status;
907            let fork_filter = self.fork_filter.clone();
908            let local_peer_id = self.local_peer_id;
909            let mut hello = self.hello.clone();
910            let key = SecretKey::new(&mut rand_08::thread_rng());
911            hello.id = pk2id(&key.public_key(SECP256K1));
912            Box::pin(async move {
913                let outgoing = TcpStream::connect(local_addr).await.unwrap();
914                let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
915
916                let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
917
918                let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
919                    .handshake(status, fork_filter)
920                    .await
921                    .unwrap();
922                f(client_stream).await
923            })
924        }
925
926        async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
927            let remote_addr = stream.local_addr().unwrap();
928            let session_id = self.next_id();
929            let (_disconnect_tx, disconnect_rx) = oneshot::channel();
930            let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
931
932            tokio::task::spawn(start_pending_incoming_session(
933                Arc::new(EthHandshake::default()),
934                disconnect_rx,
935                session_id,
936                stream,
937                pending_sessions_tx,
938                remote_addr,
939                self.secret_key,
940                self.hello.clone(),
941                self.status,
942                self.fork_filter.clone(),
943                Default::default(),
944            ));
945
946            let mut stream = ReceiverStream::new(pending_sessions_rx);
947
948            match stream.next().await.unwrap() {
949                PendingSessionEvent::Established {
950                    session_id,
951                    remote_addr,
952                    peer_id,
953                    capabilities,
954                    conn,
955                    ..
956                } => {
957                    let (_to_session_tx, messages_rx) = mpsc::channel(10);
958                    let (commands_to_session, commands_rx) = mpsc::channel(10);
959                    let poll_sender = PollSender::new(self.active_session_tx.clone());
960
961                    self.to_sessions.push(commands_to_session);
962
963                    ActiveSession {
964                        next_id: 0,
965                        remote_peer_id: peer_id,
966                        remote_addr,
967                        remote_capabilities: Arc::clone(&capabilities),
968                        session_id,
969                        commands_rx: ReceiverStream::new(commands_rx),
970                        to_session_manager: MeteredPollSender::new(
971                            poll_sender,
972                            "network_active_session",
973                        ),
974                        pending_message_to_session: None,
975                        internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
976                        inflight_requests: Default::default(),
977                        conn,
978                        queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
979                        received_requests_from_remote: Default::default(),
980                        internal_request_timeout_interval: tokio::time::interval(
981                            INITIAL_REQUEST_TIMEOUT,
982                        ),
983                        internal_request_timeout: Arc::new(AtomicU64::new(
984                            INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
985                        )),
986                        protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
987                        terminate_message: None,
988                    }
989                }
990                ev => {
991                    panic!("unexpected message {ev:?}")
992                }
993            }
994        }
995    }
996
997    impl Default for SessionBuilder {
998        fn default() -> Self {
999            let (active_session_tx, active_session_rx) = mpsc::channel(100);
1000
1001            let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1002            let local_peer_id = pk2id(&pk);
1003
1004            Self {
1005                next_id: 0,
1006                _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1007                active_session_tx,
1008                active_session_rx: ReceiverStream::new(active_session_rx),
1009                to_sessions: vec![],
1010                hello: eth_hello(&secret_key),
1011                secret_key,
1012                local_peer_id,
1013                status: StatusBuilder::default().build(),
1014                fork_filter: MAINNET
1015                    .hardfork_fork_filter(EthereumHardfork::Frontier)
1016                    .expect("The Frontier fork filter should exist on mainnet"),
1017            }
1018        }
1019    }
1020
1021    #[tokio::test(flavor = "multi_thread")]
1022    async fn test_disconnect() {
1023        let mut builder = SessionBuilder::default();
1024
1025        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1026        let local_addr = listener.local_addr().unwrap();
1027
1028        let expected_disconnect = DisconnectReason::UselessPeer;
1029
1030        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1031            let msg = client_stream.next().await.unwrap().unwrap_err();
1032            assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1033        });
1034
1035        tokio::task::spawn(async move {
1036            let (incoming, _) = listener.accept().await.unwrap();
1037            let mut session = builder.connect_incoming(incoming).await;
1038
1039            session.start_disconnect(expected_disconnect).unwrap();
1040            session.await
1041        });
1042
1043        fut.await;
1044    }
1045
1046    #[tokio::test(flavor = "multi_thread")]
1047    async fn handle_dropped_stream() {
1048        let mut builder = SessionBuilder::default();
1049
1050        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1051        let local_addr = listener.local_addr().unwrap();
1052
1053        let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1054            drop(client_stream);
1055            tokio::time::sleep(Duration::from_secs(1)).await
1056        });
1057
1058        let (tx, rx) = oneshot::channel();
1059
1060        tokio::task::spawn(async move {
1061            let (incoming, _) = listener.accept().await.unwrap();
1062            let session = builder.connect_incoming(incoming).await;
1063            session.await;
1064
1065            tx.send(()).unwrap();
1066        });
1067
1068        tokio::task::spawn(fut);
1069
1070        rx.await.unwrap();
1071    }
1072
1073    #[tokio::test(flavor = "multi_thread")]
1074    async fn test_send_many_messages() {
1075        reth_tracing::init_test_tracing();
1076        let mut builder = SessionBuilder::default();
1077
1078        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1079        let local_addr = listener.local_addr().unwrap();
1080
1081        let num_messages = 100;
1082
1083        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1084            for _ in 0..num_messages {
1085                client_stream
1086                    .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1087                    .await
1088                    .unwrap();
1089            }
1090        });
1091
1092        let (tx, rx) = oneshot::channel();
1093
1094        tokio::task::spawn(async move {
1095            let (incoming, _) = listener.accept().await.unwrap();
1096            let session = builder.connect_incoming(incoming).await;
1097            session.await;
1098
1099            tx.send(()).unwrap();
1100        });
1101
1102        tokio::task::spawn(fut);
1103
1104        rx.await.unwrap();
1105    }
1106
1107    #[tokio::test(flavor = "multi_thread")]
1108    async fn test_request_timeout() {
1109        reth_tracing::init_test_tracing();
1110
1111        let mut builder = SessionBuilder::default();
1112
1113        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1114        let local_addr = listener.local_addr().unwrap();
1115
1116        let request_timeout = Duration::from_millis(100);
1117        let drop_timeout = Duration::from_millis(1500);
1118
1119        let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1120            let _client_stream = client_stream;
1121            tokio::time::sleep(drop_timeout * 60).await;
1122        });
1123        tokio::task::spawn(fut);
1124
1125        let (incoming, _) = listener.accept().await.unwrap();
1126        let mut session = builder.connect_incoming(incoming).await;
1127        session
1128            .internal_request_timeout
1129            .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1130        session.protocol_breach_request_timeout = drop_timeout;
1131        session.internal_request_timeout_interval =
1132            tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1133        let (tx, rx) = oneshot::channel();
1134        let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1135        session.on_internal_peer_request(req, Instant::now());
1136        tokio::spawn(session);
1137
1138        let err = rx.await.unwrap().unwrap_err();
1139        assert_eq!(err, RequestError::Timeout);
1140
1141        // wait for protocol breach error
1142        let msg = builder.active_session_rx.next().await.unwrap();
1143        match msg {
1144            ActiveSessionMessage::ProtocolBreach { .. } => {}
1145            ev => unreachable!("{ev:?}"),
1146        }
1147    }
1148
1149    #[tokio::test(flavor = "multi_thread")]
1150    async fn test_keep_alive() {
1151        let mut builder = SessionBuilder::default();
1152
1153        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1154        let local_addr = listener.local_addr().unwrap();
1155
1156        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1157            let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1158            client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1159        });
1160
1161        let (tx, rx) = oneshot::channel();
1162
1163        tokio::task::spawn(async move {
1164            let (incoming, _) = listener.accept().await.unwrap();
1165            let session = builder.connect_incoming(incoming).await;
1166            session.await;
1167
1168            tx.send(()).unwrap();
1169        });
1170
1171        tokio::task::spawn(fut);
1172
1173        rx.await.unwrap();
1174    }
1175
1176    // This tests that incoming messages are delivered when there's capacity.
1177    #[tokio::test(flavor = "multi_thread")]
1178    async fn test_send_at_capacity() {
1179        let mut builder = SessionBuilder::default();
1180
1181        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1182        let local_addr = listener.local_addr().unwrap();
1183
1184        let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1185            client_stream
1186                .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1187                .await
1188                .unwrap();
1189            let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1190        });
1191        tokio::task::spawn(fut);
1192
1193        let (incoming, _) = listener.accept().await.unwrap();
1194        let session = builder.connect_incoming(incoming).await;
1195
1196        // fill the entire message buffer with an unrelated message
1197        let mut num_fill_messages = 0;
1198        loop {
1199            if builder
1200                .active_session_tx
1201                .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1202                .is_err()
1203            {
1204                break
1205            }
1206            num_fill_messages += 1;
1207        }
1208
1209        tokio::task::spawn(async move {
1210            session.await;
1211        });
1212
1213        tokio::time::sleep(Duration::from_millis(100)).await;
1214
1215        for _ in 0..num_fill_messages {
1216            let message = builder.active_session_rx.next().await.unwrap();
1217            match message {
1218                ActiveSessionMessage::ProtocolBreach { .. } => {}
1219                ev => unreachable!("{ev:?}"),
1220            }
1221        }
1222
1223        let message = builder.active_session_rx.next().await.unwrap();
1224        match message {
1225            ActiveSessionMessage::ValidMessage {
1226                message: PeerMessage::PooledTransactions(_),
1227                ..
1228            } => {}
1229            _ => unreachable!(),
1230        }
1231    }
1232
1233    #[test]
1234    fn timeout_calculation_sanity_tests() {
1235        let rtt = Duration::from_secs(5);
1236        // timeout for an RTT of `rtt`
1237        let timeout = rtt * TIMEOUT_SCALING;
1238
1239        // if rtt hasn't changed, timeout shouldn't change
1240        assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1241
1242        // if rtt changed, the new timeout should change less than it
1243        assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1244        assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1245        assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1246        assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1247    }
1248}