1use super::message::MAX_MESSAGE_SIZE;
7use crate::{
8 message::{EthBroadcastMessage, ProtocolBroadcastMessage},
9 EthMessage, EthMessageID, EthNetworkPrimitives, EthVersion, NetworkPrimitives, ProtocolMessage,
10 RawCapabilityMessage, SnapMessageId, SnapProtocolMessage,
11};
12use alloy_rlp::{Bytes, BytesMut, Encodable};
13use core::fmt::Debug;
14use futures::{Sink, SinkExt};
15use pin_project::pin_project;
16use std::{
17 marker::PhantomData,
18 pin::Pin,
19 task::{ready, Context, Poll},
20};
21use tokio_stream::Stream;
22
23#[derive(thiserror::Error, Debug)]
25pub enum EthSnapStreamError {
26 #[error("invalid message for version {0:?}: {1}")]
28 InvalidMessage(EthVersion, String),
29
30 #[error("unknown message id: {0}")]
32 UnknownMessageId(u8),
33
34 #[error("message too large: {0} > {1}")]
36 MessageTooLarge(usize, usize),
37
38 #[error("rlp error: {0}")]
40 Rlp(#[from] alloy_rlp::Error),
41
42 #[error("status message received outside handshake")]
44 StatusNotInHandshake,
45}
46
47#[derive(Debug)]
49pub enum EthSnapMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
50 Eth(EthMessage<N>),
52 Snap(SnapProtocolMessage),
54}
55
56#[pin_project]
59#[derive(Debug, Clone)]
60pub struct EthSnapStream<S, N = EthNetworkPrimitives> {
61 eth_snap: EthSnapStreamInner<N>,
63 #[pin]
65 inner: S,
66}
67
68impl<S, N> EthSnapStream<S, N>
69where
70 N: NetworkPrimitives,
71{
72 pub const fn new(stream: S, eth_version: EthVersion) -> Self {
74 Self { eth_snap: EthSnapStreamInner::new(eth_version), inner: stream }
75 }
76
77 #[inline]
79 pub const fn eth_version(&self) -> EthVersion {
80 self.eth_snap.eth_version()
81 }
82
83 #[inline]
85 pub const fn inner(&self) -> &S {
86 &self.inner
87 }
88
89 #[inline]
91 pub const fn inner_mut(&mut self) -> &mut S {
92 &mut self.inner
93 }
94
95 #[inline]
97 pub fn into_inner(self) -> S {
98 self.inner
99 }
100}
101
102impl<S, E, N> EthSnapStream<S, N>
103where
104 S: Sink<Bytes, Error = E> + Unpin,
105 EthSnapStreamError: From<E>,
106 N: NetworkPrimitives,
107{
108 pub fn start_send_broadcast(
110 &mut self,
111 item: EthBroadcastMessage<N>,
112 ) -> Result<(), EthSnapStreamError> {
113 self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
114 ProtocolBroadcastMessage::from(item),
115 )))?;
116
117 Ok(())
118 }
119
120 pub fn start_send_raw(&mut self, msg: RawCapabilityMessage) -> Result<(), EthSnapStreamError> {
122 let mut bytes = Vec::with_capacity(msg.payload.len() + 1);
123 msg.id.encode(&mut bytes);
124 bytes.extend_from_slice(&msg.payload);
125
126 self.inner.start_send_unpin(bytes.into())?;
127 Ok(())
128 }
129}
130
131impl<S, E, N> Stream for EthSnapStream<S, N>
132where
133 S: Stream<Item = Result<BytesMut, E>> + Unpin,
134 EthSnapStreamError: From<E>,
135 N: NetworkPrimitives,
136{
137 type Item = Result<EthSnapMessage<N>, EthSnapStreamError>;
138
139 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 let this = self.project();
141 let res = ready!(this.inner.poll_next(cx));
142
143 match res {
144 Some(Ok(bytes)) => Poll::Ready(Some(this.eth_snap.decode_message(bytes))),
145 Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
146 None => Poll::Ready(None),
147 }
148 }
149}
150
151impl<S, E, N> Sink<EthSnapMessage<N>> for EthSnapStream<S, N>
152where
153 S: Sink<Bytes, Error = E> + Unpin,
154 EthSnapStreamError: From<E>,
155 N: NetworkPrimitives,
156{
157 type Error = EthSnapStreamError;
158
159 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160 self.project().inner.poll_ready(cx).map_err(Into::into)
161 }
162
163 fn start_send(mut self: Pin<&mut Self>, item: EthSnapMessage<N>) -> Result<(), Self::Error> {
164 let mut this = self.as_mut().project();
165
166 let bytes = match item {
167 EthSnapMessage::Eth(eth_msg) => this.eth_snap.encode_eth_message(eth_msg)?,
168 EthSnapMessage::Snap(snap_msg) => this.eth_snap.encode_snap_message(snap_msg),
169 };
170
171 this.inner.start_send_unpin(bytes)?;
172 Ok(())
173 }
174
175 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176 self.project().inner.poll_flush(cx).map_err(Into::into)
177 }
178
179 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180 self.project().inner.poll_close(cx).map_err(Into::into)
181 }
182}
183
184#[derive(Debug, Clone)]
188struct EthSnapStreamInner<N> {
189 eth_version: EthVersion,
191 _pd: PhantomData<N>,
193}
194
195impl<N> EthSnapStreamInner<N>
196where
197 N: NetworkPrimitives,
198{
199 const fn new(eth_version: EthVersion) -> Self {
201 Self { eth_version, _pd: PhantomData }
202 }
203
204 #[inline]
205 const fn eth_version(&self) -> EthVersion {
206 self.eth_version
207 }
208
209 fn decode_message(&self, bytes: BytesMut) -> Result<EthSnapMessage<N>, EthSnapStreamError> {
211 if bytes.len() > MAX_MESSAGE_SIZE {
212 return Err(EthSnapStreamError::MessageTooLarge(bytes.len(), MAX_MESSAGE_SIZE));
213 }
214
215 if bytes.is_empty() {
216 return Err(EthSnapStreamError::Rlp(alloy_rlp::Error::InputTooShort));
217 }
218
219 let message_id = bytes[0];
220
221 if message_id <= EthMessageID::max(self.eth_version) {
227 let mut buf = bytes.as_ref();
228 match ProtocolMessage::decode_message(self.eth_version, &mut buf) {
229 Ok(protocol_msg) => {
230 if matches!(protocol_msg.message, EthMessage::Status(_)) {
231 return Err(EthSnapStreamError::StatusNotInHandshake);
232 }
233 Ok(EthSnapMessage::Eth(protocol_msg.message))
234 }
235 Err(err) => {
236 Err(EthSnapStreamError::InvalidMessage(self.eth_version, err.to_string()))
237 }
238 }
239 } else if message_id > EthMessageID::max(self.eth_version) &&
240 message_id <=
241 EthMessageID::max(self.eth_version) + 1 + SnapMessageId::TrieNodes as u8
242 {
243 let adjusted_message_id = message_id - (EthMessageID::max(self.eth_version) + 1);
250 let mut buf = &bytes[1..];
251
252 match SnapProtocolMessage::decode(adjusted_message_id, &mut buf) {
253 Ok(snap_msg) => Ok(EthSnapMessage::Snap(snap_msg)),
254 Err(err) => Err(EthSnapStreamError::Rlp(err)),
255 }
256 } else {
257 Err(EthSnapStreamError::UnknownMessageId(message_id))
258 }
259 }
260
261 fn encode_eth_message(&self, item: EthMessage<N>) -> Result<Bytes, EthSnapStreamError> {
263 if matches!(item, EthMessage::Status(_)) {
264 return Err(EthSnapStreamError::StatusNotInHandshake);
265 }
266
267 let protocol_msg = ProtocolMessage::from(item);
268 let mut buf = Vec::new();
269 protocol_msg.encode(&mut buf);
270 Ok(Bytes::from(buf))
271 }
272
273 fn encode_snap_message(&self, message: SnapProtocolMessage) -> Bytes {
276 let encoded = message.encode();
277
278 let message_id = encoded[0];
279 let adjusted_id = message_id + EthMessageID::max(self.eth_version) + 1;
280
281 let mut adjusted = Vec::with_capacity(encoded.len());
282 adjusted.push(adjusted_id);
283 adjusted.extend_from_slice(&encoded[1..]);
284
285 Bytes::from(adjusted)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::{EthMessage, SnapProtocolMessage};
293 use alloy_eips::BlockHashOrNumber;
294 use alloy_primitives::B256;
295 use alloy_rlp::Encodable;
296 use reth_eth_wire_types::{
297 message::RequestPair, GetAccountRangeMessage, GetBlockHeaders, HeadersDirection,
298 };
299
300 fn create_eth_message() -> (EthMessage<EthNetworkPrimitives>, BytesMut) {
302 let eth_msg = EthMessage::<EthNetworkPrimitives>::GetBlockHeaders(RequestPair {
303 request_id: 1,
304 message: GetBlockHeaders {
305 start_block: BlockHashOrNumber::Number(1),
306 limit: 10,
307 skip: 0,
308 direction: HeadersDirection::Rising,
309 },
310 });
311
312 let protocol_msg = ProtocolMessage::from(eth_msg.clone());
313 let mut buf = Vec::new();
314 protocol_msg.encode(&mut buf);
315
316 (eth_msg, BytesMut::from(&buf[..]))
317 }
318
319 fn create_snap_message() -> (SnapProtocolMessage, BytesMut) {
321 let snap_msg = SnapProtocolMessage::GetAccountRange(GetAccountRangeMessage {
322 request_id: 1,
323 root_hash: B256::default(),
324 starting_hash: B256::default(),
325 limit_hash: B256::default(),
326 response_bytes: 1000,
327 });
328
329 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
330 let encoded = inner.encode_snap_message(snap_msg.clone());
331
332 (snap_msg, BytesMut::from(&encoded[..]))
333 }
334
335 #[test]
336 fn test_eth_message_roundtrip() {
337 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
338 let (eth_msg, eth_bytes) = create_eth_message();
339
340 let encoded_result = inner.encode_eth_message(eth_msg.clone());
342 assert!(encoded_result.is_ok());
343
344 let decoded_result = inner.decode_message(eth_bytes.clone());
346 assert!(matches!(decoded_result, Ok(EthSnapMessage::Eth(_))));
347
348 if let Ok(EthSnapMessage::Eth(decoded_msg)) = inner.decode_message(eth_bytes) {
350 assert_eq!(decoded_msg, eth_msg);
351
352 let re_encoded = inner.encode_eth_message(decoded_msg.clone()).unwrap();
353 let re_encoded_bytes = BytesMut::from(&re_encoded[..]);
354 let re_decoded = inner.decode_message(re_encoded_bytes);
355
356 assert!(matches!(re_decoded, Ok(EthSnapMessage::Eth(_))));
357 if let Ok(EthSnapMessage::Eth(final_msg)) = re_decoded {
358 assert_eq!(final_msg, decoded_msg);
359 }
360 }
361 }
362
363 #[test]
364 fn test_snap_protocol() {
365 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
366 let (snap_msg, snap_bytes) = create_snap_message();
367
368 let encoded_bytes = inner.encode_snap_message(snap_msg.clone());
370 assert!(!encoded_bytes.is_empty());
371
372 let decoded_result = inner.decode_message(snap_bytes.clone());
374 assert!(matches!(decoded_result, Ok(EthSnapMessage::Snap(_))));
375
376 if let Ok(EthSnapMessage::Snap(decoded_msg)) = inner.decode_message(snap_bytes) {
378 assert_eq!(decoded_msg, snap_msg);
379
380 let encoded = inner.encode_snap_message(decoded_msg.clone());
382
383 let re_encoded_bytes = BytesMut::from(&encoded[..]);
384
385 let re_decoded = inner.decode_message(re_encoded_bytes);
387
388 assert!(matches!(re_decoded, Ok(EthSnapMessage::Snap(_))));
389 if let Ok(EthSnapMessage::Snap(final_msg)) = re_decoded {
390 assert_eq!(final_msg, decoded_msg);
391 }
392 }
393 }
394
395 #[test]
396 fn test_message_id_boundaries() {
397 let inner = EthSnapStreamInner::<EthNetworkPrimitives>::new(EthVersion::Eth67);
398
399 let eth_max_id = EthMessageID::max(EthVersion::Eth67);
401 let mut eth_boundary_bytes = BytesMut::new();
402 eth_boundary_bytes.extend_from_slice(&[eth_max_id]);
403 eth_boundary_bytes.extend_from_slice(&[0, 0]);
404
405 let eth_boundary_result = inner.decode_message(eth_boundary_bytes);
407 assert!(
408 eth_boundary_result.is_err() ||
409 matches!(eth_boundary_result, Ok(EthSnapMessage::Eth(_)))
410 );
411
412 let snap_min_id = eth_max_id + 1;
414 let mut snap_boundary_bytes = BytesMut::new();
415 snap_boundary_bytes.extend_from_slice(&[snap_min_id]);
416 snap_boundary_bytes.extend_from_slice(&[0, 0]);
417
418 let snap_boundary_result = inner.decode_message(snap_boundary_bytes);
420 assert!(snap_boundary_result.is_err());
421 }
422}