1#[cfg(feature = "metrics")]
2use crate::metrics::ParallelStateRootMetrics;
3use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
4use alloy_primitives::B256;
5use alloy_rlp::{BufMut, Encodable};
6use itertools::Itertools;
7use reth_execution_errors::StorageRootError;
8use reth_provider::{
9 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
10 StateCommitmentProvider,
11};
12use reth_storage_errors::db::DatabaseError;
13use reth_trie::{
14 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
15 node_iter::{TrieElement, TrieNodeIter},
16 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
17 updates::TrieUpdates,
18 walker::TrieWalker,
19 HashBuilder, Nibbles, StorageRoot, TrieInput, TRIE_ACCOUNT_RLP_MAX_SIZE,
20};
21use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
22use std::{collections::HashMap, sync::Arc};
23use thiserror::Error;
24use tracing::*;
25
26#[derive(Debug)]
38pub struct ParallelStateRoot<Factory> {
39 view: ConsistentDbView<Factory>,
41 input: TrieInput,
43 #[cfg(feature = "metrics")]
45 metrics: ParallelStateRootMetrics,
46}
47
48impl<Factory> ParallelStateRoot<Factory> {
49 pub fn new(view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
51 Self {
52 view,
53 input,
54 #[cfg(feature = "metrics")]
55 metrics: ParallelStateRootMetrics::default(),
56 }
57 }
58}
59
60impl<Factory> ParallelStateRoot<Factory>
61where
62 Factory: DatabaseProviderFactory<Provider: BlockReader>
63 + StateCommitmentProvider
64 + Clone
65 + Send
66 + Sync
67 + 'static,
68{
69 pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
71 self.calculate(false).map(|(root, _)| root)
72 }
73
74 pub fn incremental_root_with_updates(
76 self,
77 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
78 self.calculate(true)
79 }
80
81 fn calculate(
82 self,
83 retain_updates: bool,
84 ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
85 let mut tracker = ParallelTrieTracker::default();
86 let trie_nodes_sorted = Arc::new(self.input.nodes.into_sorted());
87 let hashed_state_sorted = Arc::new(self.input.state.into_sorted());
88 let prefix_sets = self.input.prefix_sets.freeze();
89 let storage_root_targets = StorageRootTargets::new(
90 prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
91 prefix_sets.storage_prefix_sets,
92 );
93
94 tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
96 debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
97 let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
98 for (hashed_address, prefix_set) in
99 storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
100 {
101 let view = self.view.clone();
102 let hashed_state_sorted = hashed_state_sorted.clone();
103 let trie_nodes_sorted = trie_nodes_sorted.clone();
104 #[cfg(feature = "metrics")]
105 let metrics = self.metrics.storage_trie.clone();
106
107 let (tx, rx) = std::sync::mpsc::sync_channel(1);
108
109 rayon::spawn_fifo(move || {
110 let result = (|| -> Result<_, ParallelStateRootError> {
111 let provider_ro = view.provider_ro()?;
112 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
113 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
114 &trie_nodes_sorted,
115 );
116 let hashed_state = HashedPostStateCursorFactory::new(
117 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
118 &hashed_state_sorted,
119 );
120 Ok(StorageRoot::new_hashed(
121 trie_cursor_factory,
122 hashed_state,
123 hashed_address,
124 prefix_set,
125 #[cfg(feature = "metrics")]
126 metrics,
127 )
128 .calculate(retain_updates)?)
129 })();
130 let _ = tx.send(result);
131 });
132 storage_roots.insert(hashed_address, rx);
133 }
134
135 trace!(target: "trie::parallel_state_root", "calculating state root");
136 let mut trie_updates = TrieUpdates::default();
137
138 let provider_ro = self.view.provider_ro()?;
139 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
140 DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
141 &trie_nodes_sorted,
142 );
143 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
144 DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
145 &hashed_state_sorted,
146 );
147
148 let walker = TrieWalker::state_trie(
149 trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
150 prefix_sets.account_prefix_set,
151 )
152 .with_deletions_retained(retain_updates);
153 let mut account_node_iter = TrieNodeIter::state_trie(
154 walker,
155 hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
156 );
157
158 let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
159 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
160 while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
161 match node {
162 TrieElement::Branch(node) => {
163 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
164 }
165 TrieElement::Leaf(hashed_address, account) => {
166 let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) {
167 Some(rx) => rx.recv().map_err(|_| {
168 ParallelStateRootError::StorageRoot(StorageRootError::Database(
169 DatabaseError::Other(format!(
170 "channel closed for {hashed_address}"
171 )),
172 ))
173 })??,
174 None => {
177 tracker.inc_missed_leaves();
178 StorageRoot::new_hashed(
179 trie_cursor_factory.clone(),
180 hashed_cursor_factory.clone(),
181 hashed_address,
182 Default::default(),
183 #[cfg(feature = "metrics")]
184 self.metrics.storage_trie.clone(),
185 )
186 .calculate(retain_updates)?
187 }
188 };
189
190 if retain_updates {
191 trie_updates.insert_storage_updates(hashed_address, updates);
192 }
193
194 account_rlp.clear();
195 let account = account.into_trie_account(storage_root);
196 account.encode(&mut account_rlp as &mut dyn BufMut);
197 let is_private = false; hash_builder.add_leaf(
199 Nibbles::unpack(hashed_address),
200 &account_rlp,
201 is_private,
202 );
203 }
204 }
205 }
206
207 let root = hash_builder.root();
208
209 let removed_keys = account_node_iter.walker.take_removed_keys();
210 trie_updates.finalize(hash_builder, removed_keys, prefix_sets.destroyed_accounts);
211
212 let stats = tracker.finish();
213
214 #[cfg(feature = "metrics")]
215 self.metrics.record_state_trie(stats);
216
217 trace!(
218 target: "trie::parallel_state_root",
219 %root,
220 duration = ?stats.duration(),
221 branches_added = stats.branches_added(),
222 leaves_added = stats.leaves_added(),
223 missed_leaves = stats.missed_leaves(),
224 precomputed_storage_roots = stats.precomputed_storage_roots(),
225 "Calculated state root"
226 );
227
228 Ok((root, trie_updates))
229 }
230}
231
232#[derive(Error, Debug)]
234pub enum ParallelStateRootError {
235 #[error(transparent)]
237 StorageRoot(#[from] StorageRootError),
238 #[error(transparent)]
240 Provider(#[from] ProviderError),
241 #[error("{_0}")]
243 Other(String),
244}
245
246impl From<ParallelStateRootError> for ProviderError {
247 fn from(error: ParallelStateRootError) -> Self {
248 match error {
249 ParallelStateRootError::Provider(error) => error,
250 ParallelStateRootError::StorageRoot(StorageRootError::Database(error)) => {
251 Self::Database(error)
252 }
253 ParallelStateRootError::Other(other) => Self::Database(DatabaseError::Other(other)),
254 }
255 }
256}
257
258impl From<alloy_rlp::Error> for ParallelStateRootError {
259 fn from(error: alloy_rlp::Error) -> Self {
260 Self::Provider(ProviderError::Rlp(error))
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use alloy_primitives::{keccak256, Address, U256};
268 use rand::Rng;
269 use reth_primitives_traits::{Account, StorageEntry};
270 use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
271 use reth_trie::{test_utils, HashedPostState, HashedStorage};
272 use revm_state::FlaggedStorage;
273
274 #[test]
275 fn random_parallel_root() {
276 let factory = create_test_provider_factory();
277 let consistent_view = ConsistentDbView::new(factory.clone(), None);
278
279 let mut rng = rand::rng();
280 let mut state = (0..100)
281 .map(|_| {
282 let address = Address::random();
283 let account =
284 Account { balance: U256::from(rng.random::<u64>()), ..Default::default() };
285 let mut storage = HashMap::<B256, U256>::default();
286 let has_storage = rng.random_bool(0.7);
287 if has_storage {
288 for _ in 0..100 {
289 storage.insert(
290 B256::from(U256::from(rng.random::<u64>())),
291 U256::from(rng.random::<u64>()),
292 );
293 }
294 }
295 (address, (account, storage))
296 })
297 .collect::<HashMap<_, _>>();
298
299 {
300 let provider_rw = factory.provider_rw().unwrap();
301 provider_rw
302 .insert_account_for_hashing(
303 state.iter().map(|(address, (account, _))| (*address, Some(*account))),
304 )
305 .unwrap();
306 provider_rw
307 .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
308 (
309 *address,
310 storage.iter().map(|(slot, value)| StorageEntry {
311 key: *slot,
312 value: *value,
313 is_private: false,
314 }),
315 )
316 }))
317 .unwrap();
318 provider_rw.commit().unwrap();
319 }
320
321 assert_eq!(
322 ParallelStateRoot::new(consistent_view.clone(), Default::default())
323 .incremental_root()
324 .unwrap(),
325 test_utils::state_root(state.clone())
326 );
327
328 let mut hashed_state = HashedPostState::default();
329 for (address, (account, storage)) in &mut state {
330 let hashed_address = keccak256(address);
331
332 let should_update_account = rng.random_bool(0.5);
333 if should_update_account {
334 *account = Account { balance: U256::from(rng.random::<u64>()), ..*account };
335 hashed_state.accounts.insert(hashed_address, Some(*account));
336 }
337
338 let should_update_storage = rng.random_bool(0.3);
339 if should_update_storage {
340 for (slot, value) in storage.iter_mut() {
341 let hashed_slot = keccak256(slot);
342 *value = U256::from(rng.random::<u64>());
343 hashed_state
344 .storages
345 .entry(hashed_address)
346 .or_insert_with(HashedStorage::default)
347 .storage
348 .insert(hashed_slot, FlaggedStorage::from(*value));
349 }
350 }
351 }
352
353 assert_eq!(
354 ParallelStateRoot::new(consistent_view, TrieInput::from_state(hashed_state))
355 .incremental_root()
356 .unwrap(),
357 test_utils::state_root(state)
358 );
359 }
360}