1use crate::{
4 errors::{P2PHandshakeError, P2PStreamError},
5 p2pstream::MAX_RESERVED_MESSAGE_ID,
6 protocol::{ProtoVersion, Protocol},
7 version::ParseVersionError,
8 Capability, EthMessageID, EthVersion,
9};
10use alloy_primitives::bytes::Bytes;
11use derive_more::{Deref, DerefMut};
12use reth_eth_wire_types::{EthMessage, EthNetworkPrimitives, NetworkPrimitives};
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use std::{
16 borrow::Cow,
17 collections::{BTreeSet, HashMap},
18};
19
20#[derive(Debug, Clone, Eq, PartialEq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct RawCapabilityMessage {
24 pub id: usize,
26 pub payload: Bytes,
28}
29
30impl RawCapabilityMessage {
31 pub const fn new(id: usize, payload: Bytes) -> Self {
33 Self { id, payload }
34 }
35
36 pub const fn eth(id: EthMessageID, payload: Bytes) -> Self {
42 Self::new(id as usize, payload)
43 }
44}
45
46#[derive(Debug)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50pub enum CapabilityMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
51 #[cfg_attr(
53 feature = "serde",
54 serde(bound = "EthMessage<N>: Serialize + serde::de::DeserializeOwned")
55 )]
56 Eth(EthMessage<N>),
57 Other(RawCapabilityMessage),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum SharedCapability {
69 Eth {
71 version: EthVersion,
73 offset: u8,
78 },
79 UnknownCapability {
81 cap: Capability,
83 offset: u8,
88 messages: u8,
91 },
92}
93
94impl SharedCapability {
95 pub(crate) fn new(
100 name: &str,
101 version: u8,
102 offset: u8,
103 messages: u8,
104 ) -> Result<Self, SharedCapabilityError> {
105 if offset <= MAX_RESERVED_MESSAGE_ID {
106 return Err(SharedCapabilityError::ReservedMessageIdOffset(offset))
107 }
108
109 match name {
110 "eth" => Ok(Self::eth(EthVersion::try_from(version)?, offset)),
111 _ => Ok(Self::UnknownCapability {
112 cap: Capability::new(name.to_string(), version as usize),
113 offset,
114 messages,
115 }),
116 }
117 }
118
119 pub(crate) const fn eth(version: EthVersion, offset: u8) -> Self {
121 Self::Eth { version, offset }
122 }
123
124 pub const fn capability(&self) -> Cow<'_, Capability> {
126 match self {
127 Self::Eth { version, .. } => Cow::Owned(Capability::eth(*version)),
128 Self::UnknownCapability { cap, .. } => Cow::Borrowed(cap),
129 }
130 }
131
132 #[inline]
134 pub fn name(&self) -> &str {
135 match self {
136 Self::Eth { .. } => "eth",
137 Self::UnknownCapability { cap, .. } => cap.name.as_ref(),
138 }
139 }
140
141 #[inline]
143 pub const fn is_eth(&self) -> bool {
144 matches!(self, Self::Eth { .. })
145 }
146
147 pub const fn version(&self) -> u8 {
149 match self {
150 Self::Eth { version, .. } => *version as u8,
151 Self::UnknownCapability { cap, .. } => cap.version as u8,
152 }
153 }
154
155 pub const fn eth_version(&self) -> Option<EthVersion> {
157 match self {
158 Self::Eth { version, .. } => Some(*version),
159 _ => None,
160 }
161 }
162
163 pub const fn message_id_offset(&self) -> u8 {
168 match self {
169 Self::Eth { offset, .. } | Self::UnknownCapability { offset, .. } => *offset,
170 }
171 }
172
173 pub const fn relative_message_id_offset(&self) -> u8 {
176 self.message_id_offset() - MAX_RESERVED_MESSAGE_ID - 1
177 }
178
179 pub const fn num_messages(&self) -> u8 {
181 match self {
182 Self::Eth { version: _version, .. } => EthMessageID::max() + 1,
183 Self::UnknownCapability { messages, .. } => *messages,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)]
192pub struct SharedCapabilities(Vec<SharedCapability>);
193
194impl SharedCapabilities {
195 #[inline]
197 pub fn try_new(
198 local_protocols: Vec<Protocol>,
199 peer_capabilities: Vec<Capability>,
200 ) -> Result<Self, P2PStreamError> {
201 shared_capability_offsets(local_protocols, peer_capabilities).map(Self)
202 }
203
204 #[inline]
206 pub fn iter_caps(&self) -> impl Iterator<Item = &SharedCapability> {
207 self.0.iter()
208 }
209
210 #[inline]
212 pub fn eth(&self) -> Result<&SharedCapability, P2PStreamError> {
213 self.iter_caps().find(|c| c.is_eth()).ok_or(P2PStreamError::CapabilityNotShared)
214 }
215
216 #[inline]
218 pub fn eth_version(&self) -> Result<EthVersion, P2PStreamError> {
219 self.iter_caps()
220 .find_map(SharedCapability::eth_version)
221 .ok_or(P2PStreamError::CapabilityNotShared)
222 }
223
224 #[inline]
226 pub fn contains(&self, cap: &Capability) -> bool {
227 self.find(cap).is_some()
228 }
229
230 #[inline]
232 pub fn find(&self, cap: &Capability) -> Option<&SharedCapability> {
233 self.0.iter().find(|c| c.version() == cap.version as u8 && c.name() == cap.name)
234 }
235
236 #[inline]
245 pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
246 self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
247 }
248
249 #[inline]
257 pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
258 let mut iter = self.0.iter();
259 let mut cap = iter.next()?;
260 if offset < cap.message_id_offset() {
261 return None
263 }
264
265 for next in iter {
266 if offset < next.message_id_offset() {
267 return Some(cap)
268 }
269 cap = next
270 }
271
272 Some(cap)
273 }
274
275 #[inline]
277 pub fn ensure_matching_capability(
278 &self,
279 cap: &Capability,
280 ) -> Result<&SharedCapability, UnsupportedCapabilityError> {
281 self.find(cap).ok_or_else(|| UnsupportedCapabilityError { capability: cap.clone() })
282 }
283
284 #[inline]
286 pub fn len(&self) -> usize {
287 self.0.len()
288 }
289
290 #[inline]
292 pub fn is_empty(&self) -> bool {
293 self.0.is_empty()
294 }
295}
296
297#[inline]
306pub fn shared_capability_offsets(
307 local_protocols: Vec<Protocol>,
308 peer_capabilities: Vec<Capability>,
309) -> Result<Vec<SharedCapability>, P2PStreamError> {
310 let our_capabilities =
312 local_protocols.into_iter().map(Protocol::split).collect::<HashMap<_, _>>();
313
314 let mut shared_capabilities: HashMap<_, ProtoVersion> = HashMap::default();
316
317 let mut shared_capability_names = BTreeSet::new();
329
330 for peer_capability in peer_capabilities {
332 if let Some(messages) = our_capabilities.get(&peer_capability).copied() {
334 if shared_capabilities
337 .get(&peer_capability.name)
338 .is_none_or(|v| peer_capability.version > v.version)
339 {
340 shared_capabilities.insert(
341 peer_capability.name.clone(),
342 ProtoVersion { version: peer_capability.version, messages },
343 );
344 shared_capability_names.insert(peer_capability.name);
345 }
346 }
347 }
348
349 if shared_capabilities.is_empty() {
351 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
352 }
353
354 let mut shared_with_offsets = Vec::new();
357
358 let mut offset = MAX_RESERVED_MESSAGE_ID + 1;
362 for name in shared_capability_names {
363 let proto_version = &shared_capabilities[&name];
364 let shared_capability = SharedCapability::new(
365 &name,
366 proto_version.version as u8,
367 offset,
368 proto_version.messages,
369 )?;
370 offset += shared_capability.num_messages();
371 shared_with_offsets.push(shared_capability);
372 }
373
374 if shared_with_offsets.is_empty() {
375 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
376 }
377
378 Ok(shared_with_offsets)
379}
380
381#[derive(Debug, thiserror::Error)]
383pub enum SharedCapabilityError {
384 #[error(transparent)]
386 UnsupportedVersion(#[from] ParseVersionError),
387 #[error("message id offset `{0}` is reserved")]
390 ReservedMessageIdOffset(u8),
391}
392
393#[derive(Debug, thiserror::Error)]
395#[error("unsupported capability {capability}")]
396pub struct UnsupportedCapabilityError {
397 capability: Capability,
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::{Capabilities, Capability};
404
405 #[test]
406 fn from_eth_68() {
407 let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
408
409 assert_eq!(capability.name(), "eth");
410 assert_eq!(capability.version(), 68);
411 assert_eq!(
412 capability,
413 SharedCapability::Eth {
414 version: EthVersion::Eth68,
415 offset: MAX_RESERVED_MESSAGE_ID + 1
416 }
417 );
418 }
419
420 #[test]
421 fn from_eth_67() {
422 let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
423
424 assert_eq!(capability.name(), "eth");
425 assert_eq!(capability.version(), 67);
426 assert_eq!(
427 capability,
428 SharedCapability::Eth {
429 version: EthVersion::Eth67,
430 offset: MAX_RESERVED_MESSAGE_ID + 1
431 }
432 );
433 }
434
435 #[test]
436 fn from_eth_66() {
437 let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1, 15).unwrap();
438
439 assert_eq!(capability.name(), "eth");
440 assert_eq!(capability.version(), 66);
441 assert_eq!(
442 capability,
443 SharedCapability::Eth {
444 version: EthVersion::Eth66,
445 offset: MAX_RESERVED_MESSAGE_ID + 1
446 }
447 );
448 }
449
450 #[test]
451 fn capabilities_supports_eth() {
452 let capabilities: Capabilities = vec![
453 Capability::new_static("eth", 66),
454 Capability::new_static("eth", 67),
455 Capability::new_static("eth", 68),
456 ]
457 .into();
458
459 assert!(capabilities.supports_eth());
460 assert!(capabilities.supports_eth_v66());
461 assert!(capabilities.supports_eth_v67());
462 assert!(capabilities.supports_eth_v68());
463 }
464
465 #[test]
466 fn test_peer_capability_version_zero() {
467 let cap = Capability::new_static("TestName", 0);
468 let local_capabilities: Vec<Protocol> =
469 vec![Protocol::new(cap.clone(), 0), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
470 let peer_capabilities = vec![cap.clone()];
471
472 let shared = shared_capability_offsets(local_capabilities, peer_capabilities).unwrap();
473 assert_eq!(shared.len(), 1);
474 assert_eq!(shared[0], SharedCapability::UnknownCapability { cap, offset: 16, messages: 0 })
475 }
476
477 #[test]
478 fn test_peer_lower_capability_version() {
479 let local_capabilities: Vec<Protocol> =
480 vec![EthVersion::Eth66.into(), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
481 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
482
483 let shared_capability =
484 shared_capability_offsets(local_capabilities, peer_capabilities).unwrap()[0].clone();
485
486 assert_eq!(
487 shared_capability,
488 SharedCapability::Eth {
489 version: EthVersion::Eth66,
490 offset: MAX_RESERVED_MESSAGE_ID + 1
491 }
492 )
493 }
494
495 #[test]
496 fn test_peer_capability_version_too_low() {
497 let local: Vec<Protocol> = vec![EthVersion::Eth67.into()];
498 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
499
500 let shared_capability = shared_capability_offsets(local, peer_capabilities);
501
502 assert!(matches!(
503 shared_capability,
504 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
505 ))
506 }
507
508 #[test]
509 fn test_peer_capability_version_too_high() {
510 let local_capabilities = vec![EthVersion::Eth66.into()];
511 let peer_capabilities = vec![EthVersion::Eth67.into()];
512
513 let shared_capability = shared_capability_offsets(local_capabilities, peer_capabilities);
514
515 assert!(matches!(
516 shared_capability,
517 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
518 ))
519 }
520
521 #[test]
522 fn test_find_by_offset() {
523 let local_capabilities = vec![EthVersion::Eth66.into()];
524 let peer_capabilities = vec![EthVersion::Eth66.into()];
525
526 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
527
528 let shared_eth = shared.find_by_relative_offset(0).unwrap();
529 assert_eq!(shared_eth.name(), "eth");
530
531 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
532 assert_eq!(shared_eth.name(), "eth");
533
534 assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
536 }
537
538 #[test]
539 fn test_find_by_offset_many() {
540 let cap = Capability::new_static("aaa", 1);
541 let proto = Protocol::new(cap.clone(), 5);
542 let local_capabilities = vec![proto.clone(), EthVersion::Eth66.into()];
543 let peer_capabilities = vec![cap, EthVersion::Eth66.into()];
544
545 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
546
547 let shared_eth = shared.find_by_relative_offset(0).unwrap();
548 assert_eq!(shared_eth.name(), proto.cap.name);
549
550 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
551 assert_eq!(shared_eth.name(), proto.cap.name);
552
553 let shared_eth = shared.find_by_relative_offset(4).unwrap();
555 assert_eq!(shared_eth.name(), proto.cap.name);
556 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
557 assert_eq!(shared_eth.name(), proto.cap.name);
558
559 let shared_eth = shared.find_by_relative_offset(1 + proto.messages()).unwrap();
561 assert_eq!(shared_eth.name(), "eth");
562 }
563}