1use crate::{
2 metrics::ParallelTrieMetrics,
3 proof_task::{ProofTaskKind, ProofTaskManagerHandle, StorageProofInput},
4 root::ParallelStateRootError,
5 stats::ParallelTrieTracker,
6 StorageRootTargets,
7};
8use alloy_primitives::{
9 map::{B256Map, B256Set, HashMap},
10 B256,
11};
12use alloy_rlp::{BufMut, Encodable};
13use itertools::Itertools;
14use reth_execution_errors::StorageRootError;
15use reth_provider::{
16 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17 ProviderError, StateCommitmentProvider,
18};
19use reth_storage_errors::db::DatabaseError;
20use reth_trie::{
21 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22 node_iter::{TrieElement, TrieNodeIter},
23 prefix_set::{PrefixSet, PrefixSetMut, TriePrefixSetsMut},
24 proof::StorageProof,
25 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
26 updates::TrieUpdatesSorted,
27 walker::TrieWalker,
28 DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostStateSorted, MultiProof,
29 MultiProofTargets, Nibbles, StorageMultiProof, TRIE_ACCOUNT_RLP_MAX_SIZE,
30};
31use reth_trie_common::proof::ProofRetainer;
32use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
33use std::sync::{mpsc::Receiver, Arc};
34use tracing::debug;
35
36#[derive(Debug)]
41pub struct ParallelProof<Factory: DatabaseProviderFactory> {
42 view: ConsistentDbView<Factory>,
44 pub nodes_sorted: Arc<TrieUpdatesSorted>,
47 pub state_sorted: Arc<HashedPostStateSorted>,
49 pub prefix_sets: Arc<TriePrefixSetsMut>,
53 collect_branch_node_masks: bool,
55 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
57 #[cfg(feature = "metrics")]
58 metrics: ParallelTrieMetrics,
59}
60
61impl<Factory: DatabaseProviderFactory> ParallelProof<Factory> {
62 pub fn new(
64 view: ConsistentDbView<Factory>,
65 nodes_sorted: Arc<TrieUpdatesSorted>,
66 state_sorted: Arc<HashedPostStateSorted>,
67 prefix_sets: Arc<TriePrefixSetsMut>,
68 storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
69 ) -> Self {
70 Self {
71 view,
72 nodes_sorted,
73 state_sorted,
74 prefix_sets,
75 collect_branch_node_masks: false,
76 storage_proof_task_handle,
77 #[cfg(feature = "metrics")]
78 metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
79 }
80 }
81
82 pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
84 self.collect_branch_node_masks = branch_node_masks;
85 self
86 }
87}
88
89impl<Factory> ParallelProof<Factory>
90where
91 Factory:
92 DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
93{
94 fn spawn_storage_proof(
96 &self,
97 hashed_address: B256,
98 prefix_set: PrefixSet,
99 target_slots: B256Set,
100 ) -> Receiver<Result<StorageMultiProof, ParallelStateRootError>> {
101 let input = StorageProofInput::new(
102 hashed_address,
103 prefix_set,
104 target_slots,
105 self.collect_branch_node_masks,
106 );
107
108 let (sender, receiver) = std::sync::mpsc::channel();
109 let _ =
110 self.storage_proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
111 receiver
112 }
113
114 pub fn storage_proof(
116 self,
117 hashed_address: B256,
118 target_slots: B256Set,
119 ) -> Result<StorageMultiProof, ParallelStateRootError> {
120 let total_targets = target_slots.len();
121 let prefix_set = PrefixSetMut::from(target_slots.iter().map(Nibbles::unpack));
122 let prefix_set = prefix_set.freeze();
123
124 debug!(
125 target: "trie::parallel_proof",
126 total_targets,
127 ?hashed_address,
128 "Starting storage proof generation"
129 );
130
131 let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
132 let proof_result = receiver.recv().map_err(|_| {
133 ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
134 format!("channel closed for {hashed_address}"),
135 )))
136 })?;
137
138 debug!(
139 target: "trie::parallel_proof",
140 total_targets,
141 ?hashed_address,
142 "Storage proof generation completed"
143 );
144
145 proof_result
146 }
147
148 pub fn decoded_storage_proof(
151 self,
152 hashed_address: B256,
153 target_slots: B256Set,
154 ) -> Result<DecodedStorageMultiProof, ParallelStateRootError> {
155 let proof = self.storage_proof(hashed_address, target_slots)?;
156
157 let proof = proof.try_into()?;
159
160 Ok(proof)
161 }
162
163 pub fn multiproof(
165 self,
166 targets: MultiProofTargets,
167 ) -> Result<MultiProof, ParallelStateRootError> {
168 let mut tracker = ParallelTrieTracker::default();
169
170 let mut prefix_sets = (*self.prefix_sets).clone();
172 prefix_sets.extend(TriePrefixSetsMut {
173 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
174 storage_prefix_sets: targets
175 .iter()
176 .filter(|&(_hashed_address, slots)| !slots.is_empty())
177 .map(|(hashed_address, slots)| {
178 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
179 })
180 .collect(),
181 destroyed_accounts: Default::default(),
182 });
183 let prefix_sets = prefix_sets.freeze();
184
185 let storage_root_targets = StorageRootTargets::new(
186 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
187 prefix_sets.storage_prefix_sets.clone(),
188 );
189 let storage_root_targets_len = storage_root_targets.len();
190
191 debug!(
192 target: "trie::parallel_proof",
193 total_targets = storage_root_targets_len,
194 "Starting parallel proof generation"
195 );
196
197 tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
199
200 let mut storage_proofs =
203 B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
204
205 for (hashed_address, prefix_set) in
206 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
207 {
208 let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
209 let receiver = self.spawn_storage_proof(hashed_address, prefix_set, target_slots);
210
211 storage_proofs.insert(hashed_address, receiver);
214 }
215
216 let provider_ro = self.view.provider_ro()?;
217 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
218 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
219 &self.nodes_sorted,
220 );
221 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
222 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
223 &self.state_sorted,
224 );
225
226 let walker = TrieWalker::state_trie(
228 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
229 prefix_sets.account_prefix_set,
230 )
231 .with_deletions_retained(true);
232
233 let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
235 let mut hash_builder = HashBuilder::default()
236 .with_proof_retainer(retainer)
237 .with_updates(self.collect_branch_node_masks);
238
239 let mut storages: B256Map<_> =
242 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
243 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
244 let mut account_node_iter = TrieNodeIter::state_trie(
245 walker,
246 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
247 );
248 while let Some(account_node) =
249 account_node_iter.try_next().map_err(ProviderError::Database)?
250 {
251 match account_node {
252 TrieElement::Branch(node) => {
253 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
254 }
255 TrieElement::Leaf(hashed_address, account) => {
256 let storage_multiproof = match storage_proofs.remove(&hashed_address) {
257 Some(rx) => rx.recv().map_err(|_| {
258 ParallelStateRootError::StorageRoot(StorageRootError::Database(
259 DatabaseError::Other(format!(
260 "channel closed for {hashed_address}"
261 )),
262 ))
263 })??,
264 None => {
267 tracker.inc_missed_leaves();
268 StorageProof::new_hashed(
269 trie_cursor_factory.clone(),
270 hashed_cursor_factory.clone(),
271 hashed_address,
272 )
273 .with_prefix_set_mut(Default::default())
274 .storage_multiproof(
275 targets.get(&hashed_address).cloned().unwrap_or_default(),
276 )
277 .map_err(|e| {
278 ParallelStateRootError::StorageRoot(StorageRootError::Database(
279 DatabaseError::Other(e.to_string()),
280 ))
281 })?
282 }
283 };
284
285 account_rlp.clear();
287 let account = account.into_trie_account(storage_multiproof.root);
288 account.encode(&mut account_rlp as &mut dyn BufMut);
289 let is_private = false; hash_builder.add_leaf(
291 Nibbles::unpack(hashed_address),
292 &account_rlp,
293 is_private,
294 );
295
296 if targets.contains_key(&hashed_address) {
298 storages.insert(hashed_address, storage_multiproof);
299 }
300 }
301 }
302 }
303 let _ = hash_builder.root();
304
305 let stats = tracker.finish();
306 #[cfg(feature = "metrics")]
307 self.metrics.record(stats);
308
309 let account_subtree = hash_builder.take_proof_nodes();
310 let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
311 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
312 (
313 updated_branch_nodes
314 .iter()
315 .map(|(path, node)| (path.clone(), node.hash_mask))
316 .collect(),
317 updated_branch_nodes
318 .into_iter()
319 .map(|(path, node)| (path, node.tree_mask))
320 .collect(),
321 )
322 } else {
323 (HashMap::default(), HashMap::default())
324 };
325
326 debug!(
327 target: "trie::parallel_proof",
328 total_targets = storage_root_targets_len,
329 duration = ?stats.duration(),
330 branches_added = stats.branches_added(),
331 leaves_added = stats.leaves_added(),
332 missed_leaves = stats.missed_leaves(),
333 precomputed_storage_roots = stats.precomputed_storage_roots(),
334 "Calculated proof"
335 );
336
337 Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
338 }
339
340 pub fn decoded_multiproof(
344 self,
345 targets: MultiProofTargets,
346 ) -> Result<DecodedMultiProof, ParallelStateRootError> {
347 let multiproof = self.multiproof(targets)?;
348
349 let multiproof = multiproof.try_into()?;
351
352 Ok(multiproof)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
360 use alloy_primitives::{
361 keccak256,
362 map::{B256Set, DefaultHashBuilder},
363 Address, U256,
364 };
365 use rand::Rng;
366 use reth_primitives_traits::{Account, StorageEntry};
367 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
368 use reth_trie::proof::Proof;
369 use tokio::runtime::Runtime;
370
371 #[test]
372 fn random_parallel_proof() {
373 let factory = create_test_provider_factory();
374 let consistent_view = ConsistentDbView::new(factory.clone(), None);
375
376 let mut rng = rand::rng();
377 let state = (0..100)
378 .map(|_| {
379 let address = Address::random();
380 let account =
381 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
382 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
383 let has_storage = rng.random_bool(0.7);
384 if has_storage {
385 for _ in 0..100 {
386 storage.insert(
387 B256::from(U256::from(rng.random::<u64>())),
388 U256::from(rng.random::<u64>()),
389 );
390 }
391 }
392 (address, (account, storage))
393 })
394 .collect::<HashMap<_, _, DefaultHashBuilder>>();
395
396 {
397 let provider_rw = factory.provider_rw().unwrap();
398 provider_rw
399 .insert_account_for_hashing(
400 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
401 )
402 .unwrap();
403 provider_rw
404 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
405 (
406 *address,
407 storage.iter().map(|(slot, value)| StorageEntry {
408 key: *slot,
409 value: *value,
410 is_private: false,
411 }),
412 )
413 }))
414 .unwrap();
415 provider_rw.commit().unwrap();
416 }
417
418 let mut targets = MultiProofTargets::default();
419 for (address, (_, storage)) in state.iter().take(10) {
420 let hashed_address = keccak256(*address);
421 let mut target_slots = B256Set::default();
422
423 for (slot, _) in storage.iter().take(5) {
424 target_slots.insert(*slot);
425 }
426
427 if !target_slots.is_empty() {
428 targets.insert(hashed_address, target_slots);
429 }
430 }
431
432 let provider_rw = factory.provider_rw().unwrap();
433 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
434 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
435
436 let rt = Runtime::new().unwrap();
437
438 let task_ctx =
439 ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
440 let proof_task =
441 ProofTaskManager::new(rt.handle().clone(), consistent_view.clone(), task_ctx, 1);
442 let proof_task_handle = proof_task.handle();
443
444 let join_handle = rt.spawn_blocking(move || proof_task.run());
447
448 let parallel_result = ParallelProof::new(
449 consistent_view,
450 Default::default(),
451 Default::default(),
452 Default::default(),
453 proof_task_handle.clone(),
454 )
455 .multiproof(targets.clone())
456 .unwrap();
457
458 let sequential_result =
459 Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap();
460
461 assert_eq!(parallel_result.account_subtree, sequential_result.account_subtree);
463
464 assert_eq!(parallel_result.storages.len(), sequential_result.storages.len());
466
467 for (hashed_address, storage_proof) in ¶llel_result.storages {
469 let sequential_storage_proof = sequential_result.storages.get(hashed_address).unwrap();
470 assert_eq!(storage_proof, sequential_storage_proof);
471 }
472
473 assert_eq!(parallel_result, sequential_result);
475
476 drop(proof_task_handle);
479 rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
480 }
481}