1use super::ExecutedBlockWithTrieUpdates;
2use alloy_consensus::BlockHeader;
3use alloy_primitives::{keccak256, Address, BlockNumber, Bytes, StorageKey, B256};
4use reth_errors::ProviderResult;
5use reth_primitives_traits::{Account, Bytecode, NodePrimitives};
6use reth_storage_api::{
7 AccountReader, BlockHashReader, HashedPostStateProvider, StateProofProvider, StateProvider,
8 StateRootProvider, StorageRootProvider,
9};
10use reth_trie::{
11 updates::TrieUpdates, AccountProof, HashedPostState, HashedStorage, MultiProof,
12 MultiProofTargets, StorageMultiProof, TrieInput,
13};
14use revm_database::BundleState;
15use revm_state::FlaggedStorage;
16use std::sync::OnceLock;
17
18#[expect(missing_debug_implementations)]
21pub struct MemoryOverlayStateProviderRef<
22 'a,
23 N: NodePrimitives = reth_ethereum_primitives::EthPrimitives,
24> {
25 pub(crate) historical: Box<dyn StateProvider + 'a>,
27 pub(crate) in_memory: Vec<ExecutedBlockWithTrieUpdates<N>>,
29 pub(crate) trie_state: OnceLock<MemoryOverlayTrieState>,
31}
32
33pub type MemoryOverlayStateProvider<N> = MemoryOverlayStateProviderRef<'static, N>;
36
37impl<'a, N: NodePrimitives> MemoryOverlayStateProviderRef<'a, N> {
38 pub fn new(
46 historical: Box<dyn StateProvider + 'a>,
47 in_memory: Vec<ExecutedBlockWithTrieUpdates<N>>,
48 ) -> Self {
49 Self { historical, in_memory, trie_state: OnceLock::new() }
50 }
51
52 pub fn boxed(self) -> Box<dyn StateProvider + 'a> {
54 Box::new(self)
55 }
56
57 fn trie_state(&self) -> &MemoryOverlayTrieState {
59 self.trie_state.get_or_init(|| {
60 let mut trie_state = MemoryOverlayTrieState::default();
61 for block in self.in_memory.iter().rev() {
62 trie_state.state.extend_ref(block.hashed_state.as_ref());
63 trie_state.nodes.extend_ref(block.trie.as_ref());
64 }
65 trie_state
66 })
67 }
68}
69
70impl<N: NodePrimitives> BlockHashReader for MemoryOverlayStateProviderRef<'_, N> {
71 fn block_hash(&self, number: BlockNumber) -> ProviderResult<Option<B256>> {
72 for block in &self.in_memory {
73 if block.recovered_block().number() == number {
74 return Ok(Some(block.recovered_block().hash()));
75 }
76 }
77
78 self.historical.block_hash(number)
79 }
80
81 fn canonical_hashes_range(
82 &self,
83 start: BlockNumber,
84 end: BlockNumber,
85 ) -> ProviderResult<Vec<B256>> {
86 let range = start..end;
87 let mut earliest_block_number = None;
88 let mut in_memory_hashes = Vec::new();
89 for block in &self.in_memory {
90 if range.contains(&block.recovered_block().number()) {
91 in_memory_hashes.insert(0, block.recovered_block().hash());
92 earliest_block_number = Some(block.recovered_block().number());
93 }
94 }
95
96 let mut hashes =
97 self.historical.canonical_hashes_range(start, earliest_block_number.unwrap_or(end))?;
98 hashes.append(&mut in_memory_hashes);
99 Ok(hashes)
100 }
101}
102
103impl<N: NodePrimitives> AccountReader for MemoryOverlayStateProviderRef<'_, N> {
104 fn basic_account(&self, address: &Address) -> ProviderResult<Option<Account>> {
105 for block in &self.in_memory {
106 if let Some(account) = block.execution_output.account(address) {
107 return Ok(account);
108 }
109 }
110
111 self.historical.basic_account(address)
112 }
113}
114
115impl<N: NodePrimitives> StateRootProvider for MemoryOverlayStateProviderRef<'_, N> {
116 fn state_root(&self, state: HashedPostState) -> ProviderResult<B256> {
117 self.state_root_from_nodes(TrieInput::from_state(state))
118 }
119
120 fn state_root_from_nodes(&self, mut input: TrieInput) -> ProviderResult<B256> {
121 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
122 input.prepend_cached(nodes, state);
123 self.historical.state_root_from_nodes(input)
124 }
125
126 fn state_root_with_updates(
127 &self,
128 state: HashedPostState,
129 ) -> ProviderResult<(B256, TrieUpdates)> {
130 self.state_root_from_nodes_with_updates(TrieInput::from_state(state))
131 }
132
133 fn state_root_from_nodes_with_updates(
134 &self,
135 mut input: TrieInput,
136 ) -> ProviderResult<(B256, TrieUpdates)> {
137 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
138 input.prepend_cached(nodes, state);
139 self.historical.state_root_from_nodes_with_updates(input)
140 }
141}
142
143impl<N: NodePrimitives> StorageRootProvider for MemoryOverlayStateProviderRef<'_, N> {
144 fn storage_root(&self, address: Address, storage: HashedStorage) -> ProviderResult<B256> {
146 let state = &self.trie_state().state;
147 let mut hashed_storage =
148 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
149 hashed_storage.extend(&storage);
150 self.historical.storage_root(address, hashed_storage)
151 }
152
153 fn storage_proof(
155 &self,
156 address: Address,
157 slot: B256,
158 storage: HashedStorage,
159 ) -> ProviderResult<reth_trie::StorageProof> {
160 let state = &self.trie_state().state;
161 let mut hashed_storage =
162 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
163 hashed_storage.extend(&storage);
164 self.historical.storage_proof(address, slot, hashed_storage)
165 }
166
167 fn storage_multiproof(
169 &self,
170 address: Address,
171 slots: &[B256],
172 storage: HashedStorage,
173 ) -> ProviderResult<StorageMultiProof> {
174 let state = &self.trie_state().state;
175 let mut hashed_storage =
176 state.storages.get(&keccak256(address)).cloned().unwrap_or_default();
177 hashed_storage.extend(&storage);
178 self.historical.storage_multiproof(address, slots, hashed_storage)
179 }
180}
181
182impl<N: NodePrimitives> StateProofProvider for MemoryOverlayStateProviderRef<'_, N> {
183 fn proof(
184 &self,
185 mut input: TrieInput,
186 address: Address,
187 slots: &[B256],
188 ) -> ProviderResult<AccountProof> {
189 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
190 input.prepend_cached(nodes, state);
191 self.historical.proof(input, address, slots)
192 }
193
194 fn multiproof(
195 &self,
196 mut input: TrieInput,
197 targets: MultiProofTargets,
198 ) -> ProviderResult<MultiProof> {
199 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
200 input.prepend_cached(nodes, state);
201 self.historical.multiproof(input, targets)
202 }
203
204 fn witness(&self, mut input: TrieInput, target: HashedPostState) -> ProviderResult<Vec<Bytes>> {
205 let MemoryOverlayTrieState { nodes, state } = self.trie_state().clone();
206 input.prepend_cached(nodes, state);
207 self.historical.witness(input, target)
208 }
209}
210
211impl<N: NodePrimitives> HashedPostStateProvider for MemoryOverlayStateProviderRef<'_, N> {
212 fn hashed_post_state(&self, bundle_state: &BundleState) -> HashedPostState {
213 self.historical.hashed_post_state(bundle_state)
214 }
215}
216
217impl<N: NodePrimitives> StateProvider for MemoryOverlayStateProviderRef<'_, N> {
218 fn storage(
219 &self,
220 address: Address,
221 storage_key: StorageKey,
222 ) -> ProviderResult<Option<FlaggedStorage>> {
223 for block in &self.in_memory {
224 if let Some(value) = block.execution_output.storage(&address, storage_key.into()) {
225 return Ok(Some(value));
226 }
227 }
228
229 self.historical.storage(address, storage_key)
230 }
231
232 fn bytecode_by_hash(&self, code_hash: &B256) -> ProviderResult<Option<Bytecode>> {
233 for block in &self.in_memory {
234 if let Some(contract) = block.execution_output.bytecode(code_hash) {
235 return Ok(Some(contract));
236 }
237 }
238
239 self.historical.bytecode_by_hash(code_hash)
240 }
241}
242
243#[derive(Clone, Default, Debug)]
245pub(crate) struct MemoryOverlayTrieState {
246 pub(crate) nodes: TrieUpdates,
248 pub(crate) state: HashedPostState,
250}