reth_engine_tree/tree/
root.rs

1//! State root task related functionality.
2
3use alloy_primitives::map::HashSet;
4use rayon::iter::{ParallelBridge, ParallelIterator};
5use reth_evm::system_calls::OnStateHook;
6use reth_execution_errors::StateProofError;
7use reth_provider::{
8    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
9    StateCommitmentProvider,
10};
11use reth_trie::{
12    proof::Proof, updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof,
13    MultiProofTargets, Nibbles, TrieInput,
14};
15use reth_trie_db::DatabaseProof;
16use reth_trie_parallel::root::ParallelStateRootError;
17use reth_trie_sparse::{
18    blinded::{BlindedProvider, BlindedProviderFactory},
19    errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind},
20    SparseStateTrie,
21};
22use revm_primitives::{keccak256, EvmState, B256};
23use std::{
24    collections::BTreeMap,
25    ops::Deref,
26    sync::{
27        mpsc::{self, channel, Receiver, Sender},
28        Arc,
29    },
30    thread::{self},
31    time::{Duration, Instant},
32};
33use tracing::{debug, error, trace};
34
35/// The level below which the sparse trie hashes are calculated in [`update_sparse_trie`].
36const SPARSE_TRIE_INCREMENTAL_LEVEL: usize = 2;
37
38/// Result of the state root calculation
39pub(crate) type StateRootResult = Result<(B256, TrieUpdates), ParallelStateRootError>;
40
41/// Handle to a spawned state root task.
42#[derive(Debug)]
43#[allow(dead_code)]
44pub struct StateRootHandle {
45    /// Channel for receiving the final result.
46    rx: mpsc::Receiver<StateRootResult>,
47}
48
49#[allow(dead_code)]
50impl StateRootHandle {
51    /// Creates a new handle from a receiver.
52    pub(crate) const fn new(rx: mpsc::Receiver<StateRootResult>) -> Self {
53        Self { rx }
54    }
55
56    /// Waits for the state root calculation to complete.
57    pub fn wait_for_result(self) -> StateRootResult {
58        self.rx.recv().expect("state root task was dropped without sending result")
59    }
60}
61
62/// Common configuration for state root tasks
63#[derive(Debug)]
64pub struct StateRootConfig<Factory> {
65    /// View over the state in the database.
66    pub consistent_view: ConsistentDbView<Factory>,
67    /// Latest trie input.
68    pub input: Arc<TrieInput>,
69}
70
71/// Messages used internally by the state root task
72#[derive(Debug)]
73#[allow(dead_code)]
74pub enum StateRootMessage<BPF: BlindedProviderFactory> {
75    /// New state update from transaction execution
76    StateUpdate(EvmState),
77    /// Proof calculation completed for a specific state update
78    ProofCalculated(Box<ProofCalculated>),
79    /// Error during proof calculation
80    ProofCalculationError(StateProofError),
81    /// State root calculation completed
82    RootCalculated {
83        /// The updated sparse trie
84        trie: Box<SparseStateTrie<BPF>>,
85        /// Time taken to calculate the root
86        elapsed: Duration,
87    },
88    /// Error during state root calculation
89    RootCalculationError(SparseStateTrieError),
90    /// Signals state update stream end.
91    FinishedStateUpdates,
92}
93
94/// Message about completion of proof calculation for a specific state update
95#[derive(Debug)]
96pub struct ProofCalculated {
97    /// The state update that was used to calculate the proof
98    state_update: HashedPostState,
99    /// The proof targets
100    targets: MultiProofTargets,
101    /// The calculated proof
102    proof: MultiProof,
103    /// The index of this proof in the sequence of state updates
104    sequence_number: u64,
105}
106
107/// Handle to track proof calculation ordering
108#[derive(Debug, Default)]
109pub(crate) struct ProofSequencer {
110    /// The next proof sequence number to be produced.
111    next_sequence: u64,
112    /// The next sequence number expected to be delivered.
113    next_to_deliver: u64,
114    /// Buffer for out-of-order proofs and corresponding state updates
115    pending_proofs: BTreeMap<u64, (HashedPostState, MultiProofTargets, MultiProof)>,
116}
117
118impl ProofSequencer {
119    /// Creates a new proof sequencer
120    pub(crate) fn new() -> Self {
121        Self::default()
122    }
123
124    /// Gets the next sequence number and increments the counter
125    pub(crate) fn next_sequence(&mut self) -> u64 {
126        let seq = self.next_sequence;
127        self.next_sequence += 1;
128        seq
129    }
130
131    /// Adds a proof with the corresponding state update and returns all sequential proofs and state
132    /// updates if we have a continuous sequence
133    pub(crate) fn add_proof(
134        &mut self,
135        sequence: u64,
136        state_update: HashedPostState,
137        targets: MultiProofTargets,
138        proof: MultiProof,
139    ) -> Vec<(HashedPostState, MultiProofTargets, MultiProof)> {
140        if sequence >= self.next_to_deliver {
141            self.pending_proofs.insert(sequence, (state_update, targets, proof));
142        }
143
144        // return early if we don't have the next expected proof
145        if !self.pending_proofs.contains_key(&self.next_to_deliver) {
146            return Vec::new()
147        }
148
149        let mut consecutive_proofs = Vec::with_capacity(self.pending_proofs.len());
150        let mut current_sequence = self.next_to_deliver;
151
152        // keep collecting proofs and state updates as long as we have consecutive sequence numbers
153        while let Some(pending) = self.pending_proofs.remove(&current_sequence) {
154            consecutive_proofs.push(pending);
155            current_sequence += 1;
156
157            // if we don't have the next number, stop collecting
158            if !self.pending_proofs.contains_key(&current_sequence) {
159                break;
160            }
161        }
162
163        self.next_to_deliver += consecutive_proofs.len() as u64;
164
165        consecutive_proofs
166    }
167
168    /// Returns true if we still have pending proofs
169    pub(crate) fn has_pending(&self) -> bool {
170        !self.pending_proofs.is_empty()
171    }
172}
173
174/// A wrapper for the sender that signals completion when dropped
175#[allow(dead_code)]
176pub(crate) struct StateHookSender<BPF: BlindedProviderFactory>(Sender<StateRootMessage<BPF>>);
177
178#[allow(dead_code)]
179impl<BPF: BlindedProviderFactory> StateHookSender<BPF> {
180    pub(crate) const fn new(inner: Sender<StateRootMessage<BPF>>) -> Self {
181        Self(inner)
182    }
183}
184
185impl<BPF: BlindedProviderFactory> Deref for StateHookSender<BPF> {
186    type Target = Sender<StateRootMessage<BPF>>;
187
188    fn deref(&self) -> &Self::Target {
189        &self.0
190    }
191}
192
193impl<BPF: BlindedProviderFactory> Drop for StateHookSender<BPF> {
194    fn drop(&mut self) {
195        // Send completion signal when the sender is dropped
196        let _ = self.0.send(StateRootMessage::FinishedStateUpdates);
197    }
198}
199
200fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
201    let mut hashed_state = HashedPostState::default();
202
203    for (address, account) in update {
204        if account.is_touched() {
205            let hashed_address = keccak256(address);
206            trace!(target: "engine::root", ?address, ?hashed_address, "Adding account to state update");
207
208            let destroyed = account.is_selfdestructed();
209            let info = if destroyed { None } else { Some(account.info.into()) };
210            hashed_state.accounts.insert(hashed_address, info);
211
212            let mut changed_storage_iter = account
213                .storage
214                .into_iter()
215                .filter_map(|(slot, value)| {
216                    value.is_changed().then(|| (keccak256(B256::from(slot)), value.present_value))
217                })
218                .peekable();
219
220            if destroyed || changed_storage_iter.peek().is_some() {
221                hashed_state.storages.insert(
222                    hashed_address,
223                    HashedStorage::from_iter(destroyed, changed_storage_iter),
224                );
225            }
226        }
227    }
228
229    hashed_state
230}
231
232/// Standalone task that receives a transaction state stream and updates relevant
233/// data structures to calculate state root.
234///
235/// It is responsible of  initializing a blinded sparse trie and subscribe to
236/// transaction state stream. As it receives transaction execution results, it
237/// fetches the proofs for relevant accounts from the database and reveal them
238/// to the tree.
239/// Then it updates relevant leaves according to the result of the transaction.
240#[derive(Debug)]
241pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
242    /// Task configuration.
243    config: StateRootConfig<Factory>,
244    /// Receiver for state root related messages.
245    rx: Receiver<StateRootMessage<BPF>>,
246    /// Sender for state root related messages.
247    tx: Sender<StateRootMessage<BPF>>,
248    /// Proof targets that have been already fetched.
249    fetched_proof_targets: MultiProofTargets,
250    /// Proof sequencing handler.
251    proof_sequencer: ProofSequencer,
252    /// The sparse trie used for the state root calculation. If [`None`], then update is in
253    /// progress.
254    sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
255}
256
257#[allow(dead_code)]
258impl<'env, Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
259where
260    Factory: DatabaseProviderFactory<Provider: BlockReader>
261        + StateCommitmentProvider
262        + Clone
263        + Send
264        + Sync
265        + 'static,
266    ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
267    SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
268    BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP>
269        + Send
270        + Sync
271        + 'env,
272{
273    /// Creates a new state root task with the unified message channel
274    pub fn new(config: StateRootConfig<Factory>, blinded_provider: BPF) -> Self {
275        let (tx, rx) = channel();
276
277        Self {
278            config,
279            rx,
280            tx,
281            fetched_proof_targets: Default::default(),
282            proof_sequencer: ProofSequencer::new(),
283            sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))),
284        }
285    }
286
287    /// Spawns the state root task and returns a handle to await its result.
288    pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle {
289        let (tx, rx) = mpsc::sync_channel(1);
290        std::thread::Builder::new()
291            .name("State Root Task".to_string())
292            .spawn_scoped(scope, move || {
293                debug!(target: "engine::tree", "Starting state root task");
294
295                let result = rayon::scope(|scope| self.run(scope));
296                let _ = tx.send(result);
297            })
298            .expect("failed to spawn state root thread");
299
300        StateRootHandle::new(rx)
301    }
302
303    /// Returns a state hook to be used to send state updates to this task.
304    pub fn state_hook(&self) -> impl OnStateHook {
305        let state_hook = StateHookSender::new(self.tx.clone());
306
307        move |state: &EvmState| {
308            if let Err(error) = state_hook.send(StateRootMessage::StateUpdate(state.clone())) {
309                error!(target: "engine::root", ?error, "Failed to send state update");
310            }
311        }
312    }
313
314    /// Handles state updates.
315    ///
316    /// Returns proof targets derived from the state update.
317    fn on_state_update(
318        scope: &rayon::Scope<'env>,
319        view: ConsistentDbView<Factory>,
320        input: Arc<TrieInput>,
321        update: EvmState,
322        fetched_proof_targets: &mut MultiProofTargets,
323        proof_sequence_number: u64,
324        state_root_message_sender: Sender<StateRootMessage<BPF>>,
325    ) {
326        let hashed_state_update = evm_state_to_hashed_post_state(update);
327
328        let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets);
329        fetched_proof_targets.extend_ref(&proof_targets);
330
331        // Dispatch proof gathering for this state update
332        scope.spawn(move |_| {
333            let provider = match view.provider_ro() {
334                Ok(provider) => provider,
335                Err(error) => {
336                    error!(target: "engine::root", ?error, "Could not get provider");
337                    return;
338                }
339            };
340
341            // TODO: replace with parallel proof
342            let result = Proof::overlay_multiproof(
343                provider.tx_ref(),
344                // TODO(alexey): this clone can be expensive, we should avoid it
345                input.as_ref().clone(),
346                proof_targets.clone(),
347            );
348            match result {
349                Ok(proof) => {
350                    let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
351                        Box::new(ProofCalculated {
352                            state_update: hashed_state_update,
353                            targets: proof_targets,
354                            proof,
355                            sequence_number: proof_sequence_number,
356                        }),
357                    ));
358                }
359                Err(e) => {
360                    let _ =
361                        state_root_message_sender.send(StateRootMessage::ProofCalculationError(e));
362                }
363            }
364        });
365    }
366
367    /// Handler for new proof calculated, aggregates all the existing sequential proofs.
368    fn on_proof(
369        &mut self,
370        sequence_number: u64,
371        state_update: HashedPostState,
372        targets: MultiProofTargets,
373        proof: MultiProof,
374    ) -> Option<(HashedPostState, MultiProofTargets, MultiProof)> {
375        let ready_proofs =
376            self.proof_sequencer.add_proof(sequence_number, state_update, targets, proof);
377
378        if ready_proofs.is_empty() {
379            None
380        } else {
381            // Merge all ready proofs and state updates
382            ready_proofs.into_iter().reduce(|mut acc, (state_update, targets, proof)| {
383                acc.0.extend(state_update);
384                acc.1.extend(targets);
385                acc.2.extend(proof);
386                acc
387            })
388        }
389    }
390
391    /// Spawns root calculation with the current state and proofs.
392    fn spawn_root_calculation(
393        &mut self,
394        scope: &rayon::Scope<'env>,
395        state: HashedPostState,
396        targets: MultiProofTargets,
397        multiproof: MultiProof,
398    ) {
399        let Some(trie) = self.sparse_trie.take() else { return };
400
401        trace!(
402            target: "engine::root",
403            account_proofs = multiproof.account_subtree.len(),
404            storage_proofs = multiproof.storages.len(),
405            "Spawning root calculation"
406        );
407
408        // TODO(alexey): store proof targets in `ProofSequecner` to avoid recomputing them
409        let targets = get_proof_targets(&state, &targets);
410
411        let tx = self.tx.clone();
412        scope.spawn(move |_| {
413            let result = update_sparse_trie(trie, multiproof, targets, state);
414            match result {
415                Ok((trie, elapsed)) => {
416                    trace!(
417                        target: "engine::root",
418                        ?elapsed,
419                        "Root calculation completed, sending result"
420                    );
421                    let _ = tx.send(StateRootMessage::RootCalculated { trie, elapsed });
422                }
423                Err(e) => {
424                    let _ = tx.send(StateRootMessage::RootCalculationError(e));
425                }
426            }
427        });
428    }
429
430    fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult {
431        let mut current_state_update = HashedPostState::default();
432        let mut current_proof_targets = MultiProofTargets::default();
433        let mut current_multiproof = MultiProof::default();
434        let mut updates_received = 0;
435        let mut proofs_processed = 0;
436        let mut roots_calculated = 0;
437        let mut updates_finished = false;
438
439        loop {
440            match self.rx.recv() {
441                Ok(message) => match message {
442                    StateRootMessage::StateUpdate(update) => {
443                        updates_received += 1;
444                        trace!(
445                            target: "engine::root",
446                            len = update.len(),
447                            total_updates = updates_received,
448                            "Received new state update"
449                        );
450                        Self::on_state_update(
451                            scope,
452                            self.config.consistent_view.clone(),
453                            self.config.input.clone(),
454                            update,
455                            &mut self.fetched_proof_targets,
456                            self.proof_sequencer.next_sequence(),
457                            self.tx.clone(),
458                        );
459                    }
460                    StateRootMessage::FinishedStateUpdates => {
461                        updates_finished = true;
462                    }
463                    StateRootMessage::ProofCalculated(proof_calculated) => {
464                        proofs_processed += 1;
465                        trace!(
466                            target: "engine::root",
467                            sequence = proof_calculated.sequence_number,
468                            total_proofs = proofs_processed,
469                            "Processing calculated proof"
470                        );
471
472                        trace!(target: "engine::root", proof = ?proof_calculated.proof, "Proof calculated");
473
474                        if let Some((
475                            combined_state_update,
476                            combined_proof_targets,
477                            combined_proof,
478                        )) = self.on_proof(
479                            proof_calculated.sequence_number,
480                            proof_calculated.state_update,
481                            proof_calculated.targets,
482                            proof_calculated.proof,
483                        ) {
484                            if self.sparse_trie.is_none() {
485                                current_state_update.extend(combined_state_update);
486                                current_proof_targets.extend(combined_proof_targets);
487                                current_multiproof.extend(combined_proof);
488                            } else {
489                                self.spawn_root_calculation(
490                                    scope,
491                                    combined_state_update,
492                                    combined_proof_targets,
493                                    combined_proof,
494                                );
495                            }
496                        }
497                    }
498                    StateRootMessage::RootCalculated { trie, elapsed } => {
499                        roots_calculated += 1;
500                        trace!(
501                            target: "engine::root",
502                            ?elapsed,
503                            roots_calculated,
504                            proofs = proofs_processed,
505                            updates = updates_received,
506                            "Computed intermediate root"
507                        );
508                        self.sparse_trie = Some(trie);
509
510                        let has_new_proofs = !current_multiproof.account_subtree.is_empty() ||
511                            !current_multiproof.storages.is_empty();
512                        let all_proofs_received = proofs_processed >= updates_received;
513                        let no_pending = !self.proof_sequencer.has_pending();
514
515                        trace!(
516                            target: "engine::root",
517                            has_new_proofs,
518                            all_proofs_received,
519                            no_pending,
520                            "State check"
521                        );
522
523                        // only spawn new calculation if we have accumulated new proofs
524                        if has_new_proofs {
525                            trace!(
526                                target: "engine::root",
527                                account_proofs = current_multiproof.account_subtree.len(),
528                                storage_proofs = current_multiproof.storages.len(),
529                                "Spawning subsequent root calculation"
530                            );
531                            self.spawn_root_calculation(
532                                scope,
533                                std::mem::take(&mut current_state_update),
534                                std::mem::take(&mut current_proof_targets),
535                                std::mem::take(&mut current_multiproof),
536                            );
537                        } else if all_proofs_received && no_pending && updates_finished {
538                            debug!(
539                                target: "engine::root",
540                                total_updates = updates_received,
541                                total_proofs = proofs_processed,
542                                roots_calculated,
543                                "All proofs processed, ending calculation"
544                            );
545                            let mut trie = self
546                                .sparse_trie
547                                .take()
548                                .expect("sparse trie update should not be in progress");
549                            let root = trie.root().expect("sparse trie should be revealed");
550                            let trie_updates = trie
551                                .take_trie_updates()
552                                .expect("sparse trie should have updates retention enabled");
553                            return Ok((root, trie_updates));
554                        }
555                    }
556                    StateRootMessage::ProofCalculationError(e) => {
557                        return Err(ParallelStateRootError::Other(format!(
558                            "could not calculate multiproof: {e:?}"
559                        )))
560                    }
561                    StateRootMessage::RootCalculationError(e) => {
562                        return Err(ParallelStateRootError::Other(format!(
563                            "could not calculate state root: {e:?}"
564                        )))
565                    }
566                },
567                Err(_) => {
568                    // this means our internal message channel is closed, which shouldn't happen
569                    // in normal operation since we hold both ends
570                    error!(
571                        target: "engine::root",
572                        "Internal message channel closed unexpectedly"
573                    );
574                    return Err(ParallelStateRootError::Other(
575                        "Internal message channel closed unexpectedly".into(),
576                    ));
577                }
578            }
579        }
580    }
581}
582
583/// Returns accounts only with those storages that were not already fetched, and
584/// if there are no such storages and the account itself was already fetched, the
585/// account shouldn't be included.
586fn get_proof_targets(
587    state_update: &HashedPostState,
588    fetched_proof_targets: &MultiProofTargets,
589) -> MultiProofTargets {
590    let mut targets = MultiProofTargets::default();
591
592    // first collect all new accounts (not previously fetched)
593    for &hashed_address in state_update.accounts.keys() {
594        if !fetched_proof_targets.contains_key(&hashed_address) {
595            targets.insert(hashed_address, HashSet::default());
596        }
597    }
598
599    // then process storage slots for all accounts in the state update
600    for (hashed_address, storage) in &state_update.storages {
601        let fetched = fetched_proof_targets.get(hashed_address);
602        let mut changed_slots = storage
603            .storage
604            .keys()
605            .filter(|slot| !fetched.is_some_and(|f| f.contains(*slot)))
606            .peekable();
607
608        if changed_slots.peek().is_some() {
609            targets.entry(*hashed_address).or_default().extend(changed_slots);
610        }
611    }
612
613    targets
614}
615
616/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
617/// time it took.
618fn update_sparse_trie<
619    ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
620    SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
621    BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP> + Send + Sync,
622>(
623    mut trie: Box<SparseStateTrie<BPF>>,
624    multiproof: MultiProof,
625    targets: MultiProofTargets,
626    state: HashedPostState,
627) -> SparseStateTrieResult<(Box<SparseStateTrie<BPF>>, Duration)> {
628    trace!(target: "engine::root::sparse", "Updating sparse trie");
629    let started_at = Instant::now();
630
631    // Reveal new accounts and storage slots.
632    trie.reveal_multiproof(targets, multiproof)?;
633
634    // Update storage slots with new values and calculate storage roots.
635    let (tx, rx) = mpsc::channel();
636    state
637        .storages
638        .into_iter()
639        .map(|(address, storage)| (address, storage, trie.take_storage_trie(&address)))
640        .par_bridge()
641        .map(|(address, storage, storage_trie)| {
642            trace!(target: "engine::root::sparse", ?address, "Updating storage");
643            let mut storage_trie = storage_trie.ok_or(SparseTrieErrorKind::Blind)?;
644
645            if storage.wiped {
646                trace!(target: "engine::root::sparse", ?address, "Wiping storage");
647                storage_trie.wipe()?;
648            }
649            for (slot, value) in storage.storage {
650                let slot_nibbles = Nibbles::unpack(slot);
651                if value.is_zero() {
652                    trace!(target: "engine::root::sparse", ?address, ?slot, "Removing storage slot");
653                    storage_trie.remove_leaf(&slot_nibbles)?;
654                } else {
655                    trace!(target: "engine::root::sparse", ?address, ?slot, "Updating storage slot");
656                    storage_trie
657                        .update_leaf(slot_nibbles, alloy_rlp::encode_fixed_size(&value.value).to_vec())?;
658                }
659            }
660
661            storage_trie.root();
662
663            SparseStateTrieResult::Ok((address, storage_trie))
664        })
665        .for_each_init(|| tx.clone(), |tx, result| {
666            tx.send(result).unwrap()
667        });
668    drop(tx);
669    for result in rx {
670        let (address, storage_trie) = result?;
671        trie.insert_storage_trie(address, storage_trie);
672    }
673
674    // Update accounts with new values
675    for (address, account) in state.accounts {
676        trace!(target: "engine::root::sparse", ?address, "Updating account");
677        trie.update_account(address, account.unwrap_or_default())?;
678    }
679
680    trie.calculate_below_level(SPARSE_TRIE_INCREMENTAL_LEVEL);
681    let elapsed = started_at.elapsed();
682
683    Ok((trie, elapsed))
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use reth_primitives::{Account as RethAccount, StorageEntry};
690    use reth_provider::{
691        providers::ConsistentDbView, test_utils::create_test_provider_factory, HashingWriter,
692    };
693    use reth_testing_utils::generators::{self, Rng};
694    use reth_trie::{
695        hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory,
696        test_utils::state_root, trie_cursor::InMemoryTrieCursorFactory, TrieInput,
697    };
698    use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
699    use revm_primitives::{
700        Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot,
701        FlaggedStorage, HashMap, B256, KECCAK_EMPTY, U256,
702    };
703    use std::sync::Arc;
704
705    fn convert_revm_to_reth_account(revm_account: &RevmAccount) -> RethAccount {
706        RethAccount {
707            balance: revm_account.info.balance,
708            nonce: revm_account.info.nonce,
709            bytecode_hash: if revm_account.info.code_hash == KECCAK_EMPTY {
710                None
711            } else {
712                Some(revm_account.info.code_hash)
713            },
714        }
715    }
716
717    fn create_mock_state_updates(num_accounts: usize, updates_per_account: usize) -> Vec<EvmState> {
718        let mut rng = generators::rng();
719        let all_addresses: Vec<Address> = (0..num_accounts).map(|_| rng.gen()).collect();
720        let mut updates = Vec::new();
721
722        for _ in 0..updates_per_account {
723            let num_accounts_in_update = rng.gen_range(1..=num_accounts);
724            let mut state_update = EvmState::default();
725
726            let selected_addresses = &all_addresses[0..num_accounts_in_update];
727
728            for &address in selected_addresses {
729                let mut storage = HashMap::default();
730                if rng.gen_bool(0.7) {
731                    for _ in 0..rng.gen_range(1..10) {
732                        let slot = U256::from(rng.gen::<u64>());
733                        storage.insert(
734                            slot,
735                            EvmStorageSlot::new_changed(
736                                FlaggedStorage::ZERO,
737                                FlaggedStorage::new_from_value(rng.gen::<u64>()),
738                            ),
739                        );
740                    }
741                }
742
743                let account = RevmAccount {
744                    info: AccountInfo {
745                        balance: U256::from(rng.gen::<u64>()),
746                        nonce: rng.gen::<u64>(),
747                        code_hash: KECCAK_EMPTY,
748                        code: Some(Default::default()),
749                    },
750                    storage,
751                    status: AccountStatus::Touched,
752                };
753
754                state_update.insert(address, account);
755            }
756
757            updates.push(state_update);
758        }
759
760        updates
761    }
762
763    #[test]
764    fn test_state_root_task() {
765        reth_tracing::init_test_tracing();
766
767        let factory = create_test_provider_factory();
768
769        let state_updates = create_mock_state_updates(10, 10);
770        let mut hashed_state = HashedPostState::default();
771        let mut accumulated_state: HashMap<Address, (RethAccount, HashMap<B256, U256>)> =
772            HashMap::default();
773
774        {
775            let provider_rw = factory.provider_rw().expect("failed to get provider");
776
777            for update in &state_updates {
778                let account_updates = update.iter().map(|(address, account)| {
779                    (*address, Some(convert_revm_to_reth_account(account)))
780                });
781                provider_rw
782                    .insert_account_for_hashing(account_updates)
783                    .expect("failed to insert accounts");
784
785                let storage_updates = update.iter().map(|(address, account)| {
786                    let storage_entries =
787                        account.storage.iter().map(|(slot, value)| StorageEntry {
788                            key: B256::from(*slot),
789                            value: value.present_value.value,
790                            is_private: value.present_value.is_private,
791                        });
792                    (*address, storage_entries)
793                });
794                provider_rw
795                    .insert_storage_for_hashing(storage_updates)
796                    .expect("failed to insert storage");
797            }
798            provider_rw.commit().expect("failed to commit changes");
799        }
800
801        for update in &state_updates {
802            hashed_state.extend(evm_state_to_hashed_post_state(update.clone()));
803
804            for (address, account) in update {
805                let storage: HashMap<B256, U256> = account
806                    .storage
807                    .iter()
808                    .map(|(k, v)| (B256::from(*k), v.present_value.value))
809                    .collect();
810
811                let entry = accumulated_state.entry(*address).or_default();
812                entry.0 = convert_revm_to_reth_account(account);
813                entry.1.extend(storage);
814            }
815        }
816
817        let config = StateRootConfig {
818            consistent_view: ConsistentDbView::new(factory, None),
819            input: Arc::new(TrieInput::from_state(hashed_state)),
820        };
821        let provider = config.consistent_view.provider_ro().unwrap();
822        let nodes_sorted = config.input.nodes.clone().into_sorted();
823        let state_sorted = config.input.state.clone().into_sorted();
824        let blinded_provider_factory = ProofBlindedProviderFactory::new(
825            InMemoryTrieCursorFactory::new(
826                DatabaseTrieCursorFactory::new(provider.tx_ref()),
827                &nodes_sorted,
828            ),
829            HashedPostStateCursorFactory::new(
830                DatabaseHashedCursorFactory::new(provider.tx_ref()),
831                &state_sorted,
832            ),
833            Arc::new(config.input.prefix_sets.clone()),
834        );
835        let (root_from_task, _) = std::thread::scope(|std_scope| {
836            let task = StateRootTask::new(config, blinded_provider_factory);
837            let mut state_hook = task.state_hook();
838            let handle = task.spawn(std_scope);
839
840            for update in state_updates {
841                state_hook.on_state(&update);
842            }
843            drop(state_hook);
844
845            handle.wait_for_result().expect("task failed")
846        });
847        let root_from_base = state_root(accumulated_state);
848
849        assert_eq!(
850            root_from_task, root_from_base,
851            "State root mismatch: task={root_from_task:?}, base={root_from_base:?}"
852        );
853    }
854
855    #[test]
856    fn test_add_proof_in_sequence() {
857        let mut sequencer = ProofSequencer::new();
858        let proof1 = MultiProof::default();
859        let proof2 = MultiProof::default();
860        sequencer.next_sequence = 2;
861
862        let ready = sequencer.add_proof(
863            0,
864            HashedPostState::default(),
865            MultiProofTargets::default(),
866            proof1,
867        );
868        assert_eq!(ready.len(), 1);
869        assert!(!sequencer.has_pending());
870
871        let ready = sequencer.add_proof(
872            1,
873            HashedPostState::default(),
874            MultiProofTargets::default(),
875            proof2,
876        );
877        assert_eq!(ready.len(), 1);
878        assert!(!sequencer.has_pending());
879    }
880
881    #[test]
882    fn test_add_proof_out_of_order() {
883        let mut sequencer = ProofSequencer::new();
884        let proof1 = MultiProof::default();
885        let proof2 = MultiProof::default();
886        let proof3 = MultiProof::default();
887        sequencer.next_sequence = 3;
888
889        let ready = sequencer.add_proof(
890            2,
891            HashedPostState::default(),
892            MultiProofTargets::default(),
893            proof3,
894        );
895        assert_eq!(ready.len(), 0);
896        assert!(sequencer.has_pending());
897
898        let ready = sequencer.add_proof(
899            0,
900            HashedPostState::default(),
901            MultiProofTargets::default(),
902            proof1,
903        );
904        assert_eq!(ready.len(), 1);
905        assert!(sequencer.has_pending());
906
907        let ready = sequencer.add_proof(
908            1,
909            HashedPostState::default(),
910            MultiProofTargets::default(),
911            proof2,
912        );
913        assert_eq!(ready.len(), 2);
914        assert!(!sequencer.has_pending());
915    }
916
917    #[test]
918    fn test_add_proof_with_gaps() {
919        let mut sequencer = ProofSequencer::new();
920        let proof1 = MultiProof::default();
921        let proof3 = MultiProof::default();
922        sequencer.next_sequence = 3;
923
924        let ready = sequencer.add_proof(
925            0,
926            HashedPostState::default(),
927            MultiProofTargets::default(),
928            proof1,
929        );
930        assert_eq!(ready.len(), 1);
931
932        let ready = sequencer.add_proof(
933            2,
934            HashedPostState::default(),
935            MultiProofTargets::default(),
936            proof3,
937        );
938        assert_eq!(ready.len(), 0);
939        assert!(sequencer.has_pending());
940    }
941
942    #[test]
943    fn test_add_proof_duplicate_sequence() {
944        let mut sequencer = ProofSequencer::new();
945        let proof1 = MultiProof::default();
946        let proof2 = MultiProof::default();
947
948        let ready = sequencer.add_proof(
949            0,
950            HashedPostState::default(),
951            MultiProofTargets::default(),
952            proof1,
953        );
954        assert_eq!(ready.len(), 1);
955
956        let ready = sequencer.add_proof(
957            0,
958            HashedPostState::default(),
959            MultiProofTargets::default(),
960            proof2,
961        );
962        assert_eq!(ready.len(), 0);
963        assert!(!sequencer.has_pending());
964    }
965
966    #[test]
967    fn test_add_proof_batch_processing() {
968        let mut sequencer = ProofSequencer::new();
969        let proofs: Vec<_> = (0..5).map(|_| MultiProof::default()).collect();
970        sequencer.next_sequence = 5;
971
972        sequencer.add_proof(
973            4,
974            HashedPostState::default(),
975            MultiProofTargets::default(),
976            proofs[4].clone(),
977        );
978        sequencer.add_proof(
979            2,
980            HashedPostState::default(),
981            MultiProofTargets::default(),
982            proofs[2].clone(),
983        );
984        sequencer.add_proof(
985            1,
986            HashedPostState::default(),
987            MultiProofTargets::default(),
988            proofs[1].clone(),
989        );
990        sequencer.add_proof(
991            3,
992            HashedPostState::default(),
993            MultiProofTargets::default(),
994            proofs[3].clone(),
995        );
996
997        let ready = sequencer.add_proof(
998            0,
999            HashedPostState::default(),
1000            MultiProofTargets::default(),
1001            proofs[0].clone(),
1002        );
1003        assert_eq!(ready.len(), 5);
1004        assert!(!sequencer.has_pending());
1005    }
1006
1007    fn create_get_proof_targets_state() -> HashedPostState {
1008        let mut state = HashedPostState::default();
1009
1010        let addr1 = B256::random();
1011        let addr2 = B256::random();
1012        state.accounts.insert(addr1, Some(Default::default()));
1013        state.accounts.insert(addr2, Some(Default::default()));
1014
1015        let mut storage = HashedStorage::default();
1016        let slot1 = B256::random();
1017        let slot2 = B256::random();
1018        storage.storage.insert(slot1, FlaggedStorage::ZERO);
1019        storage.storage.insert(slot2, FlaggedStorage::new_from_value(1));
1020        state.storages.insert(addr1, storage);
1021
1022        state
1023    }
1024
1025    #[test]
1026    fn test_get_proof_targets_new_account_targets() {
1027        let state = create_get_proof_targets_state();
1028        let fetched = MultiProofTargets::default();
1029
1030        let targets = get_proof_targets(&state, &fetched);
1031
1032        // should return all accounts as targets since nothing was fetched before
1033        assert_eq!(targets.len(), state.accounts.len());
1034        for addr in state.accounts.keys() {
1035            assert!(targets.contains_key(addr));
1036        }
1037    }
1038
1039    #[test]
1040    fn test_get_proof_targets_new_storage_targets() {
1041        let state = create_get_proof_targets_state();
1042        let fetched = MultiProofTargets::default();
1043
1044        let targets = get_proof_targets(&state, &fetched);
1045
1046        // verify storage slots are included for accounts with storage
1047        for (addr, storage) in &state.storages {
1048            assert!(targets.contains_key(addr));
1049            let target_slots = &targets[addr];
1050            assert_eq!(target_slots.len(), storage.storage.len());
1051            for slot in storage.storage.keys() {
1052                assert!(target_slots.contains(slot));
1053            }
1054        }
1055    }
1056
1057    #[test]
1058    fn test_get_proof_targets_filter_already_fetched_accounts() {
1059        let state = create_get_proof_targets_state();
1060        let mut fetched = MultiProofTargets::default();
1061
1062        // select an account that has no storage updates
1063        let fetched_addr = state
1064            .accounts
1065            .keys()
1066            .find(|&&addr| !state.storages.contains_key(&addr))
1067            .expect("Should have an account without storage");
1068
1069        // mark the account as already fetched
1070        fetched.insert(*fetched_addr, HashSet::default());
1071
1072        let targets = get_proof_targets(&state, &fetched);
1073
1074        // should not include the already fetched account since it has no storage updates
1075        assert!(!targets.contains_key(fetched_addr));
1076        // other accounts should still be included
1077        assert_eq!(targets.len(), state.accounts.len() - 1);
1078    }
1079
1080    #[test]
1081    fn test_get_proof_targets_filter_already_fetched_storage() {
1082        let state = create_get_proof_targets_state();
1083        let mut fetched = MultiProofTargets::default();
1084
1085        // mark one storage slot as already fetched
1086        let (addr, storage) = state.storages.iter().next().unwrap();
1087        let mut fetched_slots = HashSet::default();
1088        let fetched_slot = *storage.storage.keys().next().unwrap();
1089        fetched_slots.insert(fetched_slot);
1090        fetched.insert(*addr, fetched_slots);
1091
1092        let targets = get_proof_targets(&state, &fetched);
1093
1094        // should not include the already fetched storage slot
1095        let target_slots = &targets[addr];
1096        assert!(!target_slots.contains(&fetched_slot));
1097        assert_eq!(target_slots.len(), storage.storage.len() - 1);
1098    }
1099
1100    #[test]
1101    fn test_get_proof_targets_empty_state() {
1102        let state = HashedPostState::default();
1103        let fetched = MultiProofTargets::default();
1104
1105        let targets = get_proof_targets(&state, &fetched);
1106
1107        assert!(targets.is_empty());
1108    }
1109
1110    #[test]
1111    fn test_get_proof_targets_mixed_fetched_state() {
1112        let mut state = HashedPostState::default();
1113        let mut fetched = MultiProofTargets::default();
1114
1115        let addr1 = B256::random();
1116        let addr2 = B256::random();
1117        let slot1 = B256::random();
1118        let slot2 = B256::random();
1119
1120        state.accounts.insert(addr1, Some(Default::default()));
1121        state.accounts.insert(addr2, Some(Default::default()));
1122
1123        let mut storage = HashedStorage::default();
1124        storage.storage.insert(slot1, FlaggedStorage::ZERO);
1125        storage.storage.insert(slot2, FlaggedStorage::new_from_value(1));
1126        state.storages.insert(addr1, storage);
1127
1128        let mut fetched_slots = HashSet::default();
1129        fetched_slots.insert(slot1);
1130        fetched.insert(addr1, fetched_slots);
1131
1132        let targets = get_proof_targets(&state, &fetched);
1133
1134        assert!(targets.contains_key(&addr2));
1135        assert!(!targets[&addr1].contains(&slot1));
1136        assert!(targets[&addr1].contains(&slot2));
1137    }
1138
1139    #[test]
1140    fn test_get_proof_targets_unmodified_account_with_storage() {
1141        let mut state = HashedPostState::default();
1142        let fetched = MultiProofTargets::default();
1143
1144        let addr = B256::random();
1145        let slot1 = B256::random();
1146        let slot2 = B256::random();
1147
1148        // don't add the account to state.accounts (simulating unmodified account)
1149        // but add storage updates for this account
1150        let mut storage = HashedStorage::default();
1151        storage.storage.insert(slot1, FlaggedStorage::new_from_value(1));
1152        storage.storage.insert(slot2, FlaggedStorage::new_from_value(2));
1153        state.storages.insert(addr, storage);
1154
1155        assert!(!state.accounts.contains_key(&addr));
1156        assert!(!fetched.contains_key(&addr));
1157
1158        let targets = get_proof_targets(&state, &fetched);
1159
1160        // verify that we still get the storage slots for the unmodified account
1161        assert!(targets.contains_key(&addr));
1162
1163        let target_slots = &targets[&addr];
1164        assert_eq!(target_slots.len(), 2);
1165        assert!(target_slots.contains(&slot1));
1166        assert!(target_slots.contains(&slot2));
1167    }
1168}