reth_stages/stages/
headers.rs

1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6    cursor::{DbCursorRO, DbCursorRW},
7    table::Value,
8    tables,
9    transaction::{DbTx, DbTxMut},
10    DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14    downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15    error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader, NodePrimitives, SealedHeader};
18use reth_provider::{
19    providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider,
20    HeaderSyncGapProvider, StaticFileProviderFactory,
21};
22use reth_stages_api::{
23    CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
24    StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
25};
26use reth_static_file_types::StaticFileSegment;
27use reth_storage_errors::provider::ProviderError;
28use std::task::{ready, Context, Poll};
29
30use tokio::sync::watch;
31use tracing::*;
32
33/// The headers stage.
34///
35/// The headers stage downloads all block headers from the highest block in storage to
36/// the perceived highest block on the network.
37///
38/// The headers are processed and data is inserted into static files, as well as into the
39/// [`HeaderNumbers`][reth_db_api::tables::HeaderNumbers] table.
40///
41/// NOTE: This stage downloads headers in reverse and pushes them to the ETL [`Collector`]. It then
42/// proceeds to push them sequentially to static files. The stage checkpoint is not updated until
43/// this stage is done.
44#[derive(Debug)]
45pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
46    /// Database handle.
47    provider: Provider,
48    /// Strategy for downloading the headers
49    downloader: Downloader,
50    /// The tip for the stage.
51    ///
52    /// This determines the sync target of the stage (set by the pipeline).
53    tip: watch::Receiver<B256>,
54    /// Current sync gap.
55    sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
56    /// ETL collector with `HeaderHash` -> `BlockNumber`
57    hash_collector: Collector<BlockHash, BlockNumber>,
58    /// ETL collector with `BlockNumber` -> `BincodeSealedHeader`
59    header_collector: Collector<BlockNumber, Bytes>,
60    /// Returns true if the ETL collector has all necessary headers to fill the gap.
61    is_etl_ready: bool,
62}
63
64// === impl HeaderStage ===
65
66impl<Provider, Downloader> HeaderStage<Provider, Downloader>
67where
68    Downloader: HeaderDownloader,
69{
70    /// Create a new header stage
71    pub fn new(
72        database: Provider,
73        downloader: Downloader,
74        tip: watch::Receiver<B256>,
75        etl_config: EtlConfig,
76    ) -> Self {
77        Self {
78            provider: database,
79            downloader,
80            tip,
81            sync_gap: None,
82            hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
83            header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
84            is_etl_ready: false,
85        }
86    }
87
88    /// Write downloaded headers to storage from ETL.
89    ///
90    /// Writes to static files ( `Header | HeaderTD | HeaderHash` ) and [`tables::HeaderNumbers`]
91    /// database table.
92    fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
93    where
94        P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
95        Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
96        <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
97    {
98        let total_headers = self.header_collector.len();
99
100        info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
101
102        let static_file_provider = provider.static_file_provider();
103
104        // Consistency check of expected headers in static files vs DB is done on provider::sync_gap
105        // when poll_execute_ready is polled.
106        let mut last_header_number = static_file_provider
107            .get_highest_static_file_block(StaticFileSegment::Headers)
108            .unwrap_or_default();
109
110        // Find the latest total difficulty
111        let mut td = static_file_provider
112            .header_td_by_number(last_header_number)?
113            .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
114
115        // Although headers were downloaded in reverse order, the collector iterates it in ascending
116        // order
117        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
118        let interval = (total_headers / 10).max(1);
119        for (index, header) in self.header_collector.iter()?.enumerate() {
120            let (_, header_buf) = header?;
121
122            if index > 0 && index % interval == 0 && total_headers > 100 {
123                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
124            }
125
126            let sealed_header: SealedHeader<Downloader::Header> =
127                bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
128                    .map_err(|err| StageError::Fatal(Box::new(err)))?
129                    .into();
130
131            let (header, header_hash) = sealed_header.split_ref();
132            if header.number() == 0 {
133                continue
134            }
135            last_header_number = header.number();
136
137            // Increase total difficulty
138            td += header.difficulty();
139
140            // Append to Headers segment
141            writer.append_header(header, td, header_hash)?;
142        }
143
144        info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
145
146        let mut cursor_header_numbers =
147            provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
148        let mut first_sync = false;
149
150        // If we only have the genesis block hash, then we are at first sync, and we can remove it,
151        // add it to the collector and use tx.append on all hashes.
152        if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 {
153            if let Some((hash, block_number)) = cursor_header_numbers.last()? {
154                if block_number.value()? == 0 {
155                    self.hash_collector.insert(hash.key()?, 0)?;
156                    cursor_header_numbers.delete_current()?;
157                    first_sync = true;
158                }
159            }
160        }
161
162        // Since ETL sorts all entries by hashes, we are either appending (first sync) or inserting
163        // in order (further syncs).
164        for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
165            let (hash, number) = hash_to_number?;
166
167            if index > 0 && index % interval == 0 && total_headers > 100 {
168                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
169            }
170
171            if first_sync {
172                cursor_header_numbers.append(
173                    RawKey::<BlockHash>::from_vec(hash),
174                    &RawValue::<BlockNumber>::from_vec(number),
175                )?;
176            } else {
177                cursor_header_numbers.upsert(
178                    RawKey::<BlockHash>::from_vec(hash),
179                    &RawValue::<BlockNumber>::from_vec(number),
180                )?;
181            }
182        }
183
184        Ok(last_header_number)
185    }
186}
187
188impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
189where
190    Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
191    P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
192    D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
193    <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
194{
195    /// Return the id of the stage
196    fn id(&self) -> StageId {
197        StageId::Headers
198    }
199
200    fn poll_execute_ready(
201        &mut self,
202        cx: &mut Context<'_>,
203        input: ExecInput,
204    ) -> Poll<Result<(), StageError>> {
205        let current_checkpoint = input.checkpoint();
206
207        // Return if stage has already completed the gap on the ETL files
208        if self.is_etl_ready {
209            return Poll::Ready(Ok(()))
210        }
211
212        // Lookup the head and tip of the sync range
213        let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
214        let target = SyncTarget::Tip(*self.tip.borrow());
215        let gap = HeaderSyncGap { local_head, target };
216        let tip = gap.target.tip();
217        self.sync_gap = Some(gap.clone());
218
219        // Nothing to sync
220        if gap.is_closed() {
221            info!(
222                target: "sync::stages::headers",
223                checkpoint = %current_checkpoint.block_number,
224                target = ?tip,
225                "Target block already reached"
226            );
227            self.is_etl_ready = true;
228            return Poll::Ready(Ok(()))
229        }
230
231        debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
232        let local_head_number = gap.local_head.number();
233
234        // let the downloader know what to sync
235        self.downloader.update_sync_gap(gap.local_head, gap.target);
236
237        // We only want to stop once we have all the headers on ETL filespace (disk).
238        loop {
239            match ready!(self.downloader.poll_next_unpin(cx)) {
240                Some(Ok(headers)) => {
241                    info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
242                    for header in headers {
243                        let header_number = header.number();
244
245                        self.hash_collector.insert(header.hash(), header_number)?;
246                        self.header_collector.insert(
247                            header_number,
248                            Bytes::from(
249                                bincode::serialize(&serde_bincode_compat::SealedHeader::from(
250                                    &header,
251                                ))
252                                .map_err(|err| StageError::Fatal(Box::new(err)))?,
253                            ),
254                        )?;
255
256                        // Headers are downloaded in reverse, so if we reach here, we know we have
257                        // filled the gap.
258                        if header_number == local_head_number + 1 {
259                            self.is_etl_ready = true;
260                            return Poll::Ready(Ok(()))
261                        }
262                    }
263                }
264                Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
265                    error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
266                    return Poll::Ready(Err(StageError::DetachedHead {
267                        local_head: Box::new(local_head.block_with_parent()),
268                        header: Box::new(header.block_with_parent()),
269                        error,
270                    }))
271                }
272                None => return Poll::Ready(Err(StageError::ChannelClosed)),
273            }
274        }
275    }
276
277    /// Download the headers in reverse order (falling block numbers)
278    /// starting from the tip of the chain
279    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
280        let current_checkpoint = input.checkpoint();
281
282        if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() {
283            self.is_etl_ready = false;
284            return Ok(ExecOutput::done(current_checkpoint))
285        }
286
287        // We should be here only after we have downloaded all headers into the disk buffer (ETL).
288        if !self.is_etl_ready {
289            return Err(StageError::MissingDownloadBuffer)
290        }
291
292        // Reset flag
293        self.is_etl_ready = false;
294
295        // Write the headers and related tables to DB from ETL space
296        let to_be_processed = self.hash_collector.len() as u64;
297        let last_header_number = self.write_headers(provider)?;
298
299        // Clear ETL collectors
300        self.hash_collector.clear();
301        self.header_collector.clear();
302
303        Ok(ExecOutput {
304            checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
305                HeadersCheckpoint {
306                    block_range: CheckpointBlockRange {
307                        from: input.checkpoint().block_number,
308                        to: last_header_number,
309                    },
310                    progress: EntitiesCheckpoint {
311                        processed: input.checkpoint().block_number + to_be_processed,
312                        total: last_header_number,
313                    },
314                },
315            ),
316            // We only reach here if all headers have been downloaded by ETL, and pushed to DB all
317            // in one stage run.
318            done: true,
319        })
320    }
321
322    /// Unwind the stage.
323    fn unwind(
324        &mut self,
325        provider: &Provider,
326        input: UnwindInput,
327    ) -> Result<UnwindOutput, StageError> {
328        self.sync_gap.take();
329
330        // First unwind the db tables, until the unwind_to block number. use the walker to unwind
331        // HeaderNumbers based on the index in CanonicalHeaders
332        // unwind from the next block number since the unwind_to block is exclusive
333        provider
334            .tx_ref()
335            .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
336                (input.unwind_to + 1)..,
337            )?;
338        provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
339        provider
340            .tx_ref()
341            .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
342        let unfinalized_headers_unwound =
343            provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
344
345        // determine how many headers to unwind from the static files based on the highest block and
346        // the unwind_to block
347        let static_file_provider = provider.static_file_provider();
348        let highest_block = static_file_provider
349            .get_highest_static_file_block(StaticFileSegment::Headers)
350            .unwrap_or_default();
351        let static_file_headers_to_unwind = highest_block - input.unwind_to;
352        for block_number in (input.unwind_to + 1)..=highest_block {
353            let hash = static_file_provider.block_hash(block_number)?;
354            // we have to delete from HeaderNumbers here as well as in the above unwind, since that
355            // mapping contains entries for both headers in the db and headers in static files
356            //
357            // so if we are unwinding past the lowest block in the db, we have to iterate through
358            // the HeaderNumbers entries that we'll delete in static files below
359            if let Some(header_hash) = hash {
360                provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
361            }
362        }
363
364        // Now unwind the static files until the unwind_to block number
365        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
366        writer.prune_headers(static_file_headers_to_unwind)?;
367
368        // Set the stage checkpoint entities processed based on how much we unwound - we add the
369        // headers unwound from static files and db
370        let stage_checkpoint =
371            input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
372                block_range: stage_checkpoint.block_range,
373                progress: EntitiesCheckpoint {
374                    processed: stage_checkpoint.progress.processed.saturating_sub(
375                        static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
376                    ),
377                    total: stage_checkpoint.progress.total,
378                },
379            });
380
381        let mut checkpoint = StageCheckpoint::new(input.unwind_to);
382        if let Some(stage_checkpoint) = stage_checkpoint {
383            checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
384        }
385
386        Ok(UnwindOutput { checkpoint })
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::test_utils::{
394        stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
395    };
396    use alloy_primitives::B256;
397    use assert_matches::assert_matches;
398    use reth_ethereum_primitives::BlockBody;
399    use reth_execution_types::ExecutionOutcome;
400    use reth_primitives_traits::{RecoveredBlock, SealedBlock};
401    use reth_provider::{BlockWriter, ProviderFactory, StaticFileProviderFactory};
402    use reth_stages_api::StageUnitCheckpoint;
403    use reth_testing_utils::generators::{self, random_header, random_header_range};
404    use reth_trie::{updates::TrieUpdates, HashedPostStateSorted};
405    use std::sync::Arc;
406    use test_runner::HeadersTestRunner;
407
408    mod test_runner {
409        use super::*;
410        use crate::test_utils::{TestRunnerError, TestStageDB};
411        use reth_consensus::test_utils::TestConsensus;
412        use reth_downloaders::headers::reverse_headers::{
413            ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
414        };
415        use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
416        use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
417        use tokio::sync::watch;
418
419        pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
420            pub(crate) client: TestHeadersClient,
421            channel: (watch::Sender<B256>, watch::Receiver<B256>),
422            downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
423            db: TestStageDB,
424        }
425
426        impl Default for HeadersTestRunner<TestHeaderDownloader> {
427            fn default() -> Self {
428                let client = TestHeadersClient::default();
429                Self {
430                    client: client.clone(),
431                    channel: watch::channel(B256::ZERO),
432
433                    downloader_factory: Box::new(move || {
434                        TestHeaderDownloader::new(client.clone(), 1000, 1000)
435                    }),
436                    db: TestStageDB::default(),
437                }
438            }
439        }
440
441        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
442            for HeadersTestRunner<D>
443        {
444            type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
445
446            fn db(&self) -> &TestStageDB {
447                &self.db
448            }
449
450            fn stage(&self) -> Self::S {
451                HeaderStage::new(
452                    self.db.factory.clone(),
453                    (*self.downloader_factory)(),
454                    self.channel.1.clone(),
455                    EtlConfig::default(),
456                )
457            }
458        }
459
460        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
461            for HeadersTestRunner<D>
462        {
463            type Seed = Vec<SealedHeader>;
464
465            fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
466                let mut rng = generators::rng();
467                let start = input.checkpoint().block_number;
468                let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
469                let head = headers.last().cloned().unwrap();
470                self.db.insert_headers_with_td(headers.iter())?;
471
472                // use previous checkpoint as seed size
473                let end = input.target.unwrap_or_default() + 1;
474
475                if start + 1 >= end {
476                    return Ok(Vec::default())
477                }
478
479                let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
480                headers.insert(0, head);
481                Ok(headers)
482            }
483
484            /// Validate stored headers
485            fn validate_execution(
486                &self,
487                input: ExecInput,
488                output: Option<ExecOutput>,
489            ) -> Result<(), TestRunnerError> {
490                let initial_checkpoint = input.checkpoint().block_number;
491                match output {
492                    Some(output) if output.checkpoint.block_number > initial_checkpoint => {
493                        let provider = self.db.factory.provider()?;
494                        let mut td = provider
495                            .header_td_by_number(initial_checkpoint.saturating_sub(1))?
496                            .unwrap_or_default();
497
498                        for block_num in initial_checkpoint..output.checkpoint.block_number {
499                            // look up the header hash
500                            let hash = provider.block_hash(block_num)?.expect("no header hash");
501
502                            // validate the header number
503                            assert_eq!(provider.block_number(hash)?, Some(block_num));
504
505                            // validate the header
506                            let header = provider.header_by_number(block_num)?;
507                            assert!(header.is_some());
508                            let header = SealedHeader::seal_slow(header.unwrap());
509                            assert_eq!(header.hash(), hash);
510
511                            // validate the header total difficulty
512                            td += header.difficulty;
513                            assert_eq!(provider.header_td_by_number(block_num)?, Some(td));
514                        }
515                    }
516                    _ => self.check_no_header_entry_above(initial_checkpoint)?,
517                };
518                Ok(())
519            }
520
521            async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
522                self.client.extend(headers.iter().map(|h| h.clone_header())).await;
523                let tip = if headers.is_empty() {
524                    let tip = random_header(&mut generators::rng(), 0, None);
525                    self.db.insert_headers(std::iter::once(&tip))?;
526                    tip.hash()
527                } else {
528                    headers.last().unwrap().hash()
529                };
530                self.send_tip(tip);
531                Ok(())
532            }
533        }
534
535        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
536            for HeadersTestRunner<D>
537        {
538            fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
539                self.check_no_header_entry_above(input.unwind_to)
540            }
541        }
542
543        impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
544            pub(crate) fn with_linear_downloader() -> Self {
545                let client = TestHeadersClient::default();
546                Self {
547                    client: client.clone(),
548                    channel: watch::channel(B256::ZERO),
549                    downloader_factory: Box::new(move || {
550                        ReverseHeadersDownloaderBuilder::default()
551                            .stream_batch_size(500)
552                            .build(client.clone(), Arc::new(TestConsensus::default()))
553                    }),
554                    db: TestStageDB::default(),
555                }
556            }
557        }
558
559        impl<D: HeaderDownloader> HeadersTestRunner<D> {
560            pub(crate) fn check_no_header_entry_above(
561                &self,
562                block: BlockNumber,
563            ) -> Result<(), TestRunnerError> {
564                self.db
565                    .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
566                self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
567                self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
568                self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
569                    block,
570                    |num| num,
571                )?;
572                Ok(())
573            }
574
575            pub(crate) fn send_tip(&self, tip: B256) {
576                self.channel.0.send(tip).expect("failed to send tip");
577            }
578        }
579    }
580
581    stage_test_suite!(HeadersTestRunner, headers);
582
583    /// Execute the stage with linear downloader, unwinds, and ensures that the database tables
584    /// along with the static files are cleaned up.
585    #[tokio::test]
586    async fn execute_with_linear_downloader_unwind() {
587        let mut runner = HeadersTestRunner::with_linear_downloader();
588        let (checkpoint, previous_stage) = (1000, 1200);
589        let input = ExecInput {
590            target: Some(previous_stage),
591            checkpoint: Some(StageCheckpoint::new(checkpoint)),
592        };
593        let headers = runner.seed_execution(input).expect("failed to seed execution");
594        let rx = runner.execute(input);
595
596        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
597
598        // skip `after_execution` hook for linear downloader
599        let tip = headers.last().unwrap();
600        runner.send_tip(tip.hash());
601
602        let result = rx.await.unwrap();
603        runner.db().factory.static_file_provider().commit().unwrap();
604        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
605            block_number,
606            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
607                block_range: CheckpointBlockRange {
608                    from,
609                    to
610                },
611                progress: EntitiesCheckpoint {
612                    processed,
613                    total,
614                }
615            }))
616        }, done: true }) if block_number == tip.number &&
617            from == checkpoint && to == previous_stage &&
618            // -1 because we don't need to download the local head
619            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
620        );
621        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
622        assert!(runner.stage().hash_collector.is_empty());
623        assert!(runner.stage().header_collector.is_empty());
624
625        // let's insert some blocks using append_blocks_with_state
626        let sealed_headers =
627            random_header_range(&mut generators::rng(), tip.number..tip.number + 10, tip.hash());
628
629        // make them sealed blocks with senders by converting them to empty blocks
630        let sealed_blocks = sealed_headers
631            .iter()
632            .map(|header| {
633                RecoveredBlock::new_sealed(
634                    SealedBlock::from_sealed_parts(header.clone(), BlockBody::default()),
635                    vec![],
636                )
637            })
638            .collect();
639
640        // append the blocks
641        let provider = runner.db().factory.provider_rw().unwrap();
642        provider
643            .append_blocks_with_state(
644                sealed_blocks,
645                &ExecutionOutcome::default(),
646                HashedPostStateSorted::default(),
647                TrieUpdates::default(),
648            )
649            .unwrap();
650        provider.commit().unwrap();
651
652        // now we can unwind 10 blocks
653        let unwind_input = UnwindInput {
654            checkpoint: StageCheckpoint::new(tip.number + 10),
655            unwind_to: tip.number,
656            bad_block: None,
657        };
658
659        let unwind_output = runner.unwind(unwind_input).await.unwrap();
660        assert_eq!(unwind_output.checkpoint.block_number, tip.number);
661
662        // validate the unwind, ensure that the tables are cleaned up
663        assert!(runner.validate_unwind(unwind_input).is_ok());
664    }
665
666    /// Execute the stage with linear downloader
667    #[tokio::test]
668    async fn execute_with_linear_downloader() {
669        let mut runner = HeadersTestRunner::with_linear_downloader();
670        let (checkpoint, previous_stage) = (1000, 1200);
671        let input = ExecInput {
672            target: Some(previous_stage),
673            checkpoint: Some(StageCheckpoint::new(checkpoint)),
674        };
675        let headers = runner.seed_execution(input).expect("failed to seed execution");
676        let rx = runner.execute(input);
677
678        runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
679
680        // skip `after_execution` hook for linear downloader
681        let tip = headers.last().unwrap();
682        runner.send_tip(tip.hash());
683
684        let result = rx.await.unwrap();
685        runner.db().factory.static_file_provider().commit().unwrap();
686        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
687            block_number,
688            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
689                block_range: CheckpointBlockRange {
690                    from,
691                    to
692                },
693                progress: EntitiesCheckpoint {
694                    processed,
695                    total,
696                }
697            }))
698        }, done: true }) if block_number == tip.number &&
699            from == checkpoint && to == previous_stage &&
700            // -1 because we don't need to download the local head
701            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
702        );
703        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
704        assert!(runner.stage().hash_collector.is_empty());
705        assert!(runner.stage().header_collector.is_empty());
706    }
707}