1use core::sync::atomic::Ordering;
4use std::{
5 collections::VecDeque,
6 future::Future,
7 net::SocketAddr,
8 pin::Pin,
9 sync::{atomic::AtomicU64, Arc},
10 task::{ready, Context, Poll},
11 time::{Duration, Instant},
12};
13
14use crate::{
15 message::{NewBlockMessage, PeerMessage, PeerResponse, PeerResponseResult},
16 session::{
17 conn::EthRlpxConnection,
18 handle::{ActiveSessionMessage, SessionCommand},
19 SessionId,
20 },
21};
22use alloy_primitives::Sealable;
23use futures::{stream::Fuse, SinkExt, StreamExt};
24use metrics::Gauge;
25use reth_eth_wire::{
26 capability::RawCapabilityMessage,
27 errors::{EthHandshakeError, EthStreamError, P2PStreamError},
28 message::{EthBroadcastMessage, RequestPair},
29 Capabilities, DisconnectP2P, DisconnectReason, EthMessage, NetworkPrimitives,
30};
31use reth_metrics::common::mpsc::MeteredPollSender;
32use reth_network_api::PeerRequest;
33use reth_network_p2p::error::RequestError;
34use reth_network_peers::PeerId;
35use reth_network_types::session::config::INITIAL_REQUEST_TIMEOUT;
36use reth_primitives_traits::Block;
37use rustc_hash::FxHashMap;
38use tokio::{
39 sync::{mpsc::error::TrySendError, oneshot},
40 time::Interval,
41};
42use tokio_stream::wrappers::ReceiverStream;
43use tokio_util::sync::PollSender;
44use tracing::{debug, trace};
45
46const MINIMUM_TIMEOUT: Duration = Duration::from_secs(2);
50const MAXIMUM_TIMEOUT: Duration = INITIAL_REQUEST_TIMEOUT;
52const SAMPLE_IMPACT: f64 = 0.1;
54const TIMEOUT_SCALING: u32 = 3;
56
57#[allow(dead_code)]
67pub(crate) struct ActiveSession<N: NetworkPrimitives> {
68 pub(crate) next_id: u64,
70 pub(crate) conn: EthRlpxConnection<N>,
72 pub(crate) remote_peer_id: PeerId,
74 pub(crate) remote_addr: SocketAddr,
76 pub(crate) remote_capabilities: Arc<Capabilities>,
78 pub(crate) session_id: SessionId,
80 pub(crate) commands_rx: ReceiverStream<SessionCommand<N>>,
82 pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage<N>>,
84 pub(crate) pending_message_to_session: Option<ActiveSessionMessage<N>>,
86 pub(crate) internal_request_tx: Fuse<ReceiverStream<PeerRequest<N>>>,
88 pub(crate) inflight_requests: FxHashMap<u64, InflightRequest<PeerRequest<N>>>,
90 pub(crate) received_requests_from_remote: Vec<ReceivedRequest<N>>,
92 pub(crate) queued_outgoing: QueuedOutgoingMessages<N>,
94 pub(crate) internal_request_timeout: Arc<AtomicU64>,
96 pub(crate) internal_request_timeout_interval: Interval,
98 pub(crate) protocol_breach_request_timeout: Duration,
101 pub(crate) terminate_message:
103 Option<(PollSender<ActiveSessionMessage<N>>, ActiveSessionMessage<N>)>,
104}
105
106impl<N: NetworkPrimitives> ActiveSession<N> {
107 fn is_disconnecting(&self) -> bool {
109 self.conn.inner().is_disconnecting()
110 }
111
112 fn next_id(&mut self) -> u64 {
114 let id = self.next_id;
115 self.next_id += 1;
116 id
117 }
118
119 pub fn shrink_to_fit(&mut self) {
121 self.received_requests_from_remote.shrink_to_fit();
122 self.queued_outgoing.shrink_to_fit();
123 }
124
125 fn on_incoming_message(&mut self, msg: EthMessage<N>) -> OnIncomingMessageOutcome<N> {
129 macro_rules! on_request {
133 ($req:ident, $resp_item:ident, $req_item:ident) => {{
134 let RequestPair { request_id, message: request } = $req;
135 let (tx, response) = oneshot::channel();
136 let received = ReceivedRequest {
137 request_id,
138 rx: PeerResponse::$resp_item { response },
139 received: Instant::now(),
140 };
141 self.received_requests_from_remote.push(received);
142 self.try_emit_request(PeerMessage::EthRequest(PeerRequest::$req_item {
143 request,
144 response: tx,
145 }))
146 .into()
147 }};
148 }
149
150 macro_rules! on_response {
152 ($resp:ident, $item:ident) => {{
153 let RequestPair { request_id, message } = $resp;
154 #[allow(clippy::collapsible_match)]
155 if let Some(req) = self.inflight_requests.remove(&request_id) {
156 match req.request {
157 RequestState::Waiting(PeerRequest::$item { response, .. }) => {
158 let _ = response.send(Ok(message));
159 self.update_request_timeout(req.timestamp, Instant::now());
160 }
161 RequestState::Waiting(request) => {
162 request.send_bad_response();
163 }
164 RequestState::TimedOut => {
165 self.update_request_timeout(req.timestamp, Instant::now());
167 }
168 }
169 } else {
170 self.on_bad_message();
172 }
173
174 OnIncomingMessageOutcome::Ok
175 }};
176 }
177
178 match msg {
179 message @ EthMessage::Status(_) => OnIncomingMessageOutcome::BadMessage {
180 error: EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake),
181 message,
182 },
183 EthMessage::NewBlockHashes(msg) => {
184 self.try_emit_broadcast(PeerMessage::NewBlockHashes(msg)).into()
185 }
186 EthMessage::NewBlock(msg) => {
187 let block =
188 NewBlockMessage { hash: msg.block.header().hash_slow(), block: Arc::new(*msg) };
189 self.try_emit_broadcast(PeerMessage::NewBlock(block)).into()
190 }
191 EthMessage::Transactions(msg) => {
192 self.try_emit_broadcast(PeerMessage::ReceivedTransaction(msg)).into()
193 }
194 EthMessage::NewPooledTransactionHashes66(msg) => {
195 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
196 }
197 EthMessage::NewPooledTransactionHashes68(msg) => {
198 if msg.hashes.len() != msg.types.len() || msg.hashes.len() != msg.sizes.len() {
199 return OnIncomingMessageOutcome::BadMessage {
200 error: EthStreamError::TransactionHashesInvalidLenOfFields {
201 hashes_len: msg.hashes.len(),
202 types_len: msg.types.len(),
203 sizes_len: msg.sizes.len(),
204 },
205 message: EthMessage::NewPooledTransactionHashes68(msg),
206 }
207 }
208 self.try_emit_broadcast(PeerMessage::PooledTransactions(msg.into())).into()
209 }
210 EthMessage::GetBlockHeaders(req) => {
211 on_request!(req, BlockHeaders, GetBlockHeaders)
212 }
213 EthMessage::BlockHeaders(resp) => {
214 on_response!(resp, GetBlockHeaders)
215 }
216 EthMessage::GetBlockBodies(req) => {
217 on_request!(req, BlockBodies, GetBlockBodies)
218 }
219 EthMessage::BlockBodies(resp) => {
220 on_response!(resp, GetBlockBodies)
221 }
222 EthMessage::GetPooledTransactions(req) => {
223 on_request!(req, PooledTransactions, GetPooledTransactions)
224 }
225 EthMessage::PooledTransactions(resp) => {
226 on_response!(resp, GetPooledTransactions)
227 }
228 EthMessage::GetNodeData(req) => {
229 on_request!(req, NodeData, GetNodeData)
230 }
231 EthMessage::NodeData(resp) => {
232 on_response!(resp, GetNodeData)
233 }
234 EthMessage::GetReceipts(req) => {
235 on_request!(req, Receipts, GetReceipts)
236 }
237 EthMessage::Receipts(resp) => {
238 on_response!(resp, GetReceipts)
239 }
240 }
241 }
242
243 fn on_internal_peer_request(&mut self, request: PeerRequest<N>, deadline: Instant) {
245 let request_id = self.next_id();
246 let msg = request.create_request_message(request_id);
247 self.queued_outgoing.push_back(msg.into());
248 let req = InflightRequest {
249 request: RequestState::Waiting(request),
250 timestamp: Instant::now(),
251 deadline,
252 };
253 self.inflight_requests.insert(request_id, req);
254 }
255
256 fn on_internal_peer_message(&mut self, msg: PeerMessage<N>) {
258 match msg {
259 PeerMessage::NewBlockHashes(msg) => {
260 self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
261 }
262 PeerMessage::NewBlock(msg) => {
263 self.queued_outgoing.push_back(EthBroadcastMessage::NewBlock(msg.block).into());
264 }
265 PeerMessage::PooledTransactions(msg) => {
266 if msg.is_valid_for_version(self.conn.version()) {
267 self.queued_outgoing.push_back(EthMessage::from(msg).into());
268 }
269 }
270 PeerMessage::EthRequest(req) => {
271 let deadline = self.request_deadline();
272 self.on_internal_peer_request(req, deadline);
273 }
274 PeerMessage::SendTransactions(msg) => {
275 self.queued_outgoing.push_back(EthBroadcastMessage::Transactions(msg).into());
276 }
277 PeerMessage::ReceivedTransaction(_) => {
278 unreachable!("Not emitted by network")
279 }
280 PeerMessage::Other(other) => {
281 debug!(target: "net::session", message_id=%other.id, "Ignoring unsupported message");
282 self.queued_outgoing.push_back(OutgoingMessage::Raw(other));
283 }
284 }
285 }
286
287 fn request_deadline(&self) -> Instant {
289 Instant::now() +
290 Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed))
291 }
292
293 fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult<N>) {
297 match resp.try_into_message(id) {
298 Ok(msg) => {
299 self.queued_outgoing.push_back(msg.into());
300 }
301 Err(err) => {
302 debug!(target: "net", %err, "Failed to respond to received request");
303 }
304 }
305 }
306
307 #[allow(clippy::result_large_err)]
311 fn try_emit_broadcast(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
312 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
313
314 match sender
315 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
316 {
317 Ok(_) => Ok(()),
318 Err(err) => {
319 trace!(
320 target: "net",
321 %err,
322 "no capacity for incoming broadcast",
323 );
324 match err {
325 TrySendError::Full(msg) => Err(msg),
326 TrySendError::Closed(_) => Ok(()),
327 }
328 }
329 }
330 }
331
332 #[allow(clippy::result_large_err)]
337 fn try_emit_request(&self, message: PeerMessage<N>) -> Result<(), ActiveSessionMessage<N>> {
338 let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
339
340 match sender
341 .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
342 {
343 Ok(_) => Ok(()),
344 Err(err) => {
345 trace!(
346 target: "net",
347 %err,
348 "no capacity for incoming request",
349 );
350 match err {
351 TrySendError::Full(msg) => Err(msg),
352 TrySendError::Closed(_) => {
353 Ok(())
356 }
357 }
358 }
359 }
360 }
361
362 fn on_bad_message(&self) {
364 let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
365 let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
366 }
367
368 fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
370 trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
371 let msg = ActiveSessionMessage::Disconnected {
372 peer_id: self.remote_peer_id,
373 remote_addr: self.remote_addr,
374 };
375
376 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
377 self.poll_terminate_message(cx).expect("message is set")
378 }
379
380 fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
382 let msg = ActiveSessionMessage::ClosedOnConnectionError {
383 peer_id: self.remote_peer_id,
384 remote_addr: self.remote_addr,
385 error,
386 };
387 self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
388 self.poll_terminate_message(cx).expect("message is set")
389 }
390
391 fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
393 self.conn
394 .inner_mut()
395 .start_disconnect(reason)
396 .map_err(P2PStreamError::from)
397 .map_err(Into::into)
398 }
399
400 fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
402 debug_assert!(self.is_disconnecting(), "not disconnecting");
403
404 let _ = ready!(self.conn.poll_close_unpin(cx));
406 self.emit_disconnect(cx)
407 }
408
409 fn try_disconnect(&mut self, reason: DisconnectReason, cx: &mut Context<'_>) -> Poll<()> {
411 match self.start_disconnect(reason) {
412 Ok(()) => {
413 self.poll_disconnect(cx)
415 }
416 Err(err) => {
417 debug!(target: "net::session", %err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
418 self.close_on_error(err, cx)
419 }
420 }
421 }
422
423 #[must_use]
432 fn check_timed_out_requests(&mut self, now: Instant) -> bool {
433 for (id, req) in &mut self.inflight_requests {
434 if req.is_timed_out(now) {
435 if req.is_waiting() {
436 debug!(target: "net::session", ?id, remote_peer_id=?self.remote_peer_id, "timed out outgoing request");
437 req.timeout();
438 } else if now - req.timestamp > self.protocol_breach_request_timeout {
439 return true
440 }
441 }
442 }
443
444 false
445 }
446
447 fn update_request_timeout(&mut self, sent: Instant, received: Instant) {
449 let elapsed = received.saturating_duration_since(sent);
450
451 let current = Duration::from_millis(self.internal_request_timeout.load(Ordering::Relaxed));
452 let request_timeout = calculate_new_timeout(current, elapsed);
453 self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
454 self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
455 }
456
457 fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
459 let (mut tx, msg) = self.terminate_message.take()?;
460 match tx.poll_reserve(cx) {
461 Poll::Pending => {
462 self.terminate_message = Some((tx, msg));
463 return Some(Poll::Pending)
464 }
465 Poll::Ready(Ok(())) => {
466 let _ = tx.send_item(msg);
467 }
468 Poll::Ready(Err(_)) => {
469 }
471 }
472 Some(Poll::Ready(()))
474 }
475}
476
477impl<N: NetworkPrimitives> Future for ActiveSession<N> {
478 type Output = ();
479
480 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
481 let this = self.get_mut();
482
483 if let Some(terminate) = this.poll_terminate_message(cx) {
485 return terminate
486 }
487
488 if this.is_disconnecting() {
489 return this.poll_disconnect(cx)
490 }
491
492 let mut budget = 4;
498
499 'main: loop {
501 let mut progress = false;
502
503 loop {
505 match this.commands_rx.poll_next_unpin(cx) {
506 Poll::Pending => break,
507 Poll::Ready(None) => {
508 return Poll::Ready(())
511 }
512 Poll::Ready(Some(cmd)) => {
513 progress = true;
514 match cmd {
515 SessionCommand::Disconnect { reason } => {
516 debug!(
517 target: "net::session",
518 ?reason,
519 remote_peer_id=?this.remote_peer_id,
520 "Received disconnect command for session"
521 );
522 let reason =
523 reason.unwrap_or(DisconnectReason::DisconnectRequested);
524
525 return this.try_disconnect(reason, cx)
526 }
527 SessionCommand::Message(msg) => {
528 this.on_internal_peer_message(msg);
529 }
530 }
531 }
532 }
533 }
534
535 let deadline = this.request_deadline();
536
537 while let Poll::Ready(Some(req)) = this.internal_request_tx.poll_next_unpin(cx) {
538 progress = true;
539 this.on_internal_peer_request(req, deadline);
540 }
541
542 for idx in (0..this.received_requests_from_remote.len()).rev() {
545 let mut req = this.received_requests_from_remote.swap_remove(idx);
546 match req.rx.poll(cx) {
547 Poll::Pending => {
548 this.received_requests_from_remote.push(req);
550 }
551 Poll::Ready(resp) => {
552 this.handle_outgoing_response(req.request_id, resp);
553 }
554 }
555 }
556
557 while this.conn.poll_ready_unpin(cx).is_ready() {
559 if let Some(msg) = this.queued_outgoing.pop_front() {
560 progress = true;
561 let res = match msg {
562 OutgoingMessage::Eth(msg) => this.conn.start_send_unpin(msg),
563 OutgoingMessage::Broadcast(msg) => this.conn.start_send_broadcast(msg),
564 OutgoingMessage::Raw(msg) => this.conn.start_send_raw(msg),
565 };
566 if let Err(err) = res {
567 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to send message");
568 return this.close_on_error(err, cx)
570 }
571 } else {
572 break
574 }
575 }
576
577 'receive: loop {
579 budget -= 1;
581 if budget == 0 {
582 cx.waker().wake_by_ref();
584 break 'main
585 }
586
587 if let Some(msg) = this.pending_message_to_session.take() {
591 match this.to_session_manager.poll_reserve(cx) {
592 Poll::Ready(Ok(_)) => {
593 let _ = this.to_session_manager.send_item(msg);
594 }
595 Poll::Ready(Err(_)) => return Poll::Ready(()),
596 Poll::Pending => {
597 this.pending_message_to_session = Some(msg);
598 break 'receive
599 }
600 };
601 }
602
603 match this.conn.poll_next_unpin(cx) {
604 Poll::Pending => break,
605 Poll::Ready(None) => {
606 if this.is_disconnecting() {
607 break
608 }
609 debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
610 return this.emit_disconnect(cx)
611 }
612 Poll::Ready(Some(res)) => {
613 match res {
614 Ok(msg) => {
615 trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
616 match this.on_incoming_message(msg) {
618 OnIncomingMessageOutcome::Ok => {
619 progress = true;
621 }
622 OnIncomingMessageOutcome::BadMessage { error, message } => {
623 debug!(target: "net::session", %error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
624 return this.close_on_error(error, cx)
625 }
626 OnIncomingMessageOutcome::NoCapacity(msg) => {
627 this.pending_message_to_session = Some(msg);
629 continue 'receive
630 }
631 }
632 }
633 Err(err) => {
634 debug!(target: "net::session", %err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
635 return this.close_on_error(err, cx)
636 }
637 }
638 }
639 }
640 }
641
642 if !progress {
643 break 'main
644 }
645 }
646
647 while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
648 if this.check_timed_out_requests(Instant::now()) {
650 if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
651 let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
652 this.pending_message_to_session = Some(msg);
653 }
654 }
655 }
656
657 this.shrink_to_fit();
658
659 Poll::Pending
660 }
661}
662
663pub(crate) struct ReceivedRequest<N: NetworkPrimitives> {
665 request_id: u64,
667 rx: PeerResponse<N>,
669 #[allow(dead_code)]
671 received: Instant,
672}
673
674pub(crate) struct InflightRequest<R> {
676 request: RequestState<R>,
678 timestamp: Instant,
680 deadline: Instant,
682}
683
684impl<N: NetworkPrimitives> InflightRequest<PeerRequest<N>> {
687 #[inline]
689 fn is_timed_out(&self, now: Instant) -> bool {
690 now > self.deadline
691 }
692
693 #[inline]
695 const fn is_waiting(&self) -> bool {
696 matches!(self.request, RequestState::Waiting(_))
697 }
698
699 fn timeout(&mut self) {
701 let mut req = RequestState::TimedOut;
702 std::mem::swap(&mut self.request, &mut req);
703
704 if let RequestState::Waiting(req) = req {
705 req.send_err_response(RequestError::Timeout);
706 }
707 }
708}
709
710enum OnIncomingMessageOutcome<N: NetworkPrimitives> {
712 Ok,
714 BadMessage { error: EthStreamError, message: EthMessage<N> },
716 NoCapacity(ActiveSessionMessage<N>),
718}
719
720impl<N: NetworkPrimitives> From<Result<(), ActiveSessionMessage<N>>>
721 for OnIncomingMessageOutcome<N>
722{
723 fn from(res: Result<(), ActiveSessionMessage<N>>) -> Self {
724 match res {
725 Ok(_) => Self::Ok,
726 Err(msg) => Self::NoCapacity(msg),
727 }
728 }
729}
730
731enum RequestState<R> {
732 Waiting(R),
734 TimedOut,
736}
737
738pub(crate) enum OutgoingMessage<N: NetworkPrimitives> {
740 Eth(EthMessage<N>),
742 Broadcast(EthBroadcastMessage<N>),
744 Raw(RawCapabilityMessage),
746}
747
748impl<N: NetworkPrimitives> From<EthMessage<N>> for OutgoingMessage<N> {
749 fn from(value: EthMessage<N>) -> Self {
750 Self::Eth(value)
751 }
752}
753
754impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for OutgoingMessage<N> {
755 fn from(value: EthBroadcastMessage<N>) -> Self {
756 Self::Broadcast(value)
757 }
758}
759
760#[inline]
762fn calculate_new_timeout(current_timeout: Duration, estimated_rtt: Duration) -> Duration {
763 let new_timeout = estimated_rtt.mul_f64(SAMPLE_IMPACT) * TIMEOUT_SCALING;
764
765 let smoothened_timeout = current_timeout.mul_f64(1.0 - SAMPLE_IMPACT) + new_timeout;
767
768 smoothened_timeout.clamp(MINIMUM_TIMEOUT, MAXIMUM_TIMEOUT)
769}
770
771pub(crate) struct QueuedOutgoingMessages<N: NetworkPrimitives> {
773 messages: VecDeque<OutgoingMessage<N>>,
774 count: Gauge,
775}
776
777impl<N: NetworkPrimitives> QueuedOutgoingMessages<N> {
778 pub(crate) const fn new(metric: Gauge) -> Self {
779 Self { messages: VecDeque::new(), count: metric }
780 }
781
782 pub(crate) fn push_back(&mut self, message: OutgoingMessage<N>) {
783 self.messages.push_back(message);
784 self.count.increment(1);
785 }
786
787 pub(crate) fn pop_front(&mut self) -> Option<OutgoingMessage<N>> {
788 self.messages.pop_front().inspect(|_| self.count.decrement(1))
789 }
790
791 pub(crate) fn shrink_to_fit(&mut self) {
792 self.messages.shrink_to_fit();
793 }
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799 use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
800 use reth_chainspec::MAINNET;
801 use reth_ecies::stream::ECIESStream;
802 use reth_eth_wire::{
803 EthNetworkPrimitives, EthStream, GetBlockBodies, HelloMessageWithProtocols, P2PStream,
804 Status, StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
805 };
806 use reth_network_peers::pk2id;
807 use reth_network_types::session::config::PROTOCOL_BREACH_REQUEST_TIMEOUT;
808 use reth_primitives::{EthereumHardfork, ForkFilter};
809 use secp256k1::{SecretKey, SECP256K1};
810 use tokio::{
811 net::{TcpListener, TcpStream},
812 sync::mpsc,
813 };
814
815 fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols {
817 HelloMessageWithProtocols::builder(pk2id(&server_key.public_key(SECP256K1))).build()
818 }
819
820 struct SessionBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
821 _remote_capabilities: Arc<Capabilities>,
822 active_session_tx: mpsc::Sender<ActiveSessionMessage<N>>,
823 active_session_rx: ReceiverStream<ActiveSessionMessage<N>>,
824 to_sessions: Vec<mpsc::Sender<SessionCommand<N>>>,
825 secret_key: SecretKey,
826 local_peer_id: PeerId,
827 hello: HelloMessageWithProtocols,
828 status: Status,
829 fork_filter: ForkFilter,
830 next_id: usize,
831 }
832
833 impl<N: NetworkPrimitives> SessionBuilder<N> {
834 fn next_id(&mut self) -> SessionId {
835 let id = self.next_id;
836 self.next_id += 1;
837 SessionId(id)
838 }
839
840 fn with_client_stream<F, O>(
842 &self,
843 local_addr: SocketAddr,
844 f: F,
845 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
846 where
847 F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>, N>) -> O + Send + 'static,
848 O: Future<Output = ()> + Send + Sync,
849 {
850 let status = self.status;
851 let fork_filter = self.fork_filter.clone();
852 let local_peer_id = self.local_peer_id;
853 let mut hello = self.hello.clone();
854 let key = SecretKey::new(&mut rand::thread_rng());
855 hello.id = pk2id(&key.public_key(SECP256K1));
856 Box::pin(async move {
857 let outgoing = TcpStream::connect(local_addr).await.unwrap();
858 let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
859
860 let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
861
862 let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
863 .handshake(status, fork_filter)
864 .await
865 .unwrap();
866 f(client_stream).await
867 })
868 }
869
870 async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession<N> {
871 let remote_addr = stream.local_addr().unwrap();
872 let session_id = self.next_id();
873 let (_disconnect_tx, disconnect_rx) = oneshot::channel();
874 let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
875
876 tokio::task::spawn(start_pending_incoming_session(
877 disconnect_rx,
878 session_id,
879 stream,
880 pending_sessions_tx,
881 remote_addr,
882 self.secret_key,
883 self.hello.clone(),
884 self.status,
885 self.fork_filter.clone(),
886 Default::default(),
887 ));
888
889 let mut stream = ReceiverStream::new(pending_sessions_rx);
890
891 match stream.next().await.unwrap() {
892 PendingSessionEvent::Established {
893 session_id,
894 remote_addr,
895 peer_id,
896 capabilities,
897 conn,
898 ..
899 } => {
900 let (_to_session_tx, messages_rx) = mpsc::channel(10);
901 let (commands_to_session, commands_rx) = mpsc::channel(10);
902 let poll_sender = PollSender::new(self.active_session_tx.clone());
903
904 self.to_sessions.push(commands_to_session);
905
906 ActiveSession {
907 next_id: 0,
908 remote_peer_id: peer_id,
909 remote_addr,
910 remote_capabilities: Arc::clone(&capabilities),
911 session_id,
912 commands_rx: ReceiverStream::new(commands_rx),
913 to_session_manager: MeteredPollSender::new(
914 poll_sender,
915 "network_active_session",
916 ),
917 pending_message_to_session: None,
918 internal_request_tx: ReceiverStream::new(messages_rx).fuse(),
919 inflight_requests: Default::default(),
920 conn,
921 queued_outgoing: QueuedOutgoingMessages::new(Gauge::noop()),
922 received_requests_from_remote: Default::default(),
923 internal_request_timeout_interval: tokio::time::interval(
924 INITIAL_REQUEST_TIMEOUT,
925 ),
926 internal_request_timeout: Arc::new(AtomicU64::new(
927 INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
928 )),
929 protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
930 terminate_message: None,
931 }
932 }
933 ev => {
934 panic!("unexpected message {ev:?}")
935 }
936 }
937 }
938 }
939
940 impl Default for SessionBuilder {
941 fn default() -> Self {
942 let (active_session_tx, active_session_rx) = mpsc::channel(100);
943
944 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand::thread_rng());
945 let local_peer_id = pk2id(&pk);
946
947 Self {
948 next_id: 0,
949 _remote_capabilities: Arc::new(Capabilities::from(vec![])),
950 active_session_tx,
951 active_session_rx: ReceiverStream::new(active_session_rx),
952 to_sessions: vec![],
953 hello: eth_hello(&secret_key),
954 secret_key,
955 local_peer_id,
956 status: StatusBuilder::default().build(),
957 fork_filter: MAINNET
958 .hardfork_fork_filter(EthereumHardfork::Frontier)
959 .expect("The Frontier fork filter should exist on mainnet"),
960 }
961 }
962 }
963
964 #[tokio::test(flavor = "multi_thread")]
965 async fn test_disconnect() {
966 let mut builder = SessionBuilder::default();
967
968 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
969 let local_addr = listener.local_addr().unwrap();
970
971 let expected_disconnect = DisconnectReason::UselessPeer;
972
973 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
974 let msg = client_stream.next().await.unwrap().unwrap_err();
975 assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
976 });
977
978 tokio::task::spawn(async move {
979 let (incoming, _) = listener.accept().await.unwrap();
980 let mut session = builder.connect_incoming(incoming).await;
981
982 session.start_disconnect(expected_disconnect).unwrap();
983 session.await
984 });
985
986 fut.await;
987 }
988
989 #[tokio::test(flavor = "multi_thread")]
990 async fn handle_dropped_stream() {
991 let mut builder = SessionBuilder::default();
992
993 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
994 let local_addr = listener.local_addr().unwrap();
995
996 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
997 drop(client_stream);
998 tokio::time::sleep(Duration::from_secs(1)).await
999 });
1000
1001 let (tx, rx) = oneshot::channel();
1002
1003 tokio::task::spawn(async move {
1004 let (incoming, _) = listener.accept().await.unwrap();
1005 let session = builder.connect_incoming(incoming).await;
1006 session.await;
1007
1008 tx.send(()).unwrap();
1009 });
1010
1011 tokio::task::spawn(fut);
1012
1013 rx.await.unwrap();
1014 }
1015
1016 #[tokio::test(flavor = "multi_thread")]
1017 async fn test_send_many_messages() {
1018 reth_tracing::init_test_tracing();
1019 let mut builder = SessionBuilder::default();
1020
1021 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1022 let local_addr = listener.local_addr().unwrap();
1023
1024 let num_messages = 100;
1025
1026 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1027 for _ in 0..num_messages {
1028 client_stream
1029 .send(EthMessage::NewPooledTransactionHashes66(Vec::new().into()))
1030 .await
1031 .unwrap();
1032 }
1033 });
1034
1035 let (tx, rx) = oneshot::channel();
1036
1037 tokio::task::spawn(async move {
1038 let (incoming, _) = listener.accept().await.unwrap();
1039 let session = builder.connect_incoming(incoming).await;
1040 session.await;
1041
1042 tx.send(()).unwrap();
1043 });
1044
1045 tokio::task::spawn(fut);
1046
1047 rx.await.unwrap();
1048 }
1049
1050 #[tokio::test(flavor = "multi_thread")]
1051 async fn test_request_timeout() {
1052 reth_tracing::init_test_tracing();
1053
1054 let mut builder = SessionBuilder::default();
1055
1056 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1057 let local_addr = listener.local_addr().unwrap();
1058
1059 let request_timeout = Duration::from_millis(100);
1060 let drop_timeout = Duration::from_millis(1500);
1061
1062 let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
1063 let _client_stream = client_stream;
1064 tokio::time::sleep(drop_timeout * 60).await;
1065 });
1066 tokio::task::spawn(fut);
1067
1068 let (incoming, _) = listener.accept().await.unwrap();
1069 let mut session = builder.connect_incoming(incoming).await;
1070 session
1071 .internal_request_timeout
1072 .store(request_timeout.as_millis() as u64, Ordering::Relaxed);
1073 session.protocol_breach_request_timeout = drop_timeout;
1074 session.internal_request_timeout_interval =
1075 tokio::time::interval_at(tokio::time::Instant::now(), request_timeout);
1076 let (tx, rx) = oneshot::channel();
1077 let req = PeerRequest::GetBlockBodies { request: GetBlockBodies(vec![]), response: tx };
1078 session.on_internal_peer_request(req, Instant::now());
1079 tokio::spawn(session);
1080
1081 let err = rx.await.unwrap().unwrap_err();
1082 assert_eq!(err, RequestError::Timeout);
1083
1084 let msg = builder.active_session_rx.next().await.unwrap();
1086 match msg {
1087 ActiveSessionMessage::ProtocolBreach { .. } => {}
1088 ev => unreachable!("{ev:?}"),
1089 }
1090 }
1091
1092 #[tokio::test(flavor = "multi_thread")]
1093 async fn test_keep_alive() {
1094 let mut builder = SessionBuilder::default();
1095
1096 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1097 let local_addr = listener.local_addr().unwrap();
1098
1099 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1100 let _ = tokio::time::timeout(Duration::from_secs(5), client_stream.next()).await;
1101 client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
1102 });
1103
1104 let (tx, rx) = oneshot::channel();
1105
1106 tokio::task::spawn(async move {
1107 let (incoming, _) = listener.accept().await.unwrap();
1108 let session = builder.connect_incoming(incoming).await;
1109 session.await;
1110
1111 tx.send(()).unwrap();
1112 });
1113
1114 tokio::task::spawn(fut);
1115
1116 rx.await.unwrap();
1117 }
1118
1119 #[tokio::test(flavor = "multi_thread")]
1121 async fn test_send_at_capacity() {
1122 let mut builder = SessionBuilder::default();
1123
1124 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1125 let local_addr = listener.local_addr().unwrap();
1126
1127 let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
1128 client_stream
1129 .send(EthMessage::NewPooledTransactionHashes68(Default::default()))
1130 .await
1131 .unwrap();
1132 let _ = tokio::time::timeout(Duration::from_secs(100), client_stream.next()).await;
1133 });
1134 tokio::task::spawn(fut);
1135
1136 let (incoming, _) = listener.accept().await.unwrap();
1137 let session = builder.connect_incoming(incoming).await;
1138
1139 let mut num_fill_messages = 0;
1141 loop {
1142 if builder
1143 .active_session_tx
1144 .try_send(ActiveSessionMessage::ProtocolBreach { peer_id: PeerId::random() })
1145 .is_err()
1146 {
1147 break
1148 }
1149 num_fill_messages += 1;
1150 }
1151
1152 tokio::task::spawn(async move {
1153 session.await;
1154 });
1155
1156 tokio::time::sleep(Duration::from_millis(100)).await;
1157
1158 for _ in 0..num_fill_messages {
1159 let message = builder.active_session_rx.next().await.unwrap();
1160 match message {
1161 ActiveSessionMessage::ProtocolBreach { .. } => {}
1162 ev => unreachable!("{ev:?}"),
1163 }
1164 }
1165
1166 let message = builder.active_session_rx.next().await.unwrap();
1167 match message {
1168 ActiveSessionMessage::ValidMessage {
1169 message: PeerMessage::PooledTransactions(_),
1170 ..
1171 } => {}
1172 _ => unreachable!(),
1173 }
1174 }
1175
1176 #[test]
1177 fn timeout_calculation_sanity_tests() {
1178 let rtt = Duration::from_secs(5);
1179 let timeout = rtt * TIMEOUT_SCALING;
1181
1182 assert_eq!(calculate_new_timeout(timeout, rtt), timeout);
1184
1185 assert!(calculate_new_timeout(timeout, rtt / 2) < timeout);
1187 assert!(calculate_new_timeout(timeout, rtt / 2) > timeout / 2);
1188 assert!(calculate_new_timeout(timeout, rtt * 2) > timeout);
1189 assert!(calculate_new_timeout(timeout, rtt * 2) < timeout * 2);
1190 }
1191}