reth_stages/stages/
utils.rs

1//! Utils for `stages`.
2use alloy_primitives::{BlockNumber, TxNumber};
3use reth_config::config::EtlConfig;
4use reth_db::BlockNumberList;
5use reth_db_api::{
6    cursor::{DbCursorRO, DbCursorRW},
7    models::sharded_key::NUM_OF_INDICES_IN_SHARD,
8    table::{Decompress, Table},
9    transaction::{DbTx, DbTxMut},
10    DatabaseError,
11};
12use reth_etl::Collector;
13use reth_primitives::StaticFileSegment;
14use reth_provider::{
15    providers::StaticFileProvider, BlockReader, DBProvider, ProviderError,
16    StaticFileProviderFactory,
17};
18use reth_stages_api::StageError;
19use std::{collections::HashMap, hash::Hash, ops::RangeBounds};
20use tracing::info;
21
22/// Number of blocks before pushing indices from cache to [`Collector`]
23const DEFAULT_CACHE_THRESHOLD: u64 = 100_000;
24
25/// Collects all history (`H`) indices for a range of changesets (`CS`) and stores them in a
26/// [`Collector`].
27///
28/// ## Process
29/// The function utilizes a `HashMap` cache with a structure of `PartialKey` (`P`) (Address or
30/// Address.StorageKey) to `BlockNumberList`. When the cache exceeds its capacity, its contents are
31/// moved to a [`Collector`]. Here, each entry's key is a concatenation of `PartialKey` and the
32/// highest block number in its list.
33///
34/// ## Example
35/// 1. Initial Cache State: `{ Address1: [1,2,3], ... }`
36/// 2. Cache is flushed to the `Collector`.
37/// 3. Updated Cache State: `{ Address1: [100,300], ... }`
38/// 4. Cache is flushed again.
39///
40/// As a result, the `Collector` will contain entries such as `(Address1.3, [1,2,3])` and
41/// `(Address1.300, [100,300])`. The entries may be stored across one or more files.
42pub(crate) fn collect_history_indices<Provider, CS, H, P>(
43    provider: &Provider,
44    range: impl RangeBounds<CS::Key>,
45    sharded_key_factory: impl Fn(P, BlockNumber) -> H::Key,
46    partial_key_factory: impl Fn((CS::Key, CS::Value)) -> (u64, P),
47    etl_config: &EtlConfig,
48) -> Result<Collector<H::Key, H::Value>, StageError>
49where
50    Provider: DBProvider,
51    CS: Table,
52    H: Table<Value = BlockNumberList>,
53    P: Copy + Eq + Hash,
54{
55    let mut changeset_cursor = provider.tx_ref().cursor_read::<CS>()?;
56
57    let mut collector = Collector::new(etl_config.file_size, etl_config.dir.clone());
58    let mut cache: HashMap<P, Vec<u64>> = HashMap::default();
59
60    let mut collect = |cache: &HashMap<P, Vec<u64>>| {
61        for (key, indices) in cache {
62            let last = indices.last().expect("qed");
63            collector.insert(
64                sharded_key_factory(*key, *last),
65                BlockNumberList::new_pre_sorted(indices.iter().copied()),
66            )?;
67        }
68        Ok::<(), StageError>(())
69    };
70
71    // observability
72    let total_changesets = provider.tx_ref().entries::<CS>()?;
73    let interval = (total_changesets / 1000).max(1);
74
75    let mut flush_counter = 0;
76    let mut current_block_number = u64::MAX;
77    for (idx, entry) in changeset_cursor.walk_range(range)?.enumerate() {
78        let (block_number, key) = partial_key_factory(entry?);
79        cache.entry(key).or_default().push(block_number);
80
81        if idx > 0 && idx % interval == 0 && total_changesets > 1000 {
82            info!(target: "sync::stages::index_history", progress = %format!("{:.4}%", (idx as f64 / total_changesets as f64) * 100.0), "Collecting indices");
83        }
84
85        // Make sure we only flush the cache every DEFAULT_CACHE_THRESHOLD blocks.
86        if current_block_number != block_number {
87            current_block_number = block_number;
88            flush_counter += 1;
89            if flush_counter > DEFAULT_CACHE_THRESHOLD {
90                collect(&cache)?;
91                cache.clear();
92                flush_counter = 0;
93            }
94        }
95    }
96    collect(&cache)?;
97
98    Ok(collector)
99}
100
101/// Given a [`Collector`] created by [`collect_history_indices`] it iterates all entries, loading
102/// the indices into the database in shards.
103///
104///  ## Process
105/// Iterates over elements, grouping indices by their partial keys (e.g., `Address` or
106/// `Address.StorageKey`). It flushes indices to disk when reaching a shard's max length
107/// (`NUM_OF_INDICES_IN_SHARD`) or when the partial key changes, ensuring the last previous partial
108/// key shard is stored.
109pub(crate) fn load_history_indices<Provider, H, P>(
110    provider: &Provider,
111    mut collector: Collector<H::Key, H::Value>,
112    append_only: bool,
113    sharded_key_factory: impl Clone + Fn(P, u64) -> <H as Table>::Key,
114    decode_key: impl Fn(Vec<u8>) -> Result<<H as Table>::Key, DatabaseError>,
115    get_partial: impl Fn(<H as Table>::Key) -> P,
116) -> Result<(), StageError>
117where
118    Provider: DBProvider<Tx: DbTxMut>,
119    H: Table<Value = BlockNumberList>,
120    P: Copy + Default + Eq,
121{
122    let mut write_cursor = provider.tx_ref().cursor_write::<H>()?;
123    let mut current_partial = P::default();
124    let mut current_list = Vec::<u64>::new();
125
126    // observability
127    let total_entries = collector.len();
128    let interval = (total_entries / 100).max(1);
129
130    for (index, element) in collector.iter()?.enumerate() {
131        let (k, v) = element?;
132        let sharded_key = decode_key(k)?;
133        let new_list = BlockNumberList::decompress_owned(v)?;
134
135        if index > 0 && index % interval == 0 && total_entries > 100 {
136            info!(target: "sync::stages::index_history", progress = %format!("{:.2}%", (index as f64 / total_entries as f64) * 100.0), "Writing indices");
137        }
138
139        // AccountsHistory: `Address`.
140        // StorageHistory: `Address.StorageKey`.
141        let partial_key = get_partial(sharded_key);
142
143        if current_partial != partial_key {
144            // We have reached the end of this subset of keys so
145            // we need to flush its last indice shard.
146            load_indices(
147                &mut write_cursor,
148                current_partial,
149                &mut current_list,
150                &sharded_key_factory,
151                append_only,
152                LoadMode::Flush,
153            )?;
154
155            current_partial = partial_key;
156            current_list.clear();
157
158            // If it's not the first sync, there might an existing shard already, so we need to
159            // merge it with the one coming from the collector
160            if !append_only {
161                if let Some((_, last_database_shard)) =
162                    write_cursor.seek_exact(sharded_key_factory(current_partial, u64::MAX))?
163                {
164                    current_list.extend(last_database_shard.iter());
165                }
166            }
167        }
168
169        current_list.extend(new_list.iter());
170        load_indices(
171            &mut write_cursor,
172            current_partial,
173            &mut current_list,
174            &sharded_key_factory,
175            append_only,
176            LoadMode::KeepLast,
177        )?;
178    }
179
180    // There will be one remaining shard that needs to be flushed to DB.
181    load_indices(
182        &mut write_cursor,
183        current_partial,
184        &mut current_list,
185        &sharded_key_factory,
186        append_only,
187        LoadMode::Flush,
188    )?;
189
190    Ok(())
191}
192
193/// Shard and insert the indices list according to [`LoadMode`] and its length.
194pub(crate) fn load_indices<H, C, P>(
195    cursor: &mut C,
196    partial_key: P,
197    list: &mut Vec<BlockNumber>,
198    sharded_key_factory: &impl Fn(P, BlockNumber) -> <H as Table>::Key,
199    append_only: bool,
200    mode: LoadMode,
201) -> Result<(), StageError>
202where
203    C: DbCursorRO<H> + DbCursorRW<H>,
204    H: Table<Value = BlockNumberList>,
205    P: Copy,
206{
207    if list.len() > NUM_OF_INDICES_IN_SHARD || mode.is_flush() {
208        let chunks = list
209            .chunks(NUM_OF_INDICES_IN_SHARD)
210            .map(|chunks| chunks.to_vec())
211            .collect::<Vec<Vec<u64>>>();
212
213        let mut iter = chunks.into_iter().peekable();
214        while let Some(chunk) = iter.next() {
215            let mut highest = *chunk.last().expect("at least one index");
216
217            if !mode.is_flush() && iter.peek().is_none() {
218                *list = chunk;
219            } else {
220                if iter.peek().is_none() {
221                    highest = u64::MAX;
222                }
223                let key = sharded_key_factory(partial_key, highest);
224                let value = BlockNumberList::new_pre_sorted(chunk);
225
226                if append_only {
227                    cursor.append(key, value)?;
228                } else {
229                    cursor.upsert(key, value)?;
230                }
231            }
232        }
233    }
234
235    Ok(())
236}
237
238/// Mode on how to load index shards into the database.
239pub(crate) enum LoadMode {
240    /// Keep the last shard in memory and don't flush it to the database.
241    KeepLast,
242    /// Flush all shards into the database.
243    Flush,
244}
245
246impl LoadMode {
247    const fn is_flush(&self) -> bool {
248        matches!(self, Self::Flush)
249    }
250}
251
252/// Called when database is ahead of static files. Attempts to find the first block we are missing
253/// transactions for.
254pub(crate) fn missing_static_data_error<Provider>(
255    last_tx_num: TxNumber,
256    static_file_provider: &StaticFileProvider<Provider::Primitives>,
257    provider: &Provider,
258    segment: StaticFileSegment,
259) -> Result<StageError, ProviderError>
260where
261    Provider: BlockReader + StaticFileProviderFactory,
262{
263    let mut last_block =
264        static_file_provider.get_highest_static_file_block(segment).unwrap_or_default();
265
266    // To be extra safe, we make sure that the last tx num matches the last block from its indices.
267    // If not, get it.
268    loop {
269        if let Some(indices) = provider.block_body_indices(last_block)? {
270            if indices.last_tx_num() <= last_tx_num {
271                break
272            }
273        }
274        if last_block == 0 {
275            break
276        }
277        last_block -= 1;
278    }
279
280    let missing_block = Box::new(provider.sealed_header(last_block + 1)?.unwrap_or_default());
281
282    Ok(StageError::MissingStaticFileData {
283        block: Box::new(missing_block.block_with_parent()),
284        segment,
285    })
286}