reth_trie_parallel/
proof.rs

1use crate::{root::ParallelStateRootError, stats::ParallelTrieTracker, StorageRootTargets};
2use alloy_primitives::{
3    map::{B256HashMap, HashMap},
4    B256,
5};
6use alloy_rlp::{BufMut, Encodable};
7use itertools::Itertools;
8use reth_db::DatabaseError;
9use reth_execution_errors::StorageRootError;
10use reth_provider::{
11    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
12    StateCommitmentProvider,
13};
14use reth_trie::{
15    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
16    node_iter::{TrieElement, TrieNodeIter},
17    prefix_set::{PrefixSetMut, TriePrefixSetsMut},
18    proof::StorageProof,
19    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
20    updates::TrieUpdatesSorted,
21    walker::TrieWalker,
22    HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof,
23    TrieAccount, TRIE_ACCOUNT_RLP_MAX_SIZE,
24};
25use reth_trie_common::proof::ProofRetainer;
26use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
27use std::sync::Arc;
28use tracing::{debug, error};
29
30#[cfg(feature = "metrics")]
31use crate::metrics::ParallelStateRootMetrics;
32
33/// TODO:
34#[derive(Debug)]
35pub struct ParallelProof<Factory> {
36    /// Consistent view of the database.
37    view: ConsistentDbView<Factory>,
38    /// The sorted collection of cached in-memory intermediate trie nodes that
39    /// can be reused for computation.
40    pub nodes_sorted: Arc<TrieUpdatesSorted>,
41    /// The sorted in-memory overlay hashed state.
42    pub state_sorted: Arc<HashedPostStateSorted>,
43    /// The collection of prefix sets for the computation. Since the prefix sets _always_
44    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
45    /// if we have cached nodes for them.
46    pub prefix_sets: Arc<TriePrefixSetsMut>,
47    /// Flag indicating whether to include branch node hash masks in the proof.
48    collect_branch_node_hash_masks: bool,
49    /// Parallel state root metrics.
50    #[cfg(feature = "metrics")]
51    metrics: ParallelStateRootMetrics,
52}
53
54impl<Factory> ParallelProof<Factory> {
55    /// Create new state proof generator.
56    pub fn new(
57        view: ConsistentDbView<Factory>,
58        nodes_sorted: Arc<TrieUpdatesSorted>,
59        state_sorted: Arc<HashedPostStateSorted>,
60        prefix_sets: Arc<TriePrefixSetsMut>,
61    ) -> Self {
62        Self {
63            view,
64            nodes_sorted,
65            state_sorted,
66            prefix_sets,
67            collect_branch_node_hash_masks: false,
68            #[cfg(feature = "metrics")]
69            metrics: ParallelStateRootMetrics::default(),
70        }
71    }
72
73    /// Set the flag indicating whether to include branch node hash masks in the proof.
74    pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
75        self.collect_branch_node_hash_masks = branch_node_hash_masks;
76        self
77    }
78}
79
80impl<Factory> ParallelProof<Factory>
81where
82    Factory: DatabaseProviderFactory<Provider: BlockReader>
83        + StateCommitmentProvider
84        + Clone
85        + Send
86        + Sync
87        + 'static,
88{
89    /// Generate a state multiproof according to specified targets.
90    pub fn multiproof(
91        self,
92        targets: MultiProofTargets,
93    ) -> Result<MultiProof, ParallelStateRootError> {
94        let mut tracker = ParallelTrieTracker::default();
95
96        // Extend prefix sets with targets
97        let mut prefix_sets = (*self.prefix_sets).clone();
98        prefix_sets.extend(TriePrefixSetsMut {
99            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
100            storage_prefix_sets: targets
101                .iter()
102                .filter(|&(_hashed_address, slots)| (!slots.is_empty()))
103                .map(|(hashed_address, slots)| {
104                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
105                })
106                .collect(),
107            destroyed_accounts: Default::default(),
108        });
109        let prefix_sets = prefix_sets.freeze();
110
111        let storage_root_targets = StorageRootTargets::new(
112            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
113            prefix_sets.storage_prefix_sets.clone(),
114        );
115
116        // Pre-calculate storage roots for accounts which were changed.
117        tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
118        debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs");
119        let mut storage_proofs =
120            B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
121        for (hashed_address, prefix_set) in
122            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
123        {
124            let view = self.view.clone();
125            let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
126
127            let trie_nodes_sorted = self.nodes_sorted.clone();
128            let hashed_state_sorted = self.state_sorted.clone();
129
130            let (tx, rx) = std::sync::mpsc::sync_channel(1);
131
132            rayon::spawn_fifo(move || {
133                let result = (|| -> Result<_, ParallelStateRootError> {
134                    let provider_ro = view.provider_ro()?;
135                    let trie_cursor_factory = InMemoryTrieCursorFactory::new(
136                        DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
137                        &trie_nodes_sorted,
138                    );
139                    let hashed_cursor_factory = HashedPostStateCursorFactory::new(
140                        DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
141                        &hashed_state_sorted,
142                    );
143
144                    StorageProof::new_hashed(
145                        trie_cursor_factory,
146                        hashed_cursor_factory,
147                        hashed_address,
148                    )
149                    .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
150                    .with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
151                    .storage_multiproof(target_slots)
152                    .map_err(|e| ParallelStateRootError::Other(e.to_string()))
153                })();
154                if let Err(err) = tx.send(result) {
155                    error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0,  "Failed to send proof result");
156                }
157            });
158            storage_proofs.insert(hashed_address, rx);
159        }
160
161        let provider_ro = self.view.provider_ro()?;
162        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
163            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
164            &self.nodes_sorted,
165        );
166        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
167            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
168            &self.state_sorted,
169        );
170
171        // Create the walker.
172        let walker = TrieWalker::new(
173            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
174            prefix_sets.account_prefix_set,
175        )
176        .with_deletions_retained(true);
177
178        // Create a hash builder to rebuild the root node since it is not available in the database.
179        let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
180        let mut hash_builder = HashBuilder::default()
181            .with_proof_retainer(retainer)
182            .with_updates(self.collect_branch_node_hash_masks);
183
184        // Initialize all storage multiproofs as empty.
185        // Storage multiproofs for non empty tries will be overwritten if necessary.
186        let mut storages: B256HashMap<_> =
187            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
188        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
189        let mut account_node_iter = TrieNodeIter::new(
190            walker,
191            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
192        );
193        while let Some(account_node) =
194            account_node_iter.try_next().map_err(ProviderError::Database)?
195        {
196            match account_node {
197                TrieElement::Branch(node) => {
198                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
199                }
200                TrieElement::Leaf(hashed_address, account) => {
201                    let storage_multiproof = match storage_proofs.remove(&hashed_address) {
202                        Some(rx) => rx.recv().map_err(|_| {
203                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
204                                DatabaseError::Other(format!(
205                                    "channel closed for {hashed_address}"
206                                )),
207                            ))
208                        })??,
209                        // Since we do not store all intermediate nodes in the database, there might
210                        // be a possibility of re-adding a non-modified leaf to the hash builder.
211                        None => {
212                            tracker.inc_missed_leaves();
213                            StorageProof::new_hashed(
214                                trie_cursor_factory.clone(),
215                                hashed_cursor_factory.clone(),
216                                hashed_address,
217                            )
218                            .with_prefix_set_mut(Default::default())
219                            .storage_multiproof(
220                                targets.get(&hashed_address).cloned().unwrap_or_default(),
221                            )
222                            .map_err(|e| {
223                                ParallelStateRootError::StorageRoot(StorageRootError::Database(
224                                    DatabaseError::Other(e.to_string()),
225                                ))
226                            })?
227                        }
228                    };
229
230                    // Encode account
231                    account_rlp.clear();
232                    let account = TrieAccount::from((account, storage_multiproof.root));
233                    account.encode(&mut account_rlp as &mut dyn BufMut);
234
235                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
236
237                    // We might be adding leaves that are not necessarily our proof targets.
238                    if targets.contains_key(&hashed_address) {
239                        storages.insert(hashed_address, storage_multiproof);
240                    }
241                }
242            }
243        }
244        let _ = hash_builder.root();
245
246        #[cfg(feature = "metrics")]
247        self.metrics.record_state_trie(tracker.finish());
248
249        let account_subtree = hash_builder.take_proof_nodes();
250        let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
251            hash_builder
252                .updated_branch_nodes
253                .unwrap_or_default()
254                .into_iter()
255                .map(|(path, node)| (path, node.hash_mask))
256                .collect()
257        } else {
258            HashMap::default()
259        };
260
261        Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use alloy_primitives::{
269        keccak256,
270        map::{B256HashSet, DefaultHashBuilder},
271        Address, U256,
272    };
273    use rand::Rng;
274    use reth_primitives::{Account, StorageEntry};
275    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
276    use reth_trie::proof::Proof;
277
278    #[test]
279    fn random_parallel_proof() {
280        let factory = create_test_provider_factory();
281        let consistent_view = ConsistentDbView::new(factory.clone(), None);
282
283        let mut rng = rand::thread_rng();
284        let state = (0..100)
285            .map(|_| {
286                let address = Address::random();
287                let account =
288                    Account { balance: U256::from(rng.gen::<u64>()), ..Default::default() };
289                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
290                let has_storage = rng.gen_bool(0.7);
291                if has_storage {
292                    for _ in 0..100 {
293                        storage.insert(
294                            B256::from(U256::from(rng.gen::<u64>())),
295                            U256::from(rng.gen::<u64>()),
296                        );
297                    }
298                }
299                (address, (account, storage))
300            })
301            .collect::<HashMap<_, _, DefaultHashBuilder>>();
302
303        {
304            let provider_rw = factory.provider_rw().unwrap();
305            provider_rw
306                .insert_account_for_hashing(
307                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
308                )
309                .unwrap();
310            provider_rw
311                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
312                    (
313                        *address,
314                        storage.iter().map(|(slot, value)| StorageEntry {
315                            key: *slot,
316                            value: *value,
317                            is_private: false,
318                        }),
319                    )
320                }))
321                .unwrap();
322            provider_rw.commit().unwrap();
323        }
324
325        let mut targets = MultiProofTargets::default();
326        for (address, (_, storage)) in state.iter().take(10) {
327            let hashed_address = keccak256(*address);
328            let mut target_slots = B256HashSet::default();
329
330            for (slot, _) in storage.iter().take(5) {
331                target_slots.insert(*slot);
332            }
333
334            if !target_slots.is_empty() {
335                targets.insert(hashed_address, target_slots);
336            }
337        }
338
339        let provider_rw = factory.provider_rw().unwrap();
340        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
341        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
342
343        assert_eq!(
344            ParallelProof::new(
345                consistent_view,
346                Default::default(),
347                Default::default(),
348                Default::default()
349            )
350            .multiproof(targets.clone())
351            .unwrap(),
352            Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap()
353        );
354    }
355}