1use core::sync::atomic::Ordering;
4use std::{
5 collections::VecDeque,
6 future::Future,
7 net::SocketAddr,
8 pin::Pin,
9 sync::{atomic::AtomicU64, Arc},
10 task::{ready, Context, Poll},
11 time::{Duration, Instant},
12};
13
14use crate::{
15 message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16 session::{
17 conn::EthRlpxConnection,
18 handle::{ActiveSessionMessage, SessionCommand},
19 SessionId,
20 },
21};
22use alloy_primitives::Sealable;
23use futures::{stream::Fuse, SinkExt, StreamExt};
24use metrics::Gauge;
25use reth_eth_wire::{
26 errors::{EthHandshakeError, EthStreamError},
27 message::{EthBroadcastMessage, RequestPair},
28 Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives,
29};
30use reth_eth_wire_types::RawCapabilityMessage;
31use reth_metrics::common::mpsc::MeteredPollSender;
32use reth_network_api::PeerRequest;
33use reth_network_p2p::error::RequestError;
34use reth_network_peers::PeerId;
35use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
36use reth_primitives_traits::Block;
37use rustc_hash::FxHashMap;
38use tokio::{
39 sync::{mpsc::error::TrySendError, oneshot},
40 time::Interval,
41};
42use tokio_stream::wrappers::ReceiverStream;
43use tokio_util::sync::PollSender;
44use tracing::{debug, trace};
45
46const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
50const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
52const SAMPLE_IMPACT: f64 = 0.1;
54const TIMEOUT_SCALING: u32 = 3;
56
57const MAX_QUEUED_OUTGOING_RESPONSES: usize = 4;
69
70#[expect(dead_code)]
80pub(crate) struct ActiveSession<N: NetworkPrimitives> {
81 pub(crate) next_id: u64,
83 pub(crate) conn: EthRlpxConnection<N>,
85 pub(crate) remote_peer_id: PeerId,
87 pub(crate) remote_addr: SocketAddr,
89 pub(crate) remote_capabilities: Arc<Capabilities>,
91 pub(crate) session_id: SessionId,
93 pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
95 pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
97 pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
99 pub(crate) internal_request_rx: Fuse<ReceiverStream<PeerRequest<N>>>,
101 pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
103 pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
105 pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
107 pub(crate) internal_request_timeout: Arc<AtomicU64>,
109 pub(crate) internal_request_timeout_interval: Interval,
111 pub(crate) protocol_breach_request_timeout: Duration,
114 pub(crate) terminate_message:
116 Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
117}
118
119impl<N: NetworkPrimitives> ActiveSession<N> {
120 fn is_disconnecting(&self) -> bool {
122 self.conn.inner().is_disconnecting()
123 }
124
125 const fn next_id(&mut self) -> u64 {
127 let id = self.next_id;
128 self.next_id += 1;
129 id
130 }
131
132 pub fn shrink_to_fit(&mut self) {
134 self.received_requests_from_remote.shrink_to_fit();
135 self.queued_outgoing.shrink_to_fit();
136 }
137
138 fn queued_response_count(&self) -> usize {
140 self.queued_outgoing.messages.iter().filter(|m| m.is_response()).count()
141 }
142
143 fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
147 macro_rules! on_request {
151 ($req:ident, $resp_item:ident, $req_item:ident) => {{
152 let RequestPair { request_id, message: request } = $req;
153 let (tx, response) = oneshot::channel();
154 let received = ReceivedRequest {
155 request_id,
156 rx: PeerResponse::$resp_item { response },
157 received: Instant::now(),
158 };
159 self.received_requests_from_remote.push(received);
160 self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
161 request,
162 response: tx,
163 }))
164 .into()
165 }};
166 }
167
168 macro_rules! on_response {
170 ($resp:ident, $item:ident) => {{
171 let RequestPair { request_id, message } = $resp;
172 if let Some(req) = self.inflight_requests.remove(&request_id) {
173 match req.request {
174 RequestState::Waiting(PeerRequest::$item { response, .. }) => {
175 let _ = response.send(Ok(message));
176 self.update_request_timeout(req.timestamp, Instant::now());
177 }
178 RequestState::Waiting(request) => {
179 request.send_bad_response();
180 }
181 RequestState::TimedOut => {
182 self.update_request_timeout(req.timestamp, Instant::now());
184 }
185 }
186 } else {
187 self.on_bad_message();
189 }
190
191 OnIncomingMessageOutcome::Ok
192 }};
193 }
194
195 match msg {
196 message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
197 error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
198 message,
199 },
200 EthMessage::NewBlockHashes(msg) => {
201 self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
202 }
203 EthMessage::NewBlock(msg) => {
204 let block =
205 NewBlockMessage { hash: msg.block.header().hash_slow(), block: Arc::new(*msg) };
206 self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
207 }
208 EthMessage::Transactions(msg) => {
209 self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
210 }
211 EthMessage::NewPooledTransactionHashes66(msg) => {
212 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
213 }
214 EthMessage::NewPooledTransactionHashes68(msg) => {
215 if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
216 return OnIncomingMessageOutcome::BadMessage {
217 error: EthStreamError::TransactionHashesInvalidLenOfFields {
218 hashes_len: msg.hashes.len(),
219 types_len: msg.types.len(),
220 sizes_len: msg.sizes.len(),
221 },
222 message: EthMessage::NewPooledTransactionHashes68(msg),
223 }
224 }
225 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
226 }
227 EthMessage::GetBlockHeaders(req) => {
228 on_request!(req, BlockHeaders, GetBlockHeaders)
229 }
230 EthMessage::BlockHeaders(resp) => {
231 on_response!(resp, GetBlockHeaders)
232 }
233 EthMessage::GetBlockBodies(req) => {
234 on_request!(req, BlockBodies, GetBlockBodies)
235 }
236 EthMessage::BlockBodies(resp) => {
237 on_response!(resp, GetBlockBodies)
238 }
239 EthMessage::GetPooledTransactions(req) => {
240 on_request!(req, PooledTransactions, GetPooledTransactions)
241 }
242 EthMessage::PooledTransactions(resp) => {
243 on_response!(resp, GetPooledTransactions)
244 }
245 EthMessage::GetNodeData(req) => {
246 on_request!(req, NodeData, GetNodeData)
247 }
248 EthMessage::NodeData(resp) => {
249 on_response!(resp, GetNodeData)
250 }
251 EthMessage::GetReceipts(req) => {
252 on_request!(req, Receipts, GetReceipts)
253 }
254 EthMessage::Receipts(resp) => {
255 on_response!(resp, GetReceipts)
256 }
257 EthMessage::Receipts69(resp) => {
258 let resp = resp.map(|receipts| receipts.into_with_bloom());
260 on_response!(resp, GetReceipts)
261 }
262 EthMessage::BlockRangeUpdate(msg) => {
263 self.try_emit_broadcast(PeerMessage::BlockRangeUpdated(msg)).into()
264 }
265 EthMessage::Other(bytes) => self.try_emit_broadcast(PeerMessage::Other(bytes)).into(),
266 }
267 }
268
269 fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
271 let request_id = self.next_id();
272 let msg = request.create_request_message(request_id);
273 self.queued_outgoing.push_back(msg.into());
274 let req = InflightRequest {
275 request: RequestState::Waiting(request),
276 timestamp: Instant::now(),
277 deadline,
278 };
279 self.inflight_requests.insert(request_id, req);
280 }
281
282 fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
284 match msg {
285 PeerMessage::NewBlockHashes(msg) => {
286 self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
287 }
288 PeerMessage::NewBlock(msg) => {
289 self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
290 }
291 PeerMessage::PooledTransactions(msg) => {
292 if msg.is_valid_for_version(self.conn.version()) {
293 self.queued_outgoing.push_back(EthMessage::from(msg).into());
294 }
295 }
296 PeerMessage::EthRequest(req) => {
297 let deadline = self.request_deadline();
298 self.on_internal_peer_request(req, deadline);
299 }
300 PeerMessage::SendTransactions(msg) => {
301 self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
302 }
303 PeerMessage::BlockRangeUpdated(_) => {}
304 PeerMessage::ReceivedTransaction(_) => {
305 unreachable!("Not emitted by network")
306 }
307 PeerMessage::Other(other) => {
308 self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
309 }
310 }
311 }
312
313 fn request_deadline(&self) -> Instant {
315 Instant::now() +
316 Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
317 }
318
319 fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
323 match resp.try_into_message(id) {
324 Ok(msg) => {
325 self.queued_outgoing.push_back(msg.into());
326 }
327 Err(err) => {
328 debug!(target: "net", %err, "Failed to respond to received request");
329 }
330 }
331 }
332
333 #[expect(clippy::result_large_err)]
337 fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
338 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
339
340 match sender
341 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
342 {
343 Ok(_) => Ok(()),
344 Err(err) => {
345 trace!(
346 target: "net",
347 %err,
348 "no capacity for incoming broadcast",
349 );
350 match err {
351 TrySendError::Full(msg) => Err(msg),
352 TrySendError::Closed(_) => Ok(()),
353 }
354 }
355 }
356 }
357
358 #[expect(clippy::result_large_err)]
363 fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
364 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
365
366 match sender
367 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
368 {
369 Ok(_) => Ok(()),
370 Err(err) => {
371 trace!(
372 target: "net",
373 %err,
374 "no capacity for incoming request",
375 );
376 match err {
377 TrySendError::Full(msg) => Err(msg),
378 TrySendError::Closed(_) => {
379 Ok(())
382 }
383 }
384 }
385 }
386 }
387
388 fn on_bad_message(&self) {
390 let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
391 let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
392 }
393
394 fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
396 trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
397 let msg = ActiveSessionMessage::Disconnected {
398 peer_id: self.remote_peer_id,
399 remote_addr: self.remote_addr,
400 };
401
402 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
403 self.poll_terminate_message(cx).expect("message is set")
404 }
405
406 fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
408 let msg = ActiveSessionMessage::ClosedOnConnectionError {
409 peer_id: self.remote_peer_id,
410 remote_addr: self.remote_addr,
411 error,
412 };
413 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
414 self.poll_terminate_message(cx).expect("message is set")
415 }
416
417 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
419 Ok(self.conn.inner_mut().start_disconnect(reason)?)
420 }
421
422 fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
424 debug_assert!(self.is_disconnecting(), "not disconnecting");
425
426 let _ = ready!(self.conn.poll_close_unpin(cx));
428 self.emit_disconnect(cx)
429 }
430
431 fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
433 match self.start_disconnect(reason) {
434 Ok(()) => {
435 self.poll_disconnect(cx)
437 }
438 Err(err) => {
439 debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
440 self.close_on_error(err, cx)
441 }
442 }
443 }
444
445 #[must_use]
454 fn check_timed_out_requests(&mut self, now: Instant) -> bool {
455 for (id, req) in &mut self.inflight_requests {
456 if req.is_timed_out(now) {
457 if req.is_waiting() {
458 debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
459 req.timeout();
460 } else if now - req.timestamp > self.protocol_breach_request_timeout {
461 return true
462 }
463 }
464 }
465
466 false
467 }
468
469 fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
471 let elapsed = received.saturating_duration_since(sent);
472
473 let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
474 let request_timeout = calculate_new_timeout(current, elapsed);
475 self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
476 self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
477 }
478
479 fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
481 let (mut tx, msg) = self.terminate_message.take()?;
482 match tx.poll_reserve(cx) {
483 Poll::Pending => {
484 self.terminate_message = Some((tx, msg));
485 return Some(Poll::Pending)
486 }
487 Poll::Ready(Ok(())) => {
488 let _ = tx.send_item(msg);
489 }
490 Poll::Ready(Err(_)) => {
491 }
493 }
494 Some(Poll::Ready(()))
496 }
497}
498
499impl<N: NetworkPrimitives> Future for ActiveSession<N> {
500 type Output = ();
501
502 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
503 let this = self.get_mut();
504
505 if let Some(terminate) = this.poll_terminate_message(cx) {
507 return terminate
508 }
509
510 if this.is_disconnecting() {
511 return this.poll_disconnect(cx)
512 }
513
514 let mut budget = 4;
520
521 'main: loop {
523 let mut progress = false;
524
525 loop {
527 match this.commands_rx.poll_next_unpin(cx) {
528 Poll::Pending => break,
529 Poll::Ready(None) => {
530 return Poll::Ready(())
533 }
534 Poll::Ready(Some(cmd)) => {
535 progress = true;
536 match cmd {
537 SessionCommand::Disconnect { reason } => {
538 debug!(
539 target: "net::session",
540 ?reason,
541 remote_peer_id=?this.remote_peer_id,
542 "Received disconnect command for session"
543 );
544 let reason =
545 reason.unwrap_or(DisconnectReason::DisconnectRequested);
546
547 return this.try_disconnect(reason, cx)
548 }
549 SessionCommand::Message(msg) => {
550 this.on_internal_peer_message(msg);
551 }
552 }
553 }
554 }
555 }
556
557 let deadline = this.request_deadline();
558
559 while let Poll::Ready(Some(req)) = this.internal_request_rx.poll_next_unpin(cx) {
560 progress = true;
561 this.on_internal_peer_request(req, deadline);
562 }
563
564 for idx in (0..this.received_requests_from_remote.len()).rev() {
567 let mut req = this.received_requests_from_remote.swap_remove(idx);
568 match req.rx.poll(cx) {
569 Poll::Pending => {
570 this.received_requests_from_remote.push(req);
572 }
573 Poll::Ready(resp) => {
574 this.handle_outgoing_response(req.request_id, resp);
575 }
576 }
577 }
578
579 while this.conn.poll_ready_unpin(cx).is_ready() {
581 if let Some(msg) = this.queued_outgoing.pop_front() {
582 progress = true;
583 let res = match msg {
584 OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
585 OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
586 OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
587 };
588 if let Err(err) = res {
589 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
590 return this.close_on_error(err, cx)
592 }
593 } else {
594 break
596 }
597 }
598
599 'receive: loop {
601 budget -= 1;
603 if budget == 0 {
604 cx.waker().wake_by_ref();
606 break 'main
607 }
608
609 if let Some(msg) = this.pending_message_to_session.take() {
613 match this.to_session_manager.poll_reserve(cx) {
614 Poll::Ready(Ok(_)) => {
615 let _ = this.to_session_manager.send_item(msg);
616 }
617 Poll::Ready(Err(_)) => return Poll::Ready(()),
618 Poll::Pending => {
619 this.pending_message_to_session = Some(msg);
620 break 'receive
621 }
622 };
623 }
624
625 if this.received_requests_from_remote.len() > MAX_QUEUED_OUTGOING_RESPONSES {
627 break 'receive
633 }
634
635 if this.queued_outgoing.messages.len() > MAX_QUEUED_OUTGOING_RESPONSES &&
637 this.queued_response_count() > MAX_QUEUED_OUTGOING_RESPONSES
638 {
639 break 'receive
646 }
647
648 match this.conn.poll_next_unpin(cx) {
649 Poll::Pending => break,
650 Poll::Ready(None) => {
651 if this.is_disconnecting() {
652 break
653 }
654 debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
655 return this.emit_disconnect(cx)
656 }
657 Poll::Ready(Some(res)) => {
658 match res {
659 Ok(msg) => {
660 trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
661 match this.on_incoming_message(msg) {
663 OnIncomingMessageOutcome::Ok => {
664 progress = true;
666 }
667 OnIncomingMessageOutcome::BadMessage { error, message } => {
668 debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
669 return this.close_on_error(error, cx)
670 }
671 OnIncomingMessageOutcome::NoCapacity(msg) => {
672 this.pending_message_to_session = Some(msg);
674 }
675 }
676 }
677 Err(err) => {
678 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
679 return this.close_on_error(err, cx)
680 }
681 }
682 }
683 }
684 }
685
686 if !progress {
687 break 'main
688 }
689 }
690
691 while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
692 if this.check_timed_out_requests(Instant::now()) {
694 if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
695 let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
696 this.pending_message_to_session = Some(msg);
697 }
698 }
699 }
700
701 this.shrink_to_fit();
702
703 Poll::Pending
704 }
705}
706
707pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
709 request_id: u64,
711 rx: PeerResponse<N>,
713 #[expect(dead_code)]
715 received: Instant,
716}
717
718pub(crate) struct InflightRequest<R> {
720 request: RequestState<R>,
722 timestamp: Instant,
724 deadline: Instant,
726}
727
728impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
731 #[inline]
733 fn is_timed_out(&self, now: Instant) -> bool {
734 now > self.deadline
735 }
736
737 #[inline]
739 const fn is_waiting(&self) -> bool {
740 matches!(self.request, RequestState::Waiting(_))
741 }
742
743 fn timeout(&mut self) {
745 let mut req = RequestState::TimedOut;
746 std::mem::swap(&mut self.request, &mut req);
747
748 if let RequestState::Waiting(req) = req {
749 req.send_err_response(RequestError::Timeout);
750 }
751 }
752}
753
754enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
756 Ok,
758 BadMessage { error: EthStreamError, message: EthMessage<N> },
760 NoCapacity(ActiveSessionMessage<N>),
762}
763
764impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
765 for OnIncomingMessageOutcome<N>
766{
767 fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
768 match res {
769 Ok(_) => Self::Ok,
770 Err(msg) => Self::NoCapacity(msg),
771 }
772 }
773}
774
775enum RequestState<R> {
776 Waiting(R),
778 TimedOut,
780}
781
782pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
784 Eth(EthMessage<N>),
786 Broadcast(EthBroadcastMessage<N>),
788 Raw(RawCapabilityMessage),
790}
791
792impl<N: NetworkPrimitives> OutgoingMessage<N> {
793 const fn is_response(&self) -> bool {
795 match self {
796 Self::Eth(msg) => msg.is_response(),
797 _ => false,
798 }
799 }
800}
801
802impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
803 fn from(value: EthMessage<N>) -> Self {
804 Self::Eth(value)
805 }
806}
807
808impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
809 fn from(value: EthBroadcastMessage<N>) -> Self {
810 Self::Broadcast(value)
811 }
812}
813
814#[inline]
816fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
817 let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
818
819 let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
821
822 smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
823}
824
825pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
827 messages: VecDeque<OutgoingMessage<N>>,
828 count: Gauge,
829}
830
831impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
832 pub(crate) const fn new(metric: Gauge) -> Self {
833 Self { messages: VecDeque::new(), count: metric }
834 }
835
836 pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
837 self.messages.push_back(message);
838 self.count.increment(1);
839 }
840
841 pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
842 self.messages.pop_front().inspect(|_| self.count.decrement(1))
843 }
844
845 pub(crate) fn shrink_to_fit(&mut self) {
846 self.messages.shrink_to_fit();
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853 use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
854 use alloy_eips::eip2124::ForkFilter;
855 use reth_chainspec::MAINNET;
856 use reth_ecies::stream::ECIESStream;
857 use reth_eth_wire::{
858 handshake::EthHandshake, EthNetworkPrimitives, EthStream, GetBlockBodies,
859 HelloMessageWithProtocols, P2PStream, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
860 UnifiedStatus,
861 };
862 use reth_ethereum_forks::EthereumHardfork;
863 use reth_network_peers::pk2id;
864 use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
865 use secp256k1::{SecretKey, SECP256K1};
866 use tokio::{
867 net::{TcpListener, TcpStream},
868 sync::mpsc,
869 };
870
871 fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
873 HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
874 }
875
876 struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
877 _remote_capabilities: Arc<Capabilities>,
878 active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
879 active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
880 to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
881 secret_key: SecretKey,
882 local_peer_id: PeerId,
883 hello: HelloMessageWithProtocols,
884 status: UnifiedStatus,
885 fork_filter: ForkFilter,
886 next_id: usize,
887 }
888
889 impl<N: NetworkPrimitives> SessionBuilder<N> {
890 fn next_id(&mut self) -> SessionId {
891 let id = self.next_id;
892 self.next_id += 1;
893 SessionId(id)
894 }
895
896 fn with_client_stream<F, O>(
898 &self,
899 local_addr: SocketAddr,
900 f: F,
901 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
902 where
903 F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
904 O: Future<Output = ()> + Send + Sync,
905 {
906 let status = self.status;
907 let fork_filter = self.fork_filter.clone();
908 let local_peer_id = self.local_peer_id;
909 let mut hello = self.hello.clone();
910 let key = SecretKey::new(&mut rand_08::thread_rng());
911 hello.id = pk2id(&key.public_key(SECP256K1));
912 Box::pin(async move {
913 let outgoing = TcpStream::connect(local_addr).await.unwrap();
914 let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
915
916 let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
917
918 let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
919 .handshake(status, fork_filter)
920 .await
921 .unwrap();
922 f(client_stream).await
923 })
924 }
925
926 async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
927 let remote_addr = stream.local_addr().unwrap();
928 let session_id = self.next_id();
929 let (_disconnect_tx, disconnect_rx) = oneshot::channel();
930 let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
931
932 tokio::task::spawn(start_pending_incoming_session(
933 Arc::new(EthHandshake::default()),
934 disconnect_rx,
935 session_id,
936 stream,
937 pending_sessions_tx,
938 remote_addr,
939 self.secret_key,
940 self.hello.clone(),
941 self.status,
942 self.fork_filter.clone(),
943 Default::default(),
944 ));
945
946 let mut stream = ReceiverStream::new(pending_sessions_rx);
947
948 match stream.next().await.unwrap() {
949 PendingSessionEvent::Established {
950 session_id,
951 remote_addr,
952 peer_id,
953 capabilities,
954 conn,
955 ..
956 } => {
957 let (_to_session_tx, messages_rx) = mpsc::channel(10);
958 let (commands_to_session, commands_rx) = mpsc::channel(10);
959 let poll_sender = PollSender::new(self.active_session_tx.clone());
960
961 self.to_sessions.push(commands_to_session);
962
963 ActiveSession {
964 next_id: 0,
965 remote_peer_id: peer_id,
966 remote_addr,
967 remote_capabilities: Arc::clone(&capabilities),
968 session_id,
969 commands_rx: ReceiverStream::new(commands_rx),
970 to_session_manager: MeteredPollSender::new(
971 poll_sender,
972 "network_active_session",
973 ),
974 pending_message_to_session: None,
975 internal_request_rx: ReceiverStream::new(messages_rx).fuse(),
976 inflight_requests: Default::default(),
977 conn,
978 queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
979 received_requests_from_remote: Default::default(),
980 internal_request_timeout_interval: tokio::time::interval(
981 INITIAL_REQUEST_TIMEOUT,
982 ),
983 internal_request_timeout: Arc::new(AtomicU64::new(
984 INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
985 )),
986 protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
987 terminate_message: None,
988 }
989 }
990 ev => {
991 panic!("unexpected message {ev:?}")
992 }
993 }
994 }
995 }
996
997 impl Default for SessionBuilder {
998 fn default() -> Self {
999 let (active_session_tx, active_session_rx) = mpsc::channel(100);
1000
1001 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand_08::thread_rng());
1002 let local_peer_id = pk2id(&pk);
1003
1004 Self {
1005 next_id: 0,
1006 _remote_capabilities: Arc::new(Capabilities::from(vec![])),
1007 active_session_tx,
1008 active_session_rx: ReceiverStream::new(active_session_rx),
1009 to_sessions: vec![],
1010 hello: eth_hello(&secret_key),
1011 secret_key,
1012 local_peer_id,
1013 status: StatusBuilder::default().build(),
1014 fork_filter: MAINNET
1015 .hardfork_fork_filter(EthereumHardfork::Frontier)
1016 .expect("The Frontier fork filter should exist on mainnet"),
1017 }
1018 }
1019 }
1020
1021 #[tokio::test(flavor = "multi_thread")]
1022 async fn test_disconnect() {
1023 let mut builder = SessionBuilder::default();
1024
1025 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1026 let local_addr = listener.local_addr().unwrap();
1027
1028 let expected_disconnect = DisconnectReason::UselessPeer;
1029
1030 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1031 let msg = client_stream.next().await.unwrap().unwrap_err();
1032 assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
1033 });
1034
1035 tokio::task::spawn(async move {
1036 let (incoming, _) = listener.accept().await.unwrap();
1037 let mut session = builder.connect_incoming(incoming).await;
1038
1039 session.start_disconnect(expected_disconnect).unwrap();
1040 session.await
1041 });
1042
1043 fut.await;
1044 }
1045
1046 #[tokio::test(flavor = "multi_thread")]
1047 async fn handle_dropped_stream() {
1048 let mut builder = SessionBuilder::default();
1049
1050 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1051 let local_addr = listener.local_addr().unwrap();
1052
1053 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1054 drop(client_stream);
1055 tokio::time::sleep(Duration::from_secs(1)).await
1056 });
1057
1058 let (tx, rx) = oneshot::channel();
1059
1060 tokio::task::spawn(async move {
1061 let (incoming, _) = listener.accept().await.unwrap();
1062 let session = builder.connect_incoming(incoming).await;
1063 session.await;
1064
1065 tx.send(()).unwrap();
1066 });
1067
1068 tokio::task::spawn(fut);
1069
1070 rx.await.unwrap();
1071 }
1072
1073 #[tokio::test(flavor = "multi_thread")]
1074 async fn test_send_many_messages() {
1075 reth_tracing::init_test_tracing();
1076 let mut builder = SessionBuilder::default();
1077
1078 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1079 let local_addr = listener.local_addr().unwrap();
1080
1081 let num_messages = 100;
1082
1083 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1084 for _ in 0..num_messages {
1085 client_stream
1086 .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1087 .await
1088 .unwrap();
1089 }
1090 });
1091
1092 let (tx, rx) = oneshot::channel();
1093
1094 tokio::task::spawn(async move {
1095 let (incoming, _) = listener.accept().await.unwrap();
1096 let session = builder.connect_incoming(incoming).await;
1097 session.await;
1098
1099 tx.send(()).unwrap();
1100 });
1101
1102 tokio::task::spawn(fut);
1103
1104 rx.await.unwrap();
1105 }
1106
1107 #[tokio::test(flavor = "multi_thread")]
1108 async fn test_request_timeout() {
1109 reth_tracing::init_test_tracing();
1110
1111 let mut builder = SessionBuilder::default();
1112
1113 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1114 let local_addr = listener.local_addr().unwrap();
1115
1116 let request_timeout = Duration::from_millis(100);
1117 let drop_timeout = Duration::from_millis(1500);
1118
1119 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1120 let _client_stream = client_stream;
1121 tokio::time::sleep(drop_timeout * 60).await;
1122 });
1123 tokio::task::spawn(fut);
1124
1125 let (incoming, _) = listener.accept().await.unwrap();
1126 let mut session = builder.connect_incoming(incoming).await;
1127 session
1128 .internal_request_timeout
1129 .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1130 session.protocol_breach_request_timeout = drop_timeout;
1131 session.internal_request_timeout_interval =
1132 tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1133 let (tx, rx) = oneshot::channel();
1134 let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1135 session.on_internal_peer_request(req, Instant::now());
1136 tokio::spawn(session);
1137
1138 let err = rx.await.unwrap().unwrap_err();
1139 assert_eq!(err, RequestError::Timeout);
1140
1141 let msg = builder.active_session_rx.next().await.unwrap();
1143 match msg {
1144 ActiveSessionMessage::ProtocolBreach { .. } => {}
1145 ev => unreachable!("{ev:?}"),
1146 }
1147 }
1148
1149 #[tokio::test(flavor = "multi_thread")]
1150 async fn test_keep_alive() {
1151 let mut builder = SessionBuilder::default();
1152
1153 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1154 let local_addr = listener.local_addr().unwrap();
1155
1156 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1157 let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1158 client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1159 });
1160
1161 let (tx, rx) = oneshot::channel();
1162
1163 tokio::task::spawn(async move {
1164 let (incoming, _) = listener.accept().await.unwrap();
1165 let session = builder.connect_incoming(incoming).await;
1166 session.await;
1167
1168 tx.send(()).unwrap();
1169 });
1170
1171 tokio::task::spawn(fut);
1172
1173 rx.await.unwrap();
1174 }
1175
1176 #[tokio::test(flavor = "multi_thread")]
1178 async fn test_send_at_capacity() {
1179 let mut builder = SessionBuilder::default();
1180
1181 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1182 let local_addr = listener.local_addr().unwrap();
1183
1184 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1185 client_stream
1186 .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1187 .await
1188 .unwrap();
1189 let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1190 });
1191 tokio::task::spawn(fut);
1192
1193 let (incoming, _) = listener.accept().await.unwrap();
1194 let session = builder.connect_incoming(incoming).await;
1195
1196 let mut num_fill_messages = 0;
1198 loop {
1199 if builder
1200 .active_session_tx
1201 .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1202 .is_err()
1203 {
1204 break
1205 }
1206 num_fill_messages += 1;
1207 }
1208
1209 tokio::task::spawn(async move {
1210 session.await;
1211 });
1212
1213 tokio::time::sleep(Duration::from_millis(100)).await;
1214
1215 for _ in 0..num_fill_messages {
1216 let message = builder.active_session_rx.next().await.unwrap();
1217 match message {
1218 ActiveSessionMessage::ProtocolBreach { .. } => {}
1219 ev => unreachable!("{ev:?}"),
1220 }
1221 }
1222
1223 let message = builder.active_session_rx.next().await.unwrap();
1224 match message {
1225 ActiveSessionMessage::ValidMessage {
1226 message: PeerMessage::PooledTransactions(_),
1227 ..
1228 } => {}
1229 _ => unreachable!(),
1230 }
1231 }
1232
1233 #[test]
1234 fn timeout_calculation_sanity_tests() {
1235 let rtt = Duration::from_secs(5);
1236 let timeout = rtt * TIMEOUT_SCALING;
1238
1239 assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1241
1242 assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1244 assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1245 assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1246 assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1247 }
1248}