reth_trie/
trie.rs

1use crate::{
2    hashed_cursor::{HashedCursorFactory, HashedStorageCursor},
3    node_iter::{TrieElement, TrieNodeIter},
4    prefix_set::{PrefixSet, TriePrefixSets},
5    progress::{IntermediateStateRootState, StateRootProgress},
6    stats::TrieTracker,
7    trie_cursor::TrieCursorFactory,
8    updates::{StorageTrieUpdates, TrieUpdates},
9    walker::TrieWalker,
10    HashBuilder, Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
11};
12use alloy_consensus::EMPTY_ROOT_HASH;
13use alloy_primitives::{keccak256, Address, B256};
14use alloy_rlp::{BufMut, Encodable};
15use reth_execution_errors::{StateRootError, StorageRootError};
16use tracing::trace;
17
18#[cfg(feature = "metrics")]
19use crate::metrics::{StateRootMetrics, TrieRootMetrics};
20
21/// `StateRoot` is used to compute the root node of a state trie.
22#[derive(Debug)]
23pub struct StateRoot<T, H> {
24    /// The factory for trie cursors.
25    pub trie_cursor_factory: T,
26    /// The factory for hashed cursors.
27    pub hashed_cursor_factory: H,
28    /// A set of prefix sets that have changed.
29    pub prefix_sets: TriePrefixSets,
30    /// Previous intermediate state.
31    previous_state: Option<IntermediateStateRootState>,
32    /// The number of updates after which the intermediate progress should be returned.
33    threshold: u64,
34    #[cfg(feature = "metrics")]
35    /// State root metrics.
36    metrics: StateRootMetrics,
37}
38
39impl<T, H> StateRoot<T, H> {
40    /// Creates [`StateRoot`] with `trie_cursor_factory` and `hashed_cursor_factory`. All other
41    /// parameters are set to reasonable defaults.
42    ///
43    /// The cursors created by given factories are then used to walk through the accounts and
44    /// calculate the state root value with.
45    pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
46        Self {
47            trie_cursor_factory,
48            hashed_cursor_factory,
49            prefix_sets: TriePrefixSets::default(),
50            previous_state: None,
51            threshold: 100_000,
52            #[cfg(feature = "metrics")]
53            metrics: StateRootMetrics::default(),
54        }
55    }
56
57    /// Set the prefix sets.
58    pub fn with_prefix_sets(mut self, prefix_sets: TriePrefixSets) -> Self {
59        self.prefix_sets = prefix_sets;
60        self
61    }
62
63    /// Set the threshold.
64    pub const fn with_threshold(mut self, threshold: u64) -> Self {
65        self.threshold = threshold;
66        self
67    }
68
69    /// Set the threshold to maximum value so that intermediate progress is not returned.
70    pub const fn with_no_threshold(mut self) -> Self {
71        self.threshold = u64::MAX;
72        self
73    }
74
75    /// Set the previously recorded intermediate state.
76    pub fn with_intermediate_state(mut self, state: Option<IntermediateStateRootState>) -> Self {
77        self.previous_state = state;
78        self
79    }
80
81    /// Set the hashed cursor factory.
82    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> StateRoot<T, HF> {
83        StateRoot {
84            trie_cursor_factory: self.trie_cursor_factory,
85            hashed_cursor_factory,
86            prefix_sets: self.prefix_sets,
87            threshold: self.threshold,
88            previous_state: self.previous_state,
89            #[cfg(feature = "metrics")]
90            metrics: self.metrics,
91        }
92    }
93
94    /// Set the trie cursor factory.
95    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StateRoot<TF, H> {
96        StateRoot {
97            trie_cursor_factory,
98            hashed_cursor_factory: self.hashed_cursor_factory,
99            prefix_sets: self.prefix_sets,
100            threshold: self.threshold,
101            previous_state: self.previous_state,
102            #[cfg(feature = "metrics")]
103            metrics: self.metrics,
104        }
105    }
106}
107
108impl<T, H> StateRoot<T, H>
109where
110    T: TrieCursorFactory + Clone,
111    H: HashedCursorFactory + Clone,
112{
113    /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the
114    /// nodes into the hash builder. Collects the updates in the process.
115    ///
116    /// Ignores the threshold.
117    ///
118    /// # Returns
119    ///
120    /// The intermediate progress of state root computation and the trie updates.
121    pub fn root_with_updates(self) -> Result<(B256, TrieUpdates), StateRootError> {
122        match self.with_no_threshold().calculate(true)? {
123            StateRootProgress::Complete(root, _, updates) => Ok((root, updates)),
124            StateRootProgress::Progress(..) => unreachable!(), // unreachable threshold
125        }
126    }
127
128    /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the
129    /// nodes into the hash builder.
130    ///
131    /// # Returns
132    ///
133    /// The state root hash.
134    pub fn root(self) -> Result<B256, StateRootError> {
135        match self.calculate(false)? {
136            StateRootProgress::Complete(root, _, _) => Ok(root),
137            StateRootProgress::Progress(..) => unreachable!(), // update retenion is disabled
138        }
139    }
140
141    /// Walks the intermediate nodes of existing state trie (if any) and hashed entries. Feeds the
142    /// nodes into the hash builder. Collects the updates in the process.
143    ///
144    /// # Returns
145    ///
146    /// The intermediate progress of state root computation.
147    pub fn root_with_progress(self) -> Result<StateRootProgress, StateRootError> {
148        self.calculate(true)
149    }
150
151    fn calculate(self, retain_updates: bool) -> Result<StateRootProgress, StateRootError> {
152        trace!(target: "trie::state_root", "calculating state root");
153        let mut tracker = TrieTracker::default();
154        let mut trie_updates = TrieUpdates::default();
155
156        let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
157
158        let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
159        let (mut hash_builder, mut account_node_iter) = match self.previous_state {
160            Some(state) => {
161                let hash_builder = state.hash_builder.with_updates(retain_updates);
162                let walker = TrieWalker::state_trie_from_stack(
163                    trie_cursor,
164                    state.walker_stack,
165                    self.prefix_sets.account_prefix_set,
166                )
167                .with_deletions_retained(retain_updates);
168                let node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor)
169                    .with_last_hashed_key(state.last_account_key);
170                (hash_builder, node_iter)
171            }
172            None => {
173                let hash_builder = HashBuilder::default().with_updates(retain_updates);
174                let walker =
175                    TrieWalker::state_trie(trie_cursor, self.prefix_sets.account_prefix_set)
176                        .with_deletions_retained(retain_updates);
177                let node_iter = TrieNodeIter::state_trie(walker, hashed_account_cursor);
178                (hash_builder, node_iter)
179            }
180        };
181
182        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
183        let mut hashed_entries_walked = 0;
184        let mut updated_storage_nodes = 0;
185        while let Some(node) = account_node_iter.try_next()? {
186            match node {
187                TrieElement::Branch(node) => {
188                    tracker.inc_branch();
189                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
190                }
191                TrieElement::Leaf(hashed_address, account) => {
192                    tracker.inc_leaf();
193                    hashed_entries_walked += 1;
194                    let is_private = false; // account leaves are always public. Their storage leaves can be private.
195
196                    // We assume we can always calculate a storage root without
197                    // OOMing. This opens us up to a potential DOS vector if
198                    // a contract had too many storage entries and they were
199                    // all buffered w/o us returning and committing our intermediate
200                    // progress.
201                    // TODO: We can consider introducing the TrieProgress::Progress/Complete
202                    // abstraction inside StorageRoot, but let's give it a try as-is for now.
203                    let storage_root_calculator = StorageRoot::new_hashed(
204                        self.trie_cursor_factory.clone(),
205                        self.hashed_cursor_factory.clone(),
206                        hashed_address,
207                        self.prefix_sets
208                            .storage_prefix_sets
209                            .get(&hashed_address)
210                            .cloned()
211                            .unwrap_or_default(),
212                        #[cfg(feature = "metrics")]
213                        self.metrics.storage_trie.clone(),
214                    );
215
216                    let storage_root = if retain_updates {
217                        let (root, storage_slots_walked, updates) =
218                            storage_root_calculator.root_with_updates()?;
219                        hashed_entries_walked += storage_slots_walked;
220                        // We only walk over hashed address once, so it's safe to insert.
221                        updated_storage_nodes += updates.len();
222                        trie_updates.insert_storage_updates(hashed_address, updates);
223                        root
224                    } else {
225                        storage_root_calculator.root()?
226                    };
227
228                    account_rlp.clear();
229                    let account = account.into_trie_account(storage_root);
230                    account.encode(&mut account_rlp as &mut dyn BufMut);
231                    hash_builder.add_leaf(
232                        Nibbles::unpack(hashed_address),
233                        &account_rlp,
234                        is_private,
235                    );
236
237                    // Decide if we need to return intermediate progress.
238                    let total_updates_len = updated_storage_nodes +
239                        account_node_iter.walker.removed_keys_len() +
240                        hash_builder.updates_len();
241                    if retain_updates && total_updates_len as u64 >= self.threshold {
242                        let (walker_stack, walker_deleted_keys) = account_node_iter.walker.split();
243                        trie_updates.removed_nodes.extend(walker_deleted_keys);
244                        let (hash_builder, hash_builder_updates) = hash_builder.split();
245                        trie_updates.account_nodes.extend(hash_builder_updates);
246
247                        let state = IntermediateStateRootState {
248                            hash_builder,
249                            walker_stack,
250                            last_account_key: hashed_address,
251                        };
252
253                        return Ok(StateRootProgress::Progress(
254                            Box::new(state),
255                            hashed_entries_walked,
256                            trie_updates,
257                        ));
258                    }
259                }
260            }
261        }
262
263        let root = hash_builder.root();
264
265        let removed_keys = account_node_iter.walker.take_removed_keys();
266        trie_updates.finalize(hash_builder, removed_keys, self.prefix_sets.destroyed_accounts);
267
268        let stats = tracker.finish();
269
270        #[cfg(feature = "metrics")]
271        self.metrics.state_trie.record(stats);
272
273        trace!(
274            target: "trie::state_root",
275            %root,
276            duration = ?stats.duration(),
277            branches_added = stats.branches_added(),
278            leaves_added = stats.leaves_added(),
279            "calculated state root"
280        );
281
282        Ok(StateRootProgress::Complete(root, hashed_entries_walked, trie_updates))
283    }
284}
285
286/// `StorageRoot` is used to compute the root node of an account storage trie.
287#[derive(Debug)]
288pub struct StorageRoot<T, H> {
289    /// A reference to the database transaction.
290    pub trie_cursor_factory: T,
291    /// The factory for hashed cursors.
292    pub hashed_cursor_factory: H,
293    /// The hashed address of an account.
294    pub hashed_address: B256,
295    /// The set of storage slot prefixes that have changed.
296    pub prefix_set: PrefixSet,
297    /// Storage root metrics.
298    #[cfg(feature = "metrics")]
299    metrics: TrieRootMetrics,
300}
301
302impl<T, H> StorageRoot<T, H> {
303    /// Creates a new storage root calculator given a raw address.
304    pub fn new(
305        trie_cursor_factory: T,
306        hashed_cursor_factory: H,
307        address: Address,
308        prefix_set: PrefixSet,
309        #[cfg(feature = "metrics")] metrics: TrieRootMetrics,
310    ) -> Self {
311        Self::new_hashed(
312            trie_cursor_factory,
313            hashed_cursor_factory,
314            keccak256(address),
315            prefix_set,
316            #[cfg(feature = "metrics")]
317            metrics,
318        )
319    }
320
321    /// Creates a new storage root calculator given a hashed address.
322    pub const fn new_hashed(
323        trie_cursor_factory: T,
324        hashed_cursor_factory: H,
325        hashed_address: B256,
326        prefix_set: PrefixSet,
327        #[cfg(feature = "metrics")] metrics: TrieRootMetrics,
328    ) -> Self {
329        Self {
330            trie_cursor_factory,
331            hashed_cursor_factory,
332            hashed_address,
333            prefix_set,
334            #[cfg(feature = "metrics")]
335            metrics,
336        }
337    }
338
339    /// Set the changed prefixes.
340    pub fn with_prefix_set(mut self, prefix_set: PrefixSet) -> Self {
341        self.prefix_set = prefix_set;
342        self
343    }
344
345    /// Set the hashed cursor factory.
346    pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> StorageRoot<T, HF> {
347        StorageRoot {
348            trie_cursor_factory: self.trie_cursor_factory,
349            hashed_cursor_factory,
350            hashed_address: self.hashed_address,
351            prefix_set: self.prefix_set,
352            #[cfg(feature = "metrics")]
353            metrics: self.metrics,
354        }
355    }
356
357    /// Set the trie cursor factory.
358    pub fn with_trie_cursor_factory<TF>(self, trie_cursor_factory: TF) -> StorageRoot<TF, H> {
359        StorageRoot {
360            trie_cursor_factory,
361            hashed_cursor_factory: self.hashed_cursor_factory,
362            hashed_address: self.hashed_address,
363            prefix_set: self.prefix_set,
364            #[cfg(feature = "metrics")]
365            metrics: self.metrics,
366        }
367    }
368}
369
370impl<T, H> StorageRoot<T, H>
371where
372    T: TrieCursorFactory,
373    H: HashedCursorFactory,
374{
375    /// Walks the hashed storage table entries for a given address and calculates the storage root.
376    ///
377    /// # Returns
378    ///
379    /// The storage root and storage trie updates for a given address.
380    pub fn root_with_updates(self) -> Result<(B256, usize, StorageTrieUpdates), StorageRootError> {
381        self.calculate(true)
382    }
383
384    /// Walks the hashed storage table entries for a given address and calculates the storage root.
385    ///
386    /// # Returns
387    ///
388    /// The storage root.
389    pub fn root(self) -> Result<B256, StorageRootError> {
390        let (root, _, _) = self.calculate(false)?;
391        Ok(root)
392    }
393
394    /// Walks the hashed storage table entries for a given address and calculates the storage root.
395    ///
396    /// # Returns
397    ///
398    /// The storage root, number of walked entries and trie updates
399    /// for a given address if requested.
400    pub fn calculate(
401        self,
402        retain_updates: bool,
403    ) -> Result<(B256, usize, StorageTrieUpdates), StorageRootError> {
404        trace!(target: "trie::storage_root", hashed_address = ?self.hashed_address, "calculating storage root");
405
406        let mut hashed_storage_cursor =
407            self.hashed_cursor_factory.hashed_storage_cursor(self.hashed_address)?;
408
409        // short circuit on empty storage
410        if hashed_storage_cursor.is_storage_empty()? {
411            return Ok((EMPTY_ROOT_HASH, 0, StorageTrieUpdates::deleted()));
412        }
413
414        let mut tracker = TrieTracker::default();
415        let trie_cursor = self.trie_cursor_factory.storage_trie_cursor(self.hashed_address)?;
416        let walker = TrieWalker::storage_trie(trie_cursor, self.prefix_set)
417            .with_deletions_retained(retain_updates);
418
419        let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
420
421        let mut storage_node_iter = TrieNodeIter::storage_trie(walker, hashed_storage_cursor);
422        while let Some(node) = storage_node_iter.try_next()? {
423            match node {
424                TrieElement::Branch(node) => {
425                    tracker.inc_branch();
426                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
427                }
428                TrieElement::Leaf(hashed_slot, value) => {
429                    tracker.inc_leaf();
430                    hash_builder.add_leaf(
431                        Nibbles::unpack(hashed_slot),
432                        alloy_rlp::encode_fixed_size(&value.value).as_ref(),
433                        value.is_private,
434                    );
435                }
436            }
437        }
438
439        let root = hash_builder.root();
440
441        let mut trie_updates = StorageTrieUpdates::default();
442        let removed_keys = storage_node_iter.walker.take_removed_keys();
443        trie_updates.finalize(hash_builder, removed_keys);
444
445        let stats = tracker.finish();
446
447        #[cfg(feature = "metrics")]
448        self.metrics.record(stats);
449
450        trace!(
451            target: "trie::storage_root",
452            %root,
453            hashed_address = %self.hashed_address,
454            duration = ?stats.duration(),
455            branches_added = stats.branches_added(),
456            leaves_added = stats.leaves_added(),
457            "calculated storage root"
458        );
459
460        let storage_slots_walked = stats.leaves_added() as usize;
461        Ok((root, storage_slots_walked, trie_updates))
462    }
463}
464
465/// Trie type for differentiating between various trie calculations.
466#[derive(Clone, Copy, Debug)]
467pub enum TrieType {
468    /// State trie type.
469    State,
470    /// Storage trie type.
471    Storage,
472}
473
474impl TrieType {
475    #[cfg(feature = "metrics")]
476    pub(crate) const fn as_str(&self) -> &'static str {
477        match self {
478            Self::State => "state",
479            Self::Storage => "storage",
480        }
481    }
482}