1use crate::{
2 blinded::{BlindedProvider, BlindedProviderFactory, DefaultBlindedProviderFactory},
3 RevealedSparseTrie, SparseTrie,
4};
5use alloy_primitives::{
6 hex,
7 map::{B256HashMap, B256HashSet},
8 Bytes, B256,
9};
10use alloy_rlp::{Decodable, Encodable};
11use reth_execution_errors::{
12 SparseStateTrieErrorKind, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind,
13};
14use reth_primitives_traits::Account;
15use reth_tracing::tracing::trace;
16use reth_trie_common::{
17 updates::{StorageTrieUpdates, TrieUpdates},
18 MultiProof, MultiProofTargets, Nibbles, TrieAccount, TrieNode, EMPTY_ROOT_HASH,
19 TRIE_ACCOUNT_RLP_MAX_SIZE,
20};
21use std::{fmt, iter::Peekable};
22
23pub struct SparseStateTrie<F: BlindedProviderFactory = DefaultBlindedProviderFactory> {
25 provider_factory: F,
27 state: SparseTrie<F::AccountNodeProvider>,
29 storages: B256HashMap<SparseTrie<F::StorageNodeProvider>>,
31 revealed: B256HashMap<B256HashSet>,
33 retain_updates: bool,
35 account_rlp_buf: Vec<u8>,
37}
38
39impl Default for SparseStateTrie {
40 fn default() -> Self {
41 Self {
42 provider_factory: Default::default(),
43 state: Default::default(),
44 storages: Default::default(),
45 revealed: Default::default(),
46 retain_updates: false,
47 account_rlp_buf: Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE),
48 }
49 }
50}
51
52impl<P: BlindedProviderFactory> fmt::Debug for SparseStateTrie<P> {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("SparseStateTrie")
55 .field("state", &self.state)
56 .field("storages", &self.storages)
57 .field("revealed", &self.revealed)
58 .field("retain_updates", &self.retain_updates)
59 .field("account_rlp_buf", &hex::encode(&self.account_rlp_buf))
60 .finish_non_exhaustive()
61 }
62}
63
64impl SparseStateTrie {
65 pub fn from_state(state: SparseTrie) -> Self {
67 Self { state, ..Default::default() }
68 }
69}
70
71impl<F: BlindedProviderFactory> SparseStateTrie<F> {
72 pub fn new(provider_factory: F) -> Self {
74 Self {
75 provider_factory,
76 state: Default::default(),
77 storages: Default::default(),
78 revealed: Default::default(),
79 retain_updates: false,
80 account_rlp_buf: Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE),
81 }
82 }
83
84 pub const fn with_updates(mut self, retain_updates: bool) -> Self {
86 self.retain_updates = retain_updates;
87 self
88 }
89
90 pub fn is_account_revealed(&self, account: &B256) -> bool {
92 self.revealed.contains_key(account)
93 }
94
95 pub fn is_storage_slot_revealed(&self, account: &B256, slot: &B256) -> bool {
97 self.revealed.get(account).is_some_and(|slots| slots.contains(slot))
98 }
99
100 pub fn storage_trie_mut(
102 &mut self,
103 address: &B256,
104 ) -> Option<&mut RevealedSparseTrie<F::StorageNodeProvider>> {
105 self.storages.get_mut(address).and_then(|e| e.as_revealed_mut())
106 }
107
108 pub fn take_storage_trie(
110 &mut self,
111 address: &B256,
112 ) -> Option<SparseTrie<F::StorageNodeProvider>> {
113 self.storages.remove(address)
114 }
115
116 pub fn insert_storage_trie(
118 &mut self,
119 address: B256,
120 storage_trie: SparseTrie<F::StorageNodeProvider>,
121 ) {
122 self.storages.insert(address, storage_trie);
123 }
124
125 pub fn reveal_account(
131 &mut self,
132 account: B256,
133 proof: impl IntoIterator<Item = (Nibbles, Bytes)>,
134 ) -> SparseStateTrieResult<()> {
135 assert!(!self.retain_updates);
136
137 if self.is_account_revealed(&account) {
138 return Ok(());
139 }
140
141 let mut proof = proof.into_iter().peekable();
142
143 let Some(root_node) = self.validate_root_node(&mut proof)? else { return Ok(()) };
144
145 let trie = self.state.reveal_root_with_provider(
147 self.provider_factory.account_node_provider(),
148 root_node,
149 None,
150 self.retain_updates,
151 )?;
152
153 for (path, bytes) in proof {
155 let node = TrieNode::decode(&mut &bytes[..])?;
156 trie.reveal_node(path, node, None)?;
157 }
158
159 self.revealed.entry(account).or_default();
161
162 Ok(())
163 }
164
165 pub fn reveal_storage_slot(
171 &mut self,
172 account: B256,
173 slot: B256,
174 proof: impl IntoIterator<Item = (Nibbles, Bytes)>,
175 ) -> SparseStateTrieResult<()> {
176 assert!(!self.retain_updates);
177
178 if self.is_storage_slot_revealed(&account, &slot) {
179 return Ok(());
180 }
181
182 let mut proof = proof.into_iter().peekable();
183
184 let Some(root_node) = self.validate_root_node(&mut proof)? else { return Ok(()) };
185
186 let trie = self.storages.entry(account).or_default().reveal_root_with_provider(
188 self.provider_factory.storage_node_provider(account),
189 root_node,
190 None,
191 self.retain_updates,
192 )?;
193
194 for (path, bytes) in proof {
196 let node = TrieNode::decode(&mut &bytes[..])?;
197 trie.reveal_node(path, node, None)?;
198 }
199
200 self.revealed.entry(account).or_default().insert(slot);
202
203 Ok(())
204 }
205
206 pub fn reveal_multiproof(
209 &mut self,
210 targets: MultiProofTargets,
211 multiproof: MultiProof,
212 ) -> SparseStateTrieResult<()> {
213 let account_subtree = multiproof.account_subtree.into_nodes_sorted();
214 let mut account_nodes = account_subtree.into_iter().peekable();
215
216 if let Some(root_node) = self.validate_root_node(&mut account_nodes)? {
217 let trie = self.state.reveal_root_with_provider(
219 self.provider_factory.account_node_provider(),
220 root_node,
221 multiproof.branch_node_hash_masks.get(&Nibbles::default()).copied(),
222 self.retain_updates,
223 )?;
224
225 for (path, bytes) in account_nodes {
227 let node = TrieNode::decode(&mut &bytes[..])?;
228 let hash_mask = if let TrieNode::Branch(_) = node {
229 multiproof.branch_node_hash_masks.get(&path).copied()
230 } else {
231 None
232 };
233
234 trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, "Revealing account node");
235 trie.reveal_node(path, node, hash_mask)?;
236 }
237 }
238
239 for (account, storage_subtree) in multiproof.storages {
240 let subtree = storage_subtree.subtree.into_nodes_sorted();
241 let mut nodes = subtree.into_iter().peekable();
242
243 if let Some(root_node) = self.validate_root_node(&mut nodes)? {
244 let trie = self.storages.entry(account).or_default().reveal_root_with_provider(
246 self.provider_factory.storage_node_provider(account),
247 root_node,
248 storage_subtree.branch_node_hash_masks.get(&Nibbles::default()).copied(),
249 self.retain_updates,
250 )?;
251
252 for (path, bytes) in nodes {
254 let node = TrieNode::decode(&mut &bytes[..])?;
255 let hash_mask = if let TrieNode::Branch(_) = node {
256 storage_subtree.branch_node_hash_masks.get(&path).copied()
257 } else {
258 None
259 };
260
261 trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, "Revealing storage node");
262 trie.reveal_node(path, node, hash_mask)?;
263 }
264 }
265 }
266
267 for (account, slots) in targets {
268 self.revealed.entry(account).or_default().extend(slots);
269 }
270
271 Ok(())
272 }
273
274 fn validate_root_node<I: Iterator<Item = (Nibbles, Bytes)>>(
276 &self,
277 proof: &mut Peekable<I>,
278 ) -> SparseStateTrieResult<Option<TrieNode>> {
279 let mut proof = proof.into_iter().peekable();
280
281 let Some((path, node)) = proof.next() else { return Ok(None) };
283 if !path.is_empty() {
284 return Err(SparseStateTrieErrorKind::InvalidRootNode { path, node }.into())
285 }
286
287 let root_node = TrieNode::decode(&mut &node[..])?;
289 if matches!(root_node, TrieNode::EmptyRoot) && proof.peek().is_some() {
290 return Err(SparseStateTrieErrorKind::InvalidRootNode { path, node }.into())
291 }
292
293 Ok(Some(root_node))
294 }
295
296 pub fn wipe_storage(&mut self, address: B256) -> SparseStateTrieResult<()> {
298 if let Some(trie) = self.storages.get_mut(&address) {
299 trie.wipe()?;
300 }
301 Ok(())
302 }
303
304 pub fn calculate_below_level(&mut self, level: usize) {
306 self.state.calculate_below_level(level);
307 }
308
309 pub fn storage_root(&mut self, account: B256) -> Option<B256> {
311 self.storages.get_mut(&account).and_then(|trie| trie.root())
312 }
313
314 pub fn root(&mut self) -> Option<B256> {
316 self.state.root()
317 }
318
319 pub fn take_trie_updates(&mut self) -> Option<TrieUpdates> {
323 self.state.as_revealed_mut().map(|state| {
324 let updates = state.take_updates();
325 TrieUpdates {
326 account_nodes: updates.updated_nodes,
327 removed_nodes: updates.removed_nodes,
328 storage_tries: self
329 .storages
330 .iter_mut()
331 .map(|(address, trie)| {
332 let trie = trie.as_revealed_mut().unwrap();
333 let updates = trie.take_updates();
334 let updates = StorageTrieUpdates {
335 is_deleted: updates.wiped,
336 storage_nodes: updates.updated_nodes,
337 removed_nodes: updates.removed_nodes,
338 };
339 (*address, updates)
340 })
341 .filter(|(_, updates)| !updates.is_empty())
342 .collect(),
343 }
344 })
345 }
346}
347impl<F> SparseStateTrie<F>
348where
349 F: BlindedProviderFactory,
350 SparseTrieError: From<<F::AccountNodeProvider as BlindedProvider>::Error>
351 + From<<F::StorageNodeProvider as BlindedProvider>::Error>,
352{
353 pub fn update_account_leaf(
355 &mut self,
356 path: Nibbles,
357 value: Vec<u8>,
358 ) -> SparseStateTrieResult<()> {
359 self.state.update_leaf(path, value)?;
360 Ok(())
361 }
362
363 pub fn update_storage_leaf(
365 &mut self,
366 address: B256,
367 slot: Nibbles,
368 value: Vec<u8>,
369 ) -> SparseStateTrieResult<()> {
370 let storage_trie = self.storages.get_mut(&address).ok_or(SparseTrieErrorKind::Blind)?;
371 storage_trie.update_leaf(slot, value)?;
372 Ok(())
373 }
374
375 pub fn update_account(&mut self, address: B256, account: Account) -> SparseStateTrieResult<()> {
380 let nibbles = Nibbles::unpack(address);
381 let storage_root = if let Some(storage_trie) = self.storages.get_mut(&address) {
382 trace!(target: "trie::sparse", ?address, "Calculating storage root to update account");
383 storage_trie.root().ok_or(SparseTrieErrorKind::Blind)?
384 } else if self.revealed.contains_key(&address) {
385 trace!(target: "trie::sparse", ?address, "Retrieving storage root from account leaf to update account");
386 let state = self.state.as_revealed_mut().ok_or(SparseTrieErrorKind::Blind)?;
387 if let Some(value) = state.get_leaf_value(&nibbles) {
389 TrieAccount::decode(&mut &value[..])?.storage_root
391 } else {
392 EMPTY_ROOT_HASH
394 }
395 } else {
396 return Err(SparseTrieErrorKind::Blind.into())
397 };
398
399 if account.is_empty() && storage_root == EMPTY_ROOT_HASH {
400 trace!(target: "trie::sparse", ?address, "Removing account");
401 self.remove_account_leaf(&nibbles)
402 } else {
403 trace!(target: "trie::sparse", ?address, "Updating account");
404 self.account_rlp_buf.clear();
405 TrieAccount::from((account, storage_root)).encode(&mut self.account_rlp_buf);
406 self.update_account_leaf(nibbles, self.account_rlp_buf.clone())
407 }
408 }
409
410 pub fn remove_account_leaf(&mut self, path: &Nibbles) -> SparseStateTrieResult<()> {
412 self.state.remove_leaf(path)?;
413 Ok(())
414 }
415
416 pub fn remove_storage_leaf(
418 &mut self,
419 address: B256,
420 slot: &Nibbles,
421 ) -> SparseStateTrieResult<()> {
422 let storage_trie = self.storages.get_mut(&address).ok_or(SparseTrieErrorKind::Blind)?;
423 storage_trie.remove_leaf(slot)?;
424 Ok(())
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use alloy_primitives::{
432 b256,
433 map::{HashMap, HashSet},
434 Bytes, U256,
435 };
436 use alloy_rlp::EMPTY_STRING_CODE;
437 use arbitrary::Arbitrary;
438 use assert_matches::assert_matches;
439 use rand::{rngs::StdRng, Rng, SeedableRng};
440 use reth_primitives_traits::Account;
441 use reth_trie::{updates::StorageTrieUpdates, HashBuilder, TrieAccount, EMPTY_ROOT_HASH};
442 use reth_trie_common::{proof::ProofRetainer, StorageMultiProof, TrieMask};
443
444 #[test]
445 fn validate_root_node_first_node_not_root() {
446 let sparse = SparseStateTrie::default();
447 let proof = [(Nibbles::from_nibbles([0x1]), Bytes::from([EMPTY_STRING_CODE]))];
448 assert_matches!(
449 sparse.validate_root_node(&mut proof.into_iter().peekable()).map_err(|e| e.into_kind()),
450 Err(SparseStateTrieErrorKind::InvalidRootNode { .. })
451 );
452 }
453
454 #[test]
455 fn validate_root_node_invalid_proof_with_empty_root() {
456 let sparse = SparseStateTrie::default();
457 let proof = [
458 (Nibbles::default(), Bytes::from([EMPTY_STRING_CODE])),
459 (Nibbles::from_nibbles([0x1]), Bytes::new()),
460 ];
461 assert_matches!(
462 sparse.validate_root_node(&mut proof.into_iter().peekable()).map_err(|e| e.into_kind()),
463 Err(SparseStateTrieErrorKind::InvalidRootNode { .. })
464 );
465 }
466
467 #[test]
468 fn reveal_account_empty() {
469 let retainer = ProofRetainer::from_iter([Nibbles::default()]);
470 let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
471 hash_builder.root();
472 let proofs = hash_builder.take_proof_nodes();
473 assert_eq!(proofs.len(), 1);
474
475 let mut sparse = SparseStateTrie::default();
476 assert_eq!(sparse.state, SparseTrie::Blind);
477
478 sparse.reveal_account(Default::default(), proofs.into_inner()).unwrap();
479 assert_eq!(sparse.state, SparseTrie::revealed_empty());
480 }
481
482 #[test]
483 fn reveal_storage_slot_empty() {
484 let retainer = ProofRetainer::from_iter([Nibbles::default()]);
485 let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
486 hash_builder.root();
487 let proofs = hash_builder.take_proof_nodes();
488 assert_eq!(proofs.len(), 1);
489
490 let mut sparse = SparseStateTrie::default();
491 assert!(sparse.storages.is_empty());
492
493 sparse
494 .reveal_storage_slot(Default::default(), Default::default(), proofs.into_inner())
495 .unwrap();
496 assert_eq!(
497 sparse.storages,
498 HashMap::from_iter([(Default::default(), SparseTrie::revealed_empty())])
499 );
500 }
501
502 #[test]
503 fn take_trie_updates() {
504 reth_tracing::init_test_tracing();
505
506 let mut rng = StdRng::seed_from_u64(1);
508
509 let mut bytes = [0u8; 1024];
510 rng.fill(bytes.as_mut_slice());
511
512 let slot_1 = b256!("1000000000000000000000000000000000000000000000000000000000000000");
513 let slot_path_1 = Nibbles::unpack(slot_1);
514 let value_1 = U256::from(rng.gen::<u64>());
515 let slot_2 = b256!("1100000000000000000000000000000000000000000000000000000000000000");
516 let slot_path_2 = Nibbles::unpack(slot_2);
517 let value_2 = U256::from(rng.gen::<u64>());
518 let slot_3 = b256!("2000000000000000000000000000000000000000000000000000000000000000");
519 let slot_path_3 = Nibbles::unpack(slot_3);
520 let value_3 = U256::from(rng.gen::<u64>());
521
522 let mut storage_hash_builder =
523 HashBuilder::default().with_proof_retainer(ProofRetainer::from_iter([
524 slot_path_1.clone(),
525 slot_path_2.clone(),
526 ]));
527 storage_hash_builder.add_leaf(slot_path_1, &alloy_rlp::encode_fixed_size(&value_1));
528 storage_hash_builder.add_leaf(slot_path_2, &alloy_rlp::encode_fixed_size(&value_2));
529
530 let storage_root = storage_hash_builder.root();
531 let storage_proof_nodes = storage_hash_builder.take_proof_nodes();
532 let storage_branch_node_hash_masks = HashMap::from_iter([
533 (Nibbles::default(), TrieMask::new(0b010)),
534 (Nibbles::from_nibbles([0x1]), TrieMask::new(0b11)),
535 ]);
536
537 let address_1 = b256!("1000000000000000000000000000000000000000000000000000000000000000");
538 let address_path_1 = Nibbles::unpack(address_1);
539 let account_1 = Account::arbitrary(&mut arbitrary::Unstructured::new(&bytes)).unwrap();
540 let mut trie_account_1 = TrieAccount::from((account_1, storage_root));
541 let address_2 = b256!("1100000000000000000000000000000000000000000000000000000000000000");
542 let address_path_2 = Nibbles::unpack(address_2);
543 let account_2 = Account::arbitrary(&mut arbitrary::Unstructured::new(&bytes)).unwrap();
544 let mut trie_account_2 = TrieAccount::from((account_2, EMPTY_ROOT_HASH));
545
546 let mut hash_builder =
547 HashBuilder::default().with_proof_retainer(ProofRetainer::from_iter([
548 address_path_1.clone(),
549 address_path_2.clone(),
550 ]));
551 hash_builder.add_leaf(address_path_1.clone(), &alloy_rlp::encode(trie_account_1));
552 hash_builder.add_leaf(address_path_2.clone(), &alloy_rlp::encode(trie_account_2));
553
554 let root = hash_builder.root();
555 let proof_nodes = hash_builder.take_proof_nodes();
556
557 let mut sparse = SparseStateTrie::default().with_updates(true);
558 sparse
559 .reveal_multiproof(
560 HashMap::from_iter([
561 (address_1, HashSet::from_iter([slot_1, slot_2])),
562 (address_2, HashSet::from_iter([slot_1, slot_2])),
563 ])
564 .into(),
565 MultiProof {
566 account_subtree: proof_nodes,
567 branch_node_hash_masks: HashMap::from_iter([(
568 Nibbles::from_nibbles([0x1]),
569 TrieMask::new(0b00),
570 )]),
571 storages: HashMap::from_iter([
572 (
573 address_1,
574 StorageMultiProof {
575 root,
576 subtree: storage_proof_nodes.clone(),
577 branch_node_hash_masks: storage_branch_node_hash_masks.clone(),
578 },
579 ),
580 (
581 address_2,
582 StorageMultiProof {
583 root,
584 subtree: storage_proof_nodes,
585 branch_node_hash_masks: storage_branch_node_hash_masks,
586 },
587 ),
588 ]),
589 },
590 )
591 .unwrap();
592
593 assert_eq!(sparse.root(), Some(root));
594
595 let address_3 = b256!("2000000000000000000000000000000000000000000000000000000000000000");
596 let address_path_3 = Nibbles::unpack(address_3);
597 let account_3 = Account { nonce: account_1.nonce + 1, ..account_1 };
598 let trie_account_3 = TrieAccount::from((account_3, EMPTY_ROOT_HASH));
599
600 sparse.update_account_leaf(address_path_3, alloy_rlp::encode(trie_account_3)).unwrap();
601
602 sparse.update_storage_leaf(address_1, slot_path_3, alloy_rlp::encode(value_3)).unwrap();
603 trie_account_1.storage_root = sparse.storage_root(address_1).unwrap();
604 sparse.update_account_leaf(address_path_1, alloy_rlp::encode(trie_account_1)).unwrap();
605
606 sparse.wipe_storage(address_2).unwrap();
607 trie_account_2.storage_root = sparse.storage_root(address_2).unwrap();
608 sparse.update_account_leaf(address_path_2, alloy_rlp::encode(trie_account_2)).unwrap();
609
610 sparse.root();
611
612 let sparse_updates = sparse.take_trie_updates().unwrap();
613 pretty_assertions::assert_eq!(
615 sparse_updates,
616 TrieUpdates {
617 account_nodes: HashMap::default(),
618 storage_tries: HashMap::from_iter([(
619 b256!("1100000000000000000000000000000000000000000000000000000000000000"),
620 StorageTrieUpdates {
621 is_deleted: true,
622 storage_nodes: HashMap::default(),
623 removed_nodes: HashSet::default()
624 }
625 )]),
626 removed_nodes: HashSet::default()
627 }
628 );
629 }
630}