1use alloy_primitives::{
3 map::{Entry, HashMap},
4 Address, B256, U256,
5};
6use core::cell::RefCell;
7use revm::primitives::{
8 db::{Database, DatabaseRef},
9 AccountInfo, Bytecode, FlaggedStorage,
10};
11
12#[derive(Debug, Clone, Default)]
32pub struct CachedReads {
33 accounts: HashMap<Address, CachedAccount>,
34 contracts: HashMap<B256, Bytecode>,
35 block_hashes: HashMap<u64, B256>,
36}
37
38impl CachedReads {
41 pub fn as_db<DB>(&mut self, db: DB) -> CachedReadsDBRef<'_, DB> {
43 self.as_db_mut(db).into_db()
44 }
45
46 pub fn as_db_mut<DB>(&mut self, db: DB) -> CachedReadsDbMut<'_, DB> {
48 CachedReadsDbMut { cached: self, db }
49 }
50
51 pub fn insert_account(
53 &mut self,
54 address: Address,
55 info: AccountInfo,
56 storage: HashMap<U256, FlaggedStorage>,
57 ) {
58 self.accounts.insert(address, CachedAccount { info: Some(info), storage });
59 }
60
61 pub fn extend(&mut self, other: Self) {
65 self.accounts.extend(other.accounts);
66 self.contracts.extend(other.contracts);
67 self.block_hashes.extend(other.block_hashes);
68 }
69}
70
71#[derive(Debug)]
73pub struct CachedReadsDbMut<'a, DB> {
74 pub cached: &'a mut CachedReads,
76 pub db: DB,
78}
79
80impl<'a, DB> CachedReadsDbMut<'a, DB> {
81 pub const fn into_db(self) -> CachedReadsDBRef<'a, DB> {
84 CachedReadsDBRef { inner: RefCell::new(self) }
85 }
86
87 pub const fn inner(&self) -> &DB {
89 &self.db
90 }
91}
92
93impl<DB, T> AsRef<T> for CachedReadsDbMut<'_, DB>
94where
95 DB: AsRef<T>,
96{
97 fn as_ref(&self) -> &T {
98 self.inner().as_ref()
99 }
100}
101
102impl<DB: DatabaseRef> Database for CachedReadsDbMut<'_, DB> {
103 type Error = <DB as DatabaseRef>::Error;
104
105 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
106 let basic = match self.cached.accounts.entry(address) {
107 Entry::Occupied(entry) => entry.get().info.clone(),
108 Entry::Vacant(entry) => {
109 entry.insert(CachedAccount::new(self.db.basic_ref(address)?)).info.clone()
110 }
111 };
112 Ok(basic)
113 }
114
115 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
116 let code = match self.cached.contracts.entry(code_hash) {
117 Entry::Occupied(entry) => entry.get().clone(),
118 Entry::Vacant(entry) => entry.insert(self.db.code_by_hash_ref(code_hash)?).clone(),
119 };
120 Ok(code)
121 }
122
123 fn storage(&mut self, address: Address, index: U256) -> Result<FlaggedStorage, Self::Error> {
124 match self.cached.accounts.entry(address) {
125 Entry::Occupied(mut acc_entry) => match acc_entry.get_mut().storage.entry(index) {
126 Entry::Occupied(entry) => Ok(*entry.get()),
127 Entry::Vacant(entry) => Ok(*entry.insert(self.db.storage_ref(address, index)?)),
128 },
129 Entry::Vacant(acc_entry) => {
130 let info = self.db.basic_ref(address)?;
132 let (account, value) = if info.is_some() {
133 let value = self.db.storage_ref(address, index)?;
134 let mut account = CachedAccount::new(info);
135 account.storage.insert(index, value);
136 (account, value)
137 } else {
138 (CachedAccount::new(info), FlaggedStorage::ZERO)
139 };
140 acc_entry.insert(account);
141 Ok(value)
142 }
143 }
144 }
145
146 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
147 let code = match self.cached.block_hashes.entry(number) {
148 Entry::Occupied(entry) => *entry.get(),
149 Entry::Vacant(entry) => *entry.insert(self.db.block_hash_ref(number)?),
150 };
151 Ok(code)
152 }
153}
154
155#[derive(Debug)]
160pub struct CachedReadsDBRef<'a, DB> {
161 pub inner: RefCell<CachedReadsDbMut<'a, DB>>,
163}
164
165impl<DB: DatabaseRef> DatabaseRef for CachedReadsDBRef<'_, DB> {
166 type Error = <DB as DatabaseRef>::Error;
167
168 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
169 self.inner.borrow_mut().basic(address)
170 }
171
172 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
173 self.inner.borrow_mut().code_by_hash(code_hash)
174 }
175
176 fn storage_ref(&self, address: Address, index: U256) -> Result<FlaggedStorage, Self::Error> {
177 self.inner.borrow_mut().storage(address, index)
178 }
179
180 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
181 self.inner.borrow_mut().block_hash(number)
182 }
183}
184
185#[derive(Debug, Clone)]
186struct CachedAccount {
187 info: Option<AccountInfo>,
188 storage: HashMap<U256, FlaggedStorage>,
189}
190
191impl CachedAccount {
192 fn new(info: Option<AccountInfo>) -> Self {
193 Self { info, storage: HashMap::default() }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_extend_with_two_cached_reads() {
203 let hash1 = B256::from_slice(&[1u8; 32]);
205 let hash2 = B256::from_slice(&[2u8; 32]);
206 let address1 = Address::from_slice(&[1u8; 20]);
207 let address2 = Address::from_slice(&[2u8; 20]);
208
209 let mut primary = {
211 let mut cache = CachedReads::default();
212 cache.accounts.insert(address1, CachedAccount::new(Some(AccountInfo::default())));
213 cache.contracts.insert(hash1, Bytecode::default());
214 cache.block_hashes.insert(1, hash1);
215 cache
216 };
217
218 let additional = {
220 let mut cache = CachedReads::default();
221 cache.accounts.insert(address2, CachedAccount::new(Some(AccountInfo::default())));
222 cache.contracts.insert(hash2, Bytecode::default());
223 cache.block_hashes.insert(2, hash2);
224 cache
225 };
226
227 primary.extend(additional);
229
230 assert!(
232 primary.accounts.len() == 2 &&
233 primary.contracts.len() == 2 &&
234 primary.block_hashes.len() == 2,
235 "All maps should contain 2 entries"
236 );
237
238 assert!(
240 primary.accounts.contains_key(&address1) &&
241 primary.accounts.contains_key(&address2) &&
242 primary.contracts.contains_key(&hash1) &&
243 primary.contracts.contains_key(&hash2) &&
244 primary.block_hashes.get(&1) == Some(&hash1) &&
245 primary.block_hashes.get(&2) == Some(&hash2),
246 "All expected entries should be present"
247 );
248 }
249}