1use crate::{root::ParallelStateRootError, stats::ParallelTrieTracker, StorageRootTargets};
2use alloy_primitives::{
3 map::{B256HashMap, HashMap},
4 B256,
5};
6use alloy_rlp::{BufMut, Encodable};
7use itertools::Itertools;
8use reth_db::DatabaseError;
9use reth_execution_errors::StorageRootError;
10use reth_provider::{
11 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
12 StateCommitmentProvider,
13};
14use reth_trie::{
15 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
16 node_iter::{TrieElement, TrieNodeIter},
17 prefix_set::{PrefixSetMut, TriePrefixSetsMut},
18 proof::StorageProof,
19 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
20 updates::TrieUpdatesSorted,
21 walker::TrieWalker,
22 HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof,
23 TrieAccount, TRIE_ACCOUNT_RLP_MAX_SIZE,
24};
25use reth_trie_common::proof::ProofRetainer;
26use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
27use std::sync::Arc;
28use tracing::{debug, error};
29
30#[cfg(feature = "metrics")]
31use crate::metrics::ParallelStateRootMetrics;
32
33#[derive(Debug)]
35pub struct ParallelProof<Factory> {
36 view: ConsistentDbView<Factory>,
38 pub nodes_sorted: Arc<TrieUpdatesSorted>,
41 pub state_sorted: Arc<HashedPostStateSorted>,
43 pub prefix_sets: Arc<TriePrefixSetsMut>,
47 collect_branch_node_hash_masks: bool,
49 #[cfg(feature = "metrics")]
51 metrics: ParallelStateRootMetrics,
52}
53
54impl<Factory> ParallelProof<Factory> {
55 pub fn new(
57 view: ConsistentDbView<Factory>,
58 nodes_sorted: Arc<TrieUpdatesSorted>,
59 state_sorted: Arc<HashedPostStateSorted>,
60 prefix_sets: Arc<TriePrefixSetsMut>,
61 ) -> Self {
62 Self {
63 view,
64 nodes_sorted,
65 state_sorted,
66 prefix_sets,
67 collect_branch_node_hash_masks: false,
68 #[cfg(feature = "metrics")]
69 metrics: ParallelStateRootMetrics::default(),
70 }
71 }
72
73 pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
75 self.collect_branch_node_hash_masks = branch_node_hash_masks;
76 self
77 }
78}
79
80impl<Factory> ParallelProof<Factory>
81where
82 Factory: DatabaseProviderFactory<Provider: BlockReader>
83 + StateCommitmentProvider
84 + Clone
85 + Send
86 + Sync
87 + 'static,
88{
89 pub fn multiproof(
91 self,
92 targets: MultiProofTargets,
93 ) -> Result<MultiProof, ParallelStateRootError> {
94 let mut tracker = ParallelTrieTracker::default();
95
96 let mut prefix_sets = (*self.prefix_sets).clone();
98 prefix_sets.extend(TriePrefixSetsMut {
99 account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
100 storage_prefix_sets: targets
101 .iter()
102 .filter(|&(_hashed_address, slots)| (!slots.is_empty()))
103 .map(|(hashed_address, slots)| {
104 (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
105 })
106 .collect(),
107 destroyed_accounts: Default::default(),
108 });
109 let prefix_sets = prefix_sets.freeze();
110
111 let storage_root_targets = StorageRootTargets::new(
112 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
113 prefix_sets.storage_prefix_sets.clone(),
114 );
115
116 tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
118 debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs");
119 let mut storage_proofs =
120 B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
121 for (hashed_address, prefix_set) in
122 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
123 {
124 let view = self.view.clone();
125 let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
126
127 let trie_nodes_sorted = self.nodes_sorted.clone();
128 let hashed_state_sorted = self.state_sorted.clone();
129
130 let (tx, rx) = std::sync::mpsc::sync_channel(1);
131
132 rayon::spawn_fifo(move || {
133 let result = (|| -> Result<_, ParallelStateRootError> {
134 let provider_ro = view.provider_ro()?;
135 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
136 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
137 &trie_nodes_sorted,
138 );
139 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
140 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
141 &hashed_state_sorted,
142 );
143
144 StorageProof::new_hashed(
145 trie_cursor_factory,
146 hashed_cursor_factory,
147 hashed_address,
148 )
149 .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
150 .with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
151 .storage_multiproof(target_slots)
152 .map_err(|e| ParallelStateRootError::Other(e.to_string()))
153 })();
154 if let Err(err) = tx.send(result) {
155 error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result");
156 }
157 });
158 storage_proofs.insert(hashed_address, rx);
159 }
160
161 let provider_ro = self.view.provider_ro()?;
162 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
163 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
164 &self.nodes_sorted,
165 );
166 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
167 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
168 &self.state_sorted,
169 );
170
171 let walker = TrieWalker::new(
173 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
174 prefix_sets.account_prefix_set,
175 )
176 .with_deletions_retained(true);
177
178 let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
180 let mut hash_builder = HashBuilder::default()
181 .with_proof_retainer(retainer)
182 .with_updates(self.collect_branch_node_hash_masks);
183
184 let mut storages: B256HashMap<_> =
187 targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
188 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
189 let mut account_node_iter = TrieNodeIter::new(
190 walker,
191 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
192 );
193 while let Some(account_node) =
194 account_node_iter.try_next().map_err(ProviderError::Database)?
195 {
196 match account_node {
197 TrieElement::Branch(node) => {
198 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
199 }
200 TrieElement::Leaf(hashed_address, account) => {
201 let storage_multiproof = match storage_proofs.remove(&hashed_address) {
202 Some(rx) => rx.recv().map_err(|_| {
203 ParallelStateRootError::StorageRoot(StorageRootError::Database(
204 DatabaseError::Other(format!(
205 "channel closed for {hashed_address}"
206 )),
207 ))
208 })??,
209 None => {
212 tracker.inc_missed_leaves();
213 StorageProof::new_hashed(
214 trie_cursor_factory.clone(),
215 hashed_cursor_factory.clone(),
216 hashed_address,
217 )
218 .with_prefix_set_mut(Default::default())
219 .storage_multiproof(
220 targets.get(&hashed_address).cloned().unwrap_or_default(),
221 )
222 .map_err(|e| {
223 ParallelStateRootError::StorageRoot(StorageRootError::Database(
224 DatabaseError::Other(e.to_string()),
225 ))
226 })?
227 }
228 };
229
230 account_rlp.clear();
232 let account = TrieAccount::from((account, storage_multiproof.root));
233 account.encode(&mut account_rlp as &mut dyn BufMut);
234
235 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
236
237 if targets.contains_key(&hashed_address) {
239 storages.insert(hashed_address, storage_multiproof);
240 }
241 }
242 }
243 }
244 let _ = hash_builder.root();
245
246 #[cfg(feature = "metrics")]
247 self.metrics.record_state_trie(tracker.finish());
248
249 let account_subtree = hash_builder.take_proof_nodes();
250 let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
251 hash_builder
252 .updated_branch_nodes
253 .unwrap_or_default()
254 .into_iter()
255 .map(|(path, node)| (path, node.hash_mask))
256 .collect()
257 } else {
258 HashMap::default()
259 };
260
261 Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use alloy_primitives::{
269 keccak256,
270 map::{B256HashSet, DefaultHashBuilder},
271 Address, U256,
272 };
273 use rand::Rng;
274 use reth_primitives::{Account, StorageEntry};
275 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
276 use reth_trie::proof::Proof;
277
278 #[test]
279 fn random_parallel_proof() {
280 let factory = create_test_provider_factory();
281 let consistent_view = ConsistentDbView::new(factory.clone(), None);
282
283 let mut rng = rand::thread_rng();
284 let state = (0..100)
285 .map(|_| {
286 let address = Address::random();
287 let account =
288 Account { balance: U256::from(rng.gen::<u64>()), ..Default::default() };
289 let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
290 let has_storage = rng.gen_bool(0.7);
291 if has_storage {
292 for _ in 0..100 {
293 storage.insert(
294 B256::from(U256::from(rng.gen::<u64>())),
295 U256::from(rng.gen::<u64>()),
296 );
297 }
298 }
299 (address, (account, storage))
300 })
301 .collect::<HashMap<_, _, DefaultHashBuilder>>();
302
303 {
304 let provider_rw = factory.provider_rw().unwrap();
305 provider_rw
306 .insert_account_for_hashing(
307 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
308 )
309 .unwrap();
310 provider_rw
311 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
312 (
313 *address,
314 storage.iter().map(|(slot, value)| StorageEntry {
315 key: *slot,
316 value: *value,
317 is_private: false,
318 }),
319 )
320 }))
321 .unwrap();
322 provider_rw.commit().unwrap();
323 }
324
325 let mut targets = MultiProofTargets::default();
326 for (address, (_, storage)) in state.iter().take(10) {
327 let hashed_address = keccak256(*address);
328 let mut target_slots = B256HashSet::default();
329
330 for (slot, _) in storage.iter().take(5) {
331 target_slots.insert(*slot);
332 }
333
334 if !target_slots.is_empty() {
335 targets.insert(hashed_address, target_slots);
336 }
337 }
338
339 let provider_rw = factory.provider_rw().unwrap();
340 let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
341 let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
342
343 assert_eq!(
344 ParallelProof::new(
345 consistent_view,
346 Default::default(),
347 Default::default(),
348 Default::default()
349 )
350 .multiproof(targets.clone())
351 .unwrap(),
352 Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap()
353 );
354 }
355}