reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics,
3    proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4    root::ParallelStateRootError,
5    stats::ParallelTrieTracker,
6    StorageRootTargets,
7};
8use alloy_primitives::{
9    map::{B256Map, B256Set, HashMap},
10    B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use itertools::Itertools;
14use reth_execution_errors::StorageRootError;
15use reth_provider::{
16    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17    ProviderError, StateCommitmentProvider,
18};
19use reth_storage_errors::db::DatabaseError;
20use reth_trie::{
21    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22    node_iter::{TrieElement, TrieNodeIter},
23    prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
24    proof::StorageProof,
25    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
26    updates::TrieUpdatesSorted,
27    walker::TrieWalker,
28    DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted, MultiProof,
29    MultiProofTargets, Nibbles, StorageMultiProof, TRIE_ACCOUNT_RLP_MAX_SIZE,
30};
31use reth_trie_common::proof::ProofRetainer;
32use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
33use std::sync::{mpsc::Receiver, Arc};
34use tracing::debug;
35
36/// Parallel proof calculator.
37///
38/// This can collect proof for many targets in parallel, spawning a task for each hashed address
39/// that has proof targets.
40#[derive(Debug)]
41pub struct ParallelProof<Factory: DatabaseProviderFactory> {
42    /// Consistent view of the database.
43    view: ConsistentDbView<Factory>,
44    /// The sorted collection of cached in-memory intermediate trie nodes that
45    /// can be reused for computation.
46    pub nodes_sorted: Arc<TrieUpdatesSorted>,
47    /// The sorted in-memory overlay hashed state.
48    pub state_sorted: Arc<HashedPostStateSorted>,
49    /// The collection of prefix sets for the computation. Since the prefix sets _always_
50    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
51    /// if we have cached nodes for them.
52    pub prefix_sets: Arc<TriePrefixSetsMut>,
53    /// Flag indicating whether to include branch node masks in the proof.
54    collect_branch_node_masks: bool,
55    /// Handle to the storage proof task.
56    storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
57    #[cfg(feature = "metrics")]
58    metrics: ParallelTrieMetrics,
59}
60
61impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
62    /// Create new state proof generator.
63    pub fn new(
64        view: ConsistentDbView<Factory>,
65        nodes_sorted: Arc<TrieUpdatesSorted>,
66        state_sorted: Arc<HashedPostStateSorted>,
67        prefix_sets: Arc<TriePrefixSetsMut>,
68        storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
69    ) -> Self {
70        Self {
71            view,
72            nodes_sorted,
73            state_sorted,
74            prefix_sets,
75            collect_branch_node_masks: false,
76            storage_proof_task_handle,
77            #[cfg(feature = "metrics")]
78            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
79        }
80    }
81
82    /// Set the flag indicating whether to include branch node masks in the proof.
83    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
84        self.collect_branch_node_masks = branch_node_masks;
85        self
86    }
87}
88
89impl<Factory> ParallelProof<Factory>
90where
91    Factory:
92        DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
93{
94    /// Spawns a storage proof on the storage proof task and returns a receiver for the result.
95    fn spawn_storage_proof(
96        &self,
97        hashed_address: B256,
98        prefix_set: PrefixSet,
99        target_slots: B256Set,
100    ) -> Receiver<Result<StorageMultiProof, ParallelStateRootError>> {
101        let input = StorageProofInput::new(
102            hashed_address,
103            prefix_set,
104            target_slots,
105            self.collect_branch_node_masks,
106        );
107
108        let (sender, receiver) = std::sync::mpsc::channel();
109        let _ =
110            self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
111        receiver
112    }
113
114    /// Generate a storage multiproof according to the specified targets and hashed address.
115    pub fn storage_proof(
116        self,
117        hashed_address: B256,
118        target_slots: B256Set,
119    ) -> Result<StorageMultiProof, ParallelStateRootError> {
120        let total_targets = target_slots.len();
121        let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
122        let prefix_set = prefix_set.freeze();
123
124        debug!(
125            target: "trie::parallel_proof",
126            total_targets,
127            ?hashed_address,
128            "Starting storage proof generation"
129        );
130
131        let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
132        let proof_result = receiver.recv().map_err(|_| {
133            ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134                format!("channel closed for {hashed_address}"),
135            )))
136        })?;
137
138        debug!(
139            target: "trie::parallel_proof",
140            total_targets,
141            ?hashed_address,
142            "Storage proof generation completed"
143        );
144
145        proof_result
146    }
147
148    /// Generate a [`DecodedStorageMultiProof`] for the given proof by first calling
149    /// `storage_proof`, then decoding the proof nodes.
150    pub fn decoded_storage_proof(
151        self,
152        hashed_address: B256,
153        target_slots: B256Set,
154    ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
155        let proof = self.storage_proof(hashed_address, target_slots)?;
156
157        // Now decode the nodes of the proof
158        let proof = proof.try_into()?;
159
160        Ok(proof)
161    }
162
163    /// Generate a state multiproof according to specified targets.
164    pub fn multiproof(
165        self,
166        targets: MultiProofTargets,
167    ) -> Result<MultiProof, ParallelStateRootError> {
168        let mut tracker = ParallelTrieTracker::default();
169
170        // Extend prefix sets with targets
171        let mut prefix_sets = (*self.prefix_sets).clone();
172        prefix_sets.extend(TriePrefixSetsMut {
173            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
174            storage_prefix_sets: targets
175                .iter()
176                .filter(|&(_hashed_address, slots)| !slots.is_empty())
177                .map(|(hashed_address, slots)| {
178                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
179                })
180                .collect(),
181            destroyed_accounts: Default::default(),
182        });
183        let prefix_sets = prefix_sets.freeze();
184
185        let storage_root_targets = StorageRootTargets::new(
186            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
187            prefix_sets.storage_prefix_sets.clone(),
188        );
189        let storage_root_targets_len = storage_root_targets.len();
190
191        debug!(
192            target: "trie::parallel_proof",
193            total_targets = storage_root_targets_len,
194            "Starting parallel proof generation"
195        );
196
197        // Pre-calculate storage roots for accounts which were changed.
198        tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
199
200        // stores the receiver for the storage proof outcome for the hashed addresses
201        // this way we can lazily await the outcome when we iterate over the map
202        let mut storage_proofs =
203            B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
204
205        for (hashed_address, prefix_set) in
206            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
207        {
208            let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
209            let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
210
211            // store the receiver for that result with the hashed address so we can await this in
212            // place when we iterate over the trie
213            storage_proofs.insert(hashed_address, receiver);
214        }
215
216        let provider_ro = self.view.provider_ro()?;
217        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
218            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
219            &self.nodes_sorted,
220        );
221        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
222            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
223            &self.state_sorted,
224        );
225
226        // Create the walker.
227        let walker = TrieWalker::state_trie(
228            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
229            prefix_sets.account_prefix_set,
230        )
231        .with_deletions_retained(true);
232
233        // Create a hash builder to rebuild the root node since it is not available in the database.
234        let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
235        let mut hash_builder = HashBuilder::default()
236            .with_proof_retainer(retainer)
237            .with_updates(self.collect_branch_node_masks);
238
239        // Initialize all storage multiproofs as empty.
240        // Storage multiproofs for non empty tries will be overwritten if necessary.
241        let mut storages: B256Map<_> =
242            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
243        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
244        let mut account_node_iter = TrieNodeIter::state_trie(
245            walker,
246            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
247        );
248        while let Some(account_node) =
249            account_node_iter.try_next().map_err(ProviderError::Database)?
250        {
251            match account_node {
252                TrieElement::Branch(node) => {
253                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
254                }
255                TrieElement::Leaf(hashed_address, account) => {
256                    let storage_multiproof = match storage_proofs.remove(&hashed_address) {
257                        Some(rx) => rx.recv().map_err(|_| {
258                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
259                                DatabaseError::Other(format!(
260                                    "channel closed for {hashed_address}"
261                                )),
262                            ))
263                        })??,
264                        // Since we do not store all intermediate nodes in the database, there might
265                        // be a possibility of re-adding a non-modified leaf to the hash builder.
266                        None => {
267                            tracker.inc_missed_leaves();
268                            StorageProof::new_hashed(
269                                trie_cursor_factory.clone(),
270                                hashed_cursor_factory.clone(),
271                                hashed_address,
272                            )
273                            .with_prefix_set_mut(Default::default())
274                            .storage_multiproof(
275                                targets.get(&hashed_address).cloned().unwrap_or_default(),
276                            )
277                            .map_err(|e| {
278                                ParallelStateRootError::StorageRoot(StorageRootError::Database(
279                                    DatabaseError::Other(e.to_string()),
280                                ))
281                            })?
282                        }
283                    };
284
285                    // Encode account
286                    account_rlp.clear();
287                    let account = account.into_trie_account(storage_multiproof.root);
288                    account.encode(&mut account_rlp as &mut dyn BufMut);
289                    let is_private = false; // account leaves are always public. Their storage leaves can be private.
290                    hash_builder.add_leaf(
291                        Nibbles::unpack(hashed_address),
292                        &account_rlp,
293                        is_private,
294                    );
295
296                    // We might be adding leaves that are not necessarily our proof targets.
297                    if targets.contains_key(&hashed_address) {
298                        storages.insert(hashed_address, storage_multiproof);
299                    }
300                }
301            }
302        }
303        let _ = hash_builder.root();
304
305        let stats = tracker.finish();
306        #[cfg(feature = "metrics")]
307        self.metrics.record(stats);
308
309        let account_subtree = hash_builder.take_proof_nodes();
310        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
311            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
312            (
313                updated_branch_nodes
314                    .iter()
315                    .map(|(path, node)| (path.clone(), node.hash_mask))
316                    .collect(),
317                updated_branch_nodes
318                    .into_iter()
319                    .map(|(path, node)| (path, node.tree_mask))
320                    .collect(),
321            )
322        } else {
323            (HashMap::default(), HashMap::default())
324        };
325
326        debug!(
327            target: "trie::parallel_proof",
328            total_targets = storage_root_targets_len,
329            duration = ?stats.duration(),
330            branches_added = stats.branches_added(),
331            leaves_added = stats.leaves_added(),
332            missed_leaves = stats.missed_leaves(),
333            precomputed_storage_roots = stats.precomputed_storage_roots(),
334            "Calculated proof"
335        );
336
337        Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
338    }
339
340    /// Returns a [`DecodedMultiProof`] for the given proof.
341    ///
342    /// Uses `multiproof` first to get the proof, and then decodes the nodes of the multiproof.
343    pub fn decoded_multiproof(
344        self,
345        targets: MultiProofTargets,
346    ) -> Result<DecodedMultiProof, ParallelStateRootError> {
347        let multiproof = self.multiproof(targets)?;
348
349        // Now decode the nodes of the multiproof
350        let multiproof = multiproof.try_into()?;
351
352        Ok(multiproof)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
360    use alloy_primitives::{
361        keccak256,
362        map::{B256Set, DefaultHashBuilder},
363        Address, U256,
364    };
365    use rand::Rng;
366    use reth_primitives_traits::{Account, StorageEntry};
367    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
368    use reth_trie::proof::Proof;
369    use tokio::runtime::Runtime;
370
371    #[test]
372    fn random_parallel_proof() {
373        let factory = create_test_provider_factory();
374        let consistent_view = ConsistentDbView::new(factory.clone(), None);
375
376        let mut rng = rand::rng();
377        let state = (0..100)
378            .map(|_| {
379                let address = Address::random();
380                let account =
381                    Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
382                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
383                let has_storage = rng.random_bool(0.7);
384                if has_storage {
385                    for _ in 0..100 {
386                        storage.insert(
387                            B256::from(U256::from(rng.random::<u64>())),
388                            U256::from(rng.random::<u64>()),
389                        );
390                    }
391                }
392                (address, (account, storage))
393            })
394            .collect::<HashMap<_, _, DefaultHashBuilder>>();
395
396        {
397            let provider_rw = factory.provider_rw().unwrap();
398            provider_rw
399                .insert_account_for_hashing(
400                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
401                )
402                .unwrap();
403            provider_rw
404                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
405                    (
406                        *address,
407                        storage.iter().map(|(slot, value)| StorageEntry {
408                            key: *slot,
409                            value: *value,
410                            is_private: false,
411                        }),
412                    )
413                }))
414                .unwrap();
415            provider_rw.commit().unwrap();
416        }
417
418        let mut targets = MultiProofTargets::default();
419        for (address, (_, storage)) in state.iter().take(10) {
420            let hashed_address = keccak256(*address);
421            let mut target_slots = B256Set::default();
422
423            for (slot, _) in storage.iter().take(5) {
424                target_slots.insert(*slot);
425            }
426
427            if !target_slots.is_empty() {
428                targets.insert(hashed_address, target_slots);
429            }
430        }
431
432        let provider_rw = factory.provider_rw().unwrap();
433        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
434        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
435
436        let rt = Runtime::new().unwrap();
437
438        let task_ctx =
439            ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
440        let proof_task =
441            ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
442        let proof_task_handle = proof_task.handle();
443
444        // keep the join handle around to make sure it does not return any errors
445        // after we compute the state root
446        let join_handle = rt.spawn_blocking(move || proof_task.run());
447
448        let parallel_result = ParallelProof::new(
449            consistent_view,
450            Default::default(),
451            Default::default(),
452            Default::default(),
453            proof_task_handle.clone(),
454        )
455        .multiproof(targets.clone())
456        .unwrap();
457
458        let sequential_result =
459            Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap();
460
461        // to help narrow down what is wrong - first compare account subtries
462        assert_eq!(parallel_result.account_subtree, sequential_result.account_subtree);
463
464        // then compare length of all storage subtries
465        assert_eq!(parallel_result.storages.len(), sequential_result.storages.len());
466
467        // then compare each storage subtrie
468        for (hashed_address, storage_proof) in &parallel_result.storages {
469            let sequential_storage_proof = sequential_result.storages.get(hashed_address).unwrap();
470            assert_eq!(storage_proof, sequential_storage_proof);
471        }
472
473        // then compare the entire thing for any mask differences
474        assert_eq!(parallel_result, sequential_result);
475
476        // drop the handle to terminate the task and then block on the proof task handle to make
477        // sure it does not return any errors
478        drop(proof_task_handle);
479        rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
480    }
481}