1use super::{
10 broadcast::NewBlockHashes, BlockBodies, BlockHeaders, GetBlockBodies, GetBlockHeaders,
11 GetNodeData, GetPooledTransactions, GetReceipts, NewBlock, NewPooledTransactionHashes66,
12 NewPooledTransactionHashes68, NodeData, PooledTransactions, Receipts, Status, Transactions,
13};
14use crate::{EthNetworkPrimitives, EthVersion, NetworkPrimitives, SharedTransactions};
15use alloy_primitives::bytes::{Buf, BufMut};
16use alloy_rlp::{length_of_length, Decodable, Encodable, Header};
17use std::{fmt::Debug, sync::Arc};
18
19pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
22
23#[derive(thiserror::Error, Debug)]
25pub enum MessageError {
26 #[error("message id {1:?} is invalid for version {0:?}")]
28 Invalid(EthVersion, EthMessageID),
29 #[error("RLP error: {0}")]
31 RlpError(#[from] alloy_rlp::Error),
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct ProtocolMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
38 pub message_type: EthMessageID,
40 #[cfg_attr(
42 feature = "serde",
43 serde(bound = "EthMessage<N>: serde::Serialize + serde::de::DeserializeOwned")
44 )]
45 pub message: EthMessage<N>,
46}
47
48impl<N: NetworkPrimitives> ProtocolMessage<N> {
49 pub fn decode_message(version: EthVersion, buf: &mut &[u8]) -> Result<Self, MessageError> {
51 let message_type = EthMessageID::decode(buf)?;
52
53 let message = match message_type {
54 EthMessageID::Status => EthMessage::Status(Status::decode(buf)?),
55 EthMessageID::NewBlockHashes => {
56 if version.is_eth69() {
57 return Err(MessageError::Invalid(version, EthMessageID::NewBlockHashes));
58 }
59 EthMessage::NewBlockHashes(NewBlockHashes::decode(buf)?)
60 }
61 EthMessageID::NewBlock => {
62 if version.is_eth69() {
63 return Err(MessageError::Invalid(version, EthMessageID::NewBlock));
64 }
65 EthMessage::NewBlock(Box::new(NewBlock::decode(buf)?))
66 }
67 EthMessageID::Transactions => EthMessage::Transactions(Transactions::decode(buf)?),
68 EthMessageID::NewPooledTransactionHashes => {
69 if version >= EthVersion::Eth68 {
70 EthMessage::NewPooledTransactionHashes68(NewPooledTransactionHashes68::decode(
71 buf,
72 )?)
73 } else {
74 EthMessage::NewPooledTransactionHashes66(NewPooledTransactionHashes66::decode(
75 buf,
76 )?)
77 }
78 }
79 EthMessageID::GetBlockHeaders => EthMessage::GetBlockHeaders(RequestPair::decode(buf)?),
80 EthMessageID::BlockHeaders => EthMessage::BlockHeaders(RequestPair::decode(buf)?),
81 EthMessageID::GetBlockBodies => EthMessage::GetBlockBodies(RequestPair::decode(buf)?),
82 EthMessageID::BlockBodies => EthMessage::BlockBodies(RequestPair::decode(buf)?),
83 EthMessageID::GetPooledTransactions => {
84 EthMessage::GetPooledTransactions(RequestPair::decode(buf)?)
85 }
86 EthMessageID::PooledTransactions => {
87 EthMessage::PooledTransactions(RequestPair::decode(buf)?)
88 }
89 EthMessageID::GetNodeData => {
90 if version >= EthVersion::Eth67 {
91 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
92 }
93 EthMessage::GetNodeData(RequestPair::decode(buf)?)
94 }
95 EthMessageID::NodeData => {
96 if version >= EthVersion::Eth67 {
97 return Err(MessageError::Invalid(version, EthMessageID::GetNodeData))
98 }
99 EthMessage::NodeData(RequestPair::decode(buf)?)
100 }
101 EthMessageID::GetReceipts => EthMessage::GetReceipts(RequestPair::decode(buf)?),
102 EthMessageID::Receipts => EthMessage::Receipts(RequestPair::decode(buf)?),
103 };
104 Ok(Self { message_type, message })
105 }
106}
107
108impl<N: NetworkPrimitives> Encodable for ProtocolMessage<N> {
109 fn encode(&self, out: &mut dyn BufMut) {
112 self.message_type.encode(out);
113 self.message.encode(out);
114 }
115 fn length(&self) -> usize {
116 self.message_type.length() + self.message.length()
117 }
118}
119
120impl<N: NetworkPrimitives> From<EthMessage<N>> for ProtocolMessage<N> {
121 fn from(message: EthMessage<N>) -> Self {
122 Self { message_type: message.message_id(), message }
123 }
124}
125
126#[derive(Clone, Debug)]
128pub struct ProtocolBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
129 pub message_type: EthMessageID,
131 pub message: EthBroadcastMessage<N>,
134}
135
136impl<N: NetworkPrimitives> Encodable for ProtocolBroadcastMessage<N> {
137 fn encode(&self, out: &mut dyn BufMut) {
140 self.message_type.encode(out);
141 self.message.encode(out);
142 }
143 fn length(&self) -> usize {
144 self.message_type.length() + self.message.length()
145 }
146}
147
148impl<N: NetworkPrimitives> From<EthBroadcastMessage<N>> for ProtocolBroadcastMessage<N> {
149 fn from(message: EthBroadcastMessage<N>) -> Self {
150 Self { message_type: message.message_id(), message }
151 }
152}
153
154#[derive(Clone, Debug, PartialEq, Eq)]
172#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
173pub enum EthMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
174 Status(Status),
176 NewBlockHashes(NewBlockHashes),
178 #[cfg_attr(
180 feature = "serde",
181 serde(bound = "N::Block: serde::Serialize + serde::de::DeserializeOwned")
182 )]
183 NewBlock(Box<NewBlock<N::Block>>),
184 #[cfg_attr(
186 feature = "serde",
187 serde(bound = "N::BroadcastedTransaction: serde::Serialize + serde::de::DeserializeOwned")
188 )]
189 Transactions(Transactions<N::BroadcastedTransaction>),
190 NewPooledTransactionHashes66(NewPooledTransactionHashes66),
192 NewPooledTransactionHashes68(NewPooledTransactionHashes68),
194 GetBlockHeaders(RequestPair<GetBlockHeaders>),
197 #[cfg_attr(
199 feature = "serde",
200 serde(bound = "N::BlockHeader: serde::Serialize + serde::de::DeserializeOwned")
201 )]
202 BlockHeaders(RequestPair<BlockHeaders<N::BlockHeader>>),
203 GetBlockBodies(RequestPair<GetBlockBodies>),
205 #[cfg_attr(
207 feature = "serde",
208 serde(bound = "N::BlockBody: serde::Serialize + serde::de::DeserializeOwned")
209 )]
210 BlockBodies(RequestPair<BlockBodies<N::BlockBody>>),
211 GetPooledTransactions(RequestPair<GetPooledTransactions>),
213 #[cfg_attr(
215 feature = "serde",
216 serde(bound = "N::PooledTransaction: serde::Serialize + serde::de::DeserializeOwned")
217 )]
218 PooledTransactions(RequestPair<PooledTransactions<N::PooledTransaction>>),
219 GetNodeData(RequestPair<GetNodeData>),
221 NodeData(RequestPair<NodeData>),
223 GetReceipts(RequestPair<GetReceipts>),
225 Receipts(RequestPair<Receipts>),
227}
228
229impl<N: NetworkPrimitives> EthMessage<N> {
230 pub const fn message_id(&self) -> EthMessageID {
232 match self {
233 Self::Status(_) => EthMessageID::Status,
234 Self::NewBlockHashes(_) => EthMessageID::NewBlockHashes,
235 Self::NewBlock(_) => EthMessageID::NewBlock,
236 Self::Transactions(_) => EthMessageID::Transactions,
237 Self::NewPooledTransactionHashes66(_) | Self::NewPooledTransactionHashes68(_) => {
238 EthMessageID::NewPooledTransactionHashes
239 }
240 Self::GetBlockHeaders(_) => EthMessageID::GetBlockHeaders,
241 Self::BlockHeaders(_) => EthMessageID::BlockHeaders,
242 Self::GetBlockBodies(_) => EthMessageID::GetBlockBodies,
243 Self::BlockBodies(_) => EthMessageID::BlockBodies,
244 Self::GetPooledTransactions(_) => EthMessageID::GetPooledTransactions,
245 Self::PooledTransactions(_) => EthMessageID::PooledTransactions,
246 Self::GetNodeData(_) => EthMessageID::GetNodeData,
247 Self::NodeData(_) => EthMessageID::NodeData,
248 Self::GetReceipts(_) => EthMessageID::GetReceipts,
249 Self::Receipts(_) => EthMessageID::Receipts,
250 }
251 }
252}
253
254impl<N: NetworkPrimitives> Encodable for EthMessage<N> {
255 fn encode(&self, out: &mut dyn BufMut) {
256 match self {
257 Self::Status(status) => status.encode(out),
258 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.encode(out),
259 Self::NewBlock(new_block) => new_block.encode(out),
260 Self::Transactions(transactions) => transactions.encode(out),
261 Self::NewPooledTransactionHashes66(hashes) => hashes.encode(out),
262 Self::NewPooledTransactionHashes68(hashes) => hashes.encode(out),
263 Self::GetBlockHeaders(request) => request.encode(out),
264 Self::BlockHeaders(headers) => headers.encode(out),
265 Self::GetBlockBodies(request) => request.encode(out),
266 Self::BlockBodies(bodies) => bodies.encode(out),
267 Self::GetPooledTransactions(request) => request.encode(out),
268 Self::PooledTransactions(transactions) => transactions.encode(out),
269 Self::GetNodeData(request) => request.encode(out),
270 Self::NodeData(data) => data.encode(out),
271 Self::GetReceipts(request) => request.encode(out),
272 Self::Receipts(receipts) => receipts.encode(out),
273 }
274 }
275 fn length(&self) -> usize {
276 match self {
277 Self::Status(status) => status.length(),
278 Self::NewBlockHashes(new_block_hashes) => new_block_hashes.length(),
279 Self::NewBlock(new_block) => new_block.length(),
280 Self::Transactions(transactions) => transactions.length(),
281 Self::NewPooledTransactionHashes66(hashes) => hashes.length(),
282 Self::NewPooledTransactionHashes68(hashes) => hashes.length(),
283 Self::GetBlockHeaders(request) => request.length(),
284 Self::BlockHeaders(headers) => headers.length(),
285 Self::GetBlockBodies(request) => request.length(),
286 Self::BlockBodies(bodies) => bodies.length(),
287 Self::GetPooledTransactions(request) => request.length(),
288 Self::PooledTransactions(transactions) => transactions.length(),
289 Self::GetNodeData(request) => request.length(),
290 Self::NodeData(data) => data.length(),
291 Self::GetReceipts(request) => request.length(),
292 Self::Receipts(receipts) => receipts.length(),
293 }
294 }
295}
296
297#[derive(Clone, Debug, PartialEq, Eq)]
305pub enum EthBroadcastMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
306 NewBlock(Arc<NewBlock<N::Block>>),
308 Transactions(SharedTransactions<N::BroadcastedTransaction>),
310}
311
312impl<N: NetworkPrimitives> EthBroadcastMessage<N> {
315 pub const fn message_id(&self) -> EthMessageID {
317 match self {
318 Self::NewBlock(_) => EthMessageID::NewBlock,
319 Self::Transactions(_) => EthMessageID::Transactions,
320 }
321 }
322}
323
324impl<N: NetworkPrimitives> Encodable for EthBroadcastMessage<N> {
325 fn encode(&self, out: &mut dyn BufMut) {
326 match self {
327 Self::NewBlock(new_block) => new_block.encode(out),
328 Self::Transactions(transactions) => transactions.encode(out),
329 }
330 }
331
332 fn length(&self) -> usize {
333 match self {
334 Self::NewBlock(new_block) => new_block.length(),
335 Self::Transactions(transactions) => transactions.length(),
336 }
337 }
338}
339
340#[repr(u8)]
342#[derive(Clone, Copy, Debug, PartialEq, Eq)]
343#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
344pub enum EthMessageID {
345 Status = 0x00,
347 NewBlockHashes = 0x01,
349 Transactions = 0x02,
351 GetBlockHeaders = 0x03,
353 BlockHeaders = 0x04,
355 GetBlockBodies = 0x05,
357 BlockBodies = 0x06,
359 NewBlock = 0x07,
361 NewPooledTransactionHashes = 0x08,
363 GetPooledTransactions = 0x09,
365 PooledTransactions = 0x0a,
367 GetNodeData = 0x0d,
369 NodeData = 0x0e,
371 GetReceipts = 0x0f,
373 Receipts = 0x10,
375}
376
377impl EthMessageID {
378 pub const fn max() -> u8 {
380 Self::Receipts as u8
381 }
382}
383
384impl Encodable for EthMessageID {
385 fn encode(&self, out: &mut dyn BufMut) {
386 out.put_u8(*self as u8);
387 }
388 fn length(&self) -> usize {
389 1
390 }
391}
392
393impl Decodable for EthMessageID {
394 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
395 let id = match buf.first().ok_or(alloy_rlp::Error::InputTooShort)? {
396 0x00 => Self::Status,
397 0x01 => Self::NewBlockHashes,
398 0x02 => Self::Transactions,
399 0x03 => Self::GetBlockHeaders,
400 0x04 => Self::BlockHeaders,
401 0x05 => Self::GetBlockBodies,
402 0x06 => Self::BlockBodies,
403 0x07 => Self::NewBlock,
404 0x08 => Self::NewPooledTransactionHashes,
405 0x09 => Self::GetPooledTransactions,
406 0x0a => Self::PooledTransactions,
407 0x0d => Self::GetNodeData,
408 0x0e => Self::NodeData,
409 0x0f => Self::GetReceipts,
410 0x10 => Self::Receipts,
411 _ => return Err(alloy_rlp::Error::Custom("Invalid message ID")),
412 };
413 buf.advance(1);
414 Ok(id)
415 }
416}
417
418impl TryFrom<usize> for EthMessageID {
419 type Error = &'static str;
420
421 fn try_from(value: usize) -> Result<Self, Self::Error> {
422 match value {
423 0x00 => Ok(Self::Status),
424 0x01 => Ok(Self::NewBlockHashes),
425 0x02 => Ok(Self::Transactions),
426 0x03 => Ok(Self::GetBlockHeaders),
427 0x04 => Ok(Self::BlockHeaders),
428 0x05 => Ok(Self::GetBlockBodies),
429 0x06 => Ok(Self::BlockBodies),
430 0x07 => Ok(Self::NewBlock),
431 0x08 => Ok(Self::NewPooledTransactionHashes),
432 0x09 => Ok(Self::GetPooledTransactions),
433 0x0a => Ok(Self::PooledTransactions),
434 0x0d => Ok(Self::GetNodeData),
435 0x0e => Ok(Self::NodeData),
436 0x0f => Ok(Self::GetReceipts),
437 0x10 => Ok(Self::Receipts),
438 _ => Err("Invalid message ID"),
439 }
440 }
441}
442
443#[derive(Clone, Debug, PartialEq, Eq)]
447#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
448pub struct RequestPair<T> {
449 pub request_id: u64,
451
452 pub message: T,
454}
455
456impl<T> Encodable for RequestPair<T>
458where
459 T: Encodable,
460{
461 fn encode(&self, out: &mut dyn alloy_rlp::BufMut) {
462 let header =
463 Header { list: true, payload_length: self.request_id.length() + self.message.length() };
464
465 header.encode(out);
466 self.request_id.encode(out);
467 self.message.encode(out);
468 }
469
470 fn length(&self) -> usize {
471 let mut length = 0;
472 length += self.request_id.length();
473 length += self.message.length();
474 length += length_of_length(length);
475 length
476 }
477}
478
479impl<T> Decodable for RequestPair<T>
481where
482 T: Decodable,
483{
484 fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
485 let header = Header::decode(buf)?;
486
487 let initial_length = buf.len();
488 let request_id = u64::decode(buf)?;
489 let message = T::decode(buf)?;
490
491 let consumed_len = initial_length - buf.len();
494 if consumed_len != header.payload_length {
495 return Err(alloy_rlp::Error::UnexpectedLength)
496 }
497
498 Ok(Self { request_id, message })
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::MessageError;
505 use crate::{
506 message::RequestPair, EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion,
507 GetNodeData, NodeData, ProtocolMessage,
508 };
509 use alloy_primitives::hex;
510 use alloy_rlp::{Decodable, Encodable, Error};
511
512 fn encode<T: Encodable>(value: T) -> Vec<u8> {
513 let mut buf = vec![];
514 value.encode(&mut buf);
515 buf
516 }
517
518 #[test]
519 fn test_removed_message_at_eth67() {
520 let get_node_data = EthMessage::<EthNetworkPrimitives>::GetNodeData(RequestPair {
521 request_id: 1337,
522 message: GetNodeData(vec![]),
523 });
524 let buf = encode(ProtocolMessage {
525 message_type: EthMessageID::GetNodeData,
526 message: get_node_data,
527 });
528 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
529 crate::EthVersion::Eth67,
530 &mut &buf[..],
531 );
532 assert!(matches!(msg, Err(MessageError::Invalid(..))));
533
534 let node_data = EthMessage::<EthNetworkPrimitives>::NodeData(RequestPair {
535 request_id: 1337,
536 message: NodeData(vec![]),
537 });
538 let buf =
539 encode(ProtocolMessage { message_type: EthMessageID::NodeData, message: node_data });
540 let msg = ProtocolMessage::<EthNetworkPrimitives>::decode_message(
541 crate::EthVersion::Eth67,
542 &mut &buf[..],
543 );
544 assert!(matches!(msg, Err(MessageError::Invalid(..))));
545 }
546
547 #[test]
548 fn request_pair_encode() {
549 let request_pair = RequestPair { request_id: 1337, message: vec![5u8] };
550
551 let expected = hex!("c5820539c105");
558 let got = encode(request_pair);
559 assert_eq!(expected[..], got, "expected: {expected:X?}, got: {got:X?}",);
560 }
561
562 #[test]
563 fn request_pair_decode() {
564 let raw_pair = &hex!("c5820539c105")[..];
565
566 let expected = RequestPair { request_id: 1337, message: vec![5u8] };
567
568 let got = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair).unwrap();
569 assert_eq!(expected.length(), raw_pair.len());
570 assert_eq!(expected, got);
571 }
572
573 #[test]
574 fn malicious_request_pair_decode() {
575 let raw_pair = &hex!("c5820539c20505")[..];
585
586 let result = RequestPair::<Vec<u8>>::decode(&mut &*raw_pair);
587 assert!(matches!(result, Err(Error::UnexpectedLength)));
588 }
589
590 #[test]
591 fn empty_block_bodies_protocol() {
592 let empty_block_bodies =
593 ProtocolMessage::from(EthMessage::<EthNetworkPrimitives>::BlockBodies(RequestPair {
594 request_id: 0,
595 message: Default::default(),
596 }));
597 let mut buf = Vec::new();
598 empty_block_bodies.encode(&mut buf);
599 let decoded =
600 ProtocolMessage::decode_message(EthVersion::Eth68, &mut buf.as_slice()).unwrap();
601 assert_eq!(empty_block_bodies, decoded);
602 }
603}