1use crate::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory, PrefixSetLoader};
2use alloy_primitives::{
3    map::{AddressMap, B256Map},
4    Address, BlockNumber, B256,
5};
6use reth_db_api::{
7    cursor::DbCursorRO,
8    models::{AccountBeforeTx, BlockNumberAddress},
9    tables,
10    transaction::DbTx,
11    DatabaseError,
12};
13use reth_execution_errors::StateRootError;
14use reth_trie::{
15    hashed_cursor::HashedPostStateCursorFactory, trie_cursor::InMemoryTrieCursorFactory,
16    updates::TrieUpdates, HashedPostState, HashedStorage, KeccakKeyHasher, KeyHasher, StateRoot,
17    StateRootProgress, TrieInput,
18};
19use revm::state::FlaggedStorage;
20use std::{collections::HashMap, ops::RangeInclusive};
21use tracing::debug;
22
23pub trait DatabaseStateRoot<'a, TX>: Sized {
25    fn from_tx(tx: &'a TX) -> Self;
27
28    fn incremental_root_calculator(
35        tx: &'a TX,
36        range: RangeInclusive<BlockNumber>,
37    ) -> Result<Self, StateRootError>;
38
39    fn incremental_root(
46        tx: &'a TX,
47        range: RangeInclusive<BlockNumber>,
48    ) -> Result<B256, StateRootError>;
49
50    fn incremental_root_with_updates(
59        tx: &'a TX,
60        range: RangeInclusive<BlockNumber>,
61    ) -> Result<(B256, TrieUpdates), StateRootError>;
62
63    fn incremental_root_with_progress(
70        tx: &'a TX,
71        range: RangeInclusive<BlockNumber>,
72    ) -> Result<StateRootProgress, StateRootError>;
73
74    fn overlay_root(tx: &'a TX, post_state: HashedPostState) -> Result<B256, StateRootError>;
107
108    fn overlay_root_with_updates(
111        tx: &'a TX,
112        post_state: HashedPostState,
113    ) -> Result<(B256, TrieUpdates), StateRootError>;
114
115    fn overlay_root_from_nodes(tx: &'a TX, input: TrieInput) -> Result<B256, StateRootError>;
117
118    fn overlay_root_from_nodes_with_updates(
121        tx: &'a TX,
122        input: TrieInput,
123    ) -> Result<(B256, TrieUpdates), StateRootError>;
124}
125
126pub trait DatabaseHashedPostState<TX>: Sized {
128    fn from_reverts<KH: KeyHasher>(tx: &TX, from: BlockNumber) -> Result<Self, DatabaseError>;
131}
132
133impl<'a, TX: DbTx> DatabaseStateRoot<'a, TX>
134    for StateRoot<DatabaseTrieCursorFactory<'a, TX>, DatabaseHashedCursorFactory<'a, TX>>
135{
136    fn from_tx(tx: &'a TX) -> Self {
137        Self::new(DatabaseTrieCursorFactory::new(tx), DatabaseHashedCursorFactory::new(tx))
138    }
139
140    fn incremental_root_calculator(
141        tx: &'a TX,
142        range: RangeInclusive<BlockNumber>,
143    ) -> Result<Self, StateRootError> {
144        let loaded_prefix_sets = PrefixSetLoader::<_, KeccakKeyHasher>::new(tx).load(range)?;
145        Ok(Self::from_tx(tx).with_prefix_sets(loaded_prefix_sets))
146    }
147
148    fn incremental_root(
149        tx: &'a TX,
150        range: RangeInclusive<BlockNumber>,
151    ) -> Result<B256, StateRootError> {
152        debug!(target: "trie::loader", ?range, "incremental state root");
153        Self::incremental_root_calculator(tx, range)?.root()
154    }
155
156    fn incremental_root_with_updates(
157        tx: &'a TX,
158        range: RangeInclusive<BlockNumber>,
159    ) -> Result<(B256, TrieUpdates), StateRootError> {
160        debug!(target: "trie::loader", ?range, "incremental state root");
161        Self::incremental_root_calculator(tx, range)?.root_with_updates()
162    }
163
164    fn incremental_root_with_progress(
165        tx: &'a TX,
166        range: RangeInclusive<BlockNumber>,
167    ) -> Result<StateRootProgress, StateRootError> {
168        debug!(target: "trie::loader", ?range, "incremental state root with progress");
169        Self::incremental_root_calculator(tx, range)?.root_with_progress()
170    }
171
172    fn overlay_root(tx: &'a TX, post_state: HashedPostState) -> Result<B256, StateRootError> {
173        let prefix_sets = post_state.construct_prefix_sets().freeze();
174        let state_sorted = post_state.into_sorted();
175        StateRoot::new(
176            DatabaseTrieCursorFactory::new(tx),
177            HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
178        )
179        .with_prefix_sets(prefix_sets)
180        .root()
181    }
182
183    fn overlay_root_with_updates(
184        tx: &'a TX,
185        post_state: HashedPostState,
186    ) -> Result<(B256, TrieUpdates), StateRootError> {
187        let prefix_sets = post_state.construct_prefix_sets().freeze();
188        let state_sorted = post_state.into_sorted();
189        StateRoot::new(
190            DatabaseTrieCursorFactory::new(tx),
191            HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
192        )
193        .with_prefix_sets(prefix_sets)
194        .root_with_updates()
195    }
196
197    fn overlay_root_from_nodes(tx: &'a TX, input: TrieInput) -> Result<B256, StateRootError> {
198        let state_sorted = input.state.into_sorted();
199        let nodes_sorted = input.nodes.into_sorted();
200        StateRoot::new(
201            InMemoryTrieCursorFactory::new(DatabaseTrieCursorFactory::new(tx), &nodes_sorted),
202            HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
203        )
204        .with_prefix_sets(input.prefix_sets.freeze())
205        .root()
206    }
207
208    fn overlay_root_from_nodes_with_updates(
209        tx: &'a TX,
210        input: TrieInput,
211    ) -> Result<(B256, TrieUpdates), StateRootError> {
212        let state_sorted = input.state.into_sorted();
213        let nodes_sorted = input.nodes.into_sorted();
214        StateRoot::new(
215            InMemoryTrieCursorFactory::new(DatabaseTrieCursorFactory::new(tx), &nodes_sorted),
216            HashedPostStateCursorFactory::new(DatabaseHashedCursorFactory::new(tx), &state_sorted),
217        )
218        .with_prefix_sets(input.prefix_sets.freeze())
219        .root_with_updates()
220    }
221}
222
223impl<TX: DbTx> DatabaseHashedPostState<TX> for HashedPostState {
224    fn from_reverts<KH: KeyHasher>(tx: &TX, from: BlockNumber) -> Result<Self, DatabaseError> {
225        let mut accounts = HashMap::new();
227        let mut account_changesets_cursor = tx.cursor_read::<tables::AccountChangeSets>()?;
228        for entry in account_changesets_cursor.walk_range(from..)? {
229            let (_, AccountBeforeTx { address, info }) = entry?;
230            accounts.entry(address).or_insert(info);
231        }
232
233        let mut storages = AddressMap::<B256Map<FlaggedStorage>>::default();
235        let mut storage_changesets_cursor = tx.cursor_read::<tables::StorageChangeSets>()?;
236        for entry in
237            storage_changesets_cursor.walk_range(BlockNumberAddress((from, Address::ZERO))..)?
238        {
239            let (BlockNumberAddress((_, address)), storage) = entry?;
240            let account_storage = storages.entry(address).or_default();
241            account_storage.entry(storage.key).or_insert(storage.into());
242        }
243
244        let hashed_accounts =
245            accounts.into_iter().map(|(address, info)| (KH::hash_key(address), info)).collect();
246
247        let hashed_storages = storages
248            .into_iter()
249            .map(|(address, storage)| {
250                (
251                    KH::hash_key(address),
252                    HashedStorage::from_iter(
253                        false,
257                        storage.into_iter().map(|(slot, value)| (KH::hash_key(slot), value)),
258                    ),
259                )
260            })
261            .collect();
262
263        Ok(Self { accounts: hashed_accounts, storages: hashed_storages })
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use alloy_primitives::{hex, map::HashMap, Address, U256};
271    use reth_db::test_utils::create_test_rw_db;
272    use reth_db_api::database::Database;
273    use reth_trie::KeccakKeyHasher;
274    use revm::state::AccountInfo;
275    use revm_database::BundleState;
276
277    #[test]
278    fn from_bundle_state_with_rayon() {
279        let address1 = Address::with_last_byte(1);
280        let address2 = Address::with_last_byte(2);
281        let slot1 = U256::from(1015);
282        let slot2 = U256::from(2015);
283
284        let account1 = AccountInfo { nonce: 1, ..Default::default() };
285        let account2 = AccountInfo { nonce: 2, ..Default::default() };
286
287        let bundle_state = BundleState::builder(2..=2)
288            .state_present_account_info(address1, account1)
289            .state_present_account_info(address2, account2)
290            .state_storage(
291                address1,
292                HashMap::from_iter([(
293                    slot1,
294                    (FlaggedStorage::ZERO, FlaggedStorage::new_from_value(10)),
295                )]),
296            )
297            .state_storage(
298                address2,
299                HashMap::from_iter([(
300                    slot2,
301                    (FlaggedStorage::ZERO, FlaggedStorage::new_from_value(20)),
302                )]),
303            )
304            .build();
305        assert_eq!(bundle_state.reverts.len(), 1);
306
307        let post_state = HashedPostState::from_bundle_state::<KeccakKeyHasher>(&bundle_state.state);
308        assert_eq!(post_state.accounts.len(), 2);
309        assert_eq!(post_state.storages.len(), 2);
310
311        let db = create_test_rw_db();
312        let tx = db.tx().expect("failed to create transaction");
313        assert_eq!(
314            StateRoot::overlay_root(&tx, post_state).unwrap(),
315            hex!("b464525710cafcf5d4044ac85b72c08b1e76231b8d91f288fe438cc41d8eaafd")
316        );
317    }
318}