reth_stages/stages/
headers.rs

1use alloy_consensus::BlockHeader;
2use alloy_eips::{eip1898::BlockWithParent, NumHash};
3use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
4use futures_util::StreamExt;
5use reth_config::config::EtlConfig;
6use reth_consensus::HeaderValidator;
7use reth_db::{table::Value, tables, transaction::DbTx, RawKey, RawTable, RawValue};
8use reth_db_api::{
9    cursor::{DbCursorRO, DbCursorRW},
10    transaction::DbTxMut,
11    DbTxUnwindExt,
12};
13use reth_etl::Collector;
14use reth_network_p2p::headers::{downloader::HeaderDownloader, error::HeadersDownloaderError};
15use reth_primitives::{NodePrimitives, SealedHeader, StaticFileSegment};
16use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader};
17use reth_provider::{
18    providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider, HeaderSyncGap,
19    HeaderSyncGapProvider, StaticFileProviderFactory,
20};
21use reth_stages_api::{
22    BlockErrorKind, CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput,
23    HeadersCheckpoint, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
24};
25use reth_storage_errors::provider::ProviderError;
26use std::{
27    sync::Arc,
28    task::{ready, Context, Poll},
29};
30use tokio::sync::watch;
31use tracing::*;
32
33/// The headers stage.
34///
35/// The headers stage downloads all block headers from the highest block in storage to
36/// the perceived highest block on the network.
37///
38/// The headers are processed and data is inserted into static files, as well as into the
39/// [`HeaderNumbers`][reth_db::tables::HeaderNumbers] table.
40///
41/// NOTE: This stage downloads headers in reverse and pushes them to the ETL [`Collector`]. It then
42/// proceeds to push them sequentially to static files. The stage checkpoint is not updated until
43/// this stage is done.
44#[derive(Debug)]
45pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
46    /// Database handle.
47    provider: Provider,
48    /// Strategy for downloading the headers
49    downloader: Downloader,
50    /// The tip for the stage.
51    tip: watch::Receiver<B256>,
52    /// Consensus client implementation
53    consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
54    /// Current sync gap.
55    sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
56    /// ETL collector with `HeaderHash` -> `BlockNumber`
57    hash_collector: Collector<BlockHash, BlockNumber>,
58    /// ETL collector with `BlockNumber` -> `BincodeSealedHeader`
59    header_collector: Collector<BlockNumber, Bytes>,
60    /// Returns true if the ETL collector has all necessary headers to fill the gap.
61    is_etl_ready: bool,
62}
63
64// === impl HeaderStage ===
65
66impl<Provider, Downloader> HeaderStage<Provider, Downloader>
67where
68    Downloader: HeaderDownloader,
69{
70    /// Create a new header stage
71    pub fn new(
72        database: Provider,
73        downloader: Downloader,
74        tip: watch::Receiver<B256>,
75        consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
76        etl_config: EtlConfig,
77    ) -> Self {
78        Self {
79            provider: database,
80            downloader,
81            tip,
82            consensus,
83            sync_gap: None,
84            hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
85            header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
86            is_etl_ready: false,
87        }
88    }
89
90    /// Write downloaded headers to storage from ETL.
91    ///
92    /// Writes to static files ( `Header | HeaderTD | HeaderHash` ) and [`tables::HeaderNumbers`]
93    /// database table.
94    fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
95    where
96        P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
97        Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
98        <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
99    {
100        let total_headers = self.header_collector.len();
101
102        info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
103
104        let static_file_provider = provider.static_file_provider();
105
106        // Consistency check of expected headers in static files vs DB is done on provider::sync_gap
107        // when poll_execute_ready is polled.
108        let mut last_header_number = static_file_provider
109            .get_highest_static_file_block(StaticFileSegment::Headers)
110            .unwrap_or_default();
111
112        // Find the latest total difficulty
113        let mut td = static_file_provider
114            .header_td_by_number(last_header_number)?
115            .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
116
117        // Although headers were downloaded in reverse order, the collector iterates it in ascending
118        // order
119        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
120        let interval = (total_headers / 10).max(1);
121        for (index, header) in self.header_collector.iter()?.enumerate() {
122            let (_, header_buf) = header?;
123
124            if index > 0 && index % interval == 0 && total_headers > 100 {
125                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
126            }
127
128            let sealed_header: SealedHeader<Downloader::Header> =
129                bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
130                    .map_err(|err| StageError::Fatal(Box::new(err)))?
131                    .into();
132
133            let (header, header_hash) = sealed_header.split();
134            if header.number() == 0 {
135                continue
136            }
137            last_header_number = header.number();
138
139            // Increase total difficulty
140            td += header.difficulty();
141
142            // Header validation
143            self.consensus.validate_header_with_total_difficulty(&header, td).map_err(|error| {
144                StageError::Block {
145                    block: Box::new(BlockWithParent::new(
146                        header.parent_hash(),
147                        NumHash::new(header.number(), header_hash),
148                    )),
149                    error: BlockErrorKind::Validation(error),
150                }
151            })?;
152
153            // Append to Headers segment
154            writer.append_header(&header, td, &header_hash)?;
155        }
156
157        info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
158
159        let mut cursor_header_numbers =
160            provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
161        let mut first_sync = false;
162
163        // If we only have the genesis block hash, then we are at first sync, and we can remove it,
164        // add it to the collector and use tx.append on all hashes.
165        if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 {
166            if let Some((hash, block_number)) = cursor_header_numbers.last()? {
167                if block_number.value()? == 0 {
168                    self.hash_collector.insert(hash.key()?, 0)?;
169                    cursor_header_numbers.delete_current()?;
170                    first_sync = true;
171                }
172            }
173        }
174
175        // Since ETL sorts all entries by hashes, we are either appending (first sync) or inserting
176        // in order (further syncs).
177        for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
178            let (hash, number) = hash_to_number?;
179
180            if index > 0 && index % interval == 0 && total_headers > 100 {
181                info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
182            }
183
184            if first_sync {
185                cursor_header_numbers.append(
186                    RawKey::<BlockHash>::from_vec(hash),
187                    RawValue::<BlockNumber>::from_vec(number),
188                )?;
189            } else {
190                cursor_header_numbers.insert(
191                    RawKey::<BlockHash>::from_vec(hash),
192                    RawValue::<BlockNumber>::from_vec(number),
193                )?;
194            }
195        }
196
197        Ok(last_header_number)
198    }
199}
200
201impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
202where
203    Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
204    P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
205    D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
206    <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
207{
208    /// Return the id of the stage
209    fn id(&self) -> StageId {
210        StageId::Headers
211    }
212
213    fn poll_execute_ready(
214        &mut self,
215        cx: &mut Context<'_>,
216        input: ExecInput,
217    ) -> Poll<Result<(), StageError>> {
218        let current_checkpoint = input.checkpoint();
219
220        // Return if stage has already completed the gap on the ETL files
221        if self.is_etl_ready {
222            return Poll::Ready(Ok(()))
223        }
224
225        // Lookup the head and tip of the sync range
226        let gap = self.provider.sync_gap(self.tip.clone(), current_checkpoint.block_number)?;
227        let tip = gap.target.tip();
228        self.sync_gap = Some(gap.clone());
229
230        // Nothing to sync
231        if gap.is_closed() {
232            info!(
233                target: "sync::stages::headers",
234                checkpoint = %current_checkpoint.block_number,
235                target = ?tip,
236                "Target block already reached"
237            );
238            self.is_etl_ready = true;
239            return Poll::Ready(Ok(()))
240        }
241
242        debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
243        let local_head_number = gap.local_head.number();
244
245        // let the downloader know what to sync
246        self.downloader.update_sync_gap(gap.local_head, gap.target);
247
248        // We only want to stop once we have all the headers on ETL filespace (disk).
249        loop {
250            match ready!(self.downloader.poll_next_unpin(cx)) {
251                Some(Ok(headers)) => {
252                    info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
253                    for header in headers {
254                        let header_number = header.number();
255
256                        self.hash_collector.insert(header.hash(), header_number)?;
257                        self.header_collector.insert(
258                            header_number,
259                            Bytes::from(
260                                bincode::serialize(&serde_bincode_compat::SealedHeader::from(
261                                    &header,
262                                ))
263                                .map_err(|err| StageError::Fatal(Box::new(err)))?,
264                            ),
265                        )?;
266
267                        // Headers are downloaded in reverse, so if we reach here, we know we have
268                        // filled the gap.
269                        if header_number == local_head_number + 1 {
270                            self.is_etl_ready = true;
271                            return Poll::Ready(Ok(()))
272                        }
273                    }
274                }
275                Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
276                    error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
277                    return Poll::Ready(Err(StageError::DetachedHead {
278                        local_head: Box::new(local_head.block_with_parent()),
279                        header: Box::new(header.block_with_parent()),
280                        error,
281                    }))
282                }
283                None => return Poll::Ready(Err(StageError::ChannelClosed)),
284            }
285        }
286    }
287
288    /// Download the headers in reverse order (falling block numbers)
289    /// starting from the tip of the chain
290    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
291        let current_checkpoint = input.checkpoint();
292
293        if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() {
294            self.is_etl_ready = false;
295            return Ok(ExecOutput::done(current_checkpoint))
296        }
297
298        // We should be here only after we have downloaded all headers into the disk buffer (ETL).
299        if !self.is_etl_ready {
300            return Err(StageError::MissingDownloadBuffer)
301        }
302
303        // Reset flag
304        self.is_etl_ready = false;
305
306        // Write the headers and related tables to DB from ETL space
307        let to_be_processed = self.hash_collector.len() as u64;
308        let last_header_number = self.write_headers(provider)?;
309
310        // Clear ETL collectors
311        self.hash_collector.clear();
312        self.header_collector.clear();
313
314        Ok(ExecOutput {
315            checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
316                HeadersCheckpoint {
317                    block_range: CheckpointBlockRange {
318                        from: input.checkpoint().block_number,
319                        to: last_header_number,
320                    },
321                    progress: EntitiesCheckpoint {
322                        processed: input.checkpoint().block_number + to_be_processed,
323                        total: last_header_number,
324                    },
325                },
326            ),
327            // We only reach here if all headers have been downloaded by ETL, and pushed to DB all
328            // in one stage run.
329            done: true,
330        })
331    }
332
333    /// Unwind the stage.
334    fn unwind(
335        &mut self,
336        provider: &Provider,
337        input: UnwindInput,
338    ) -> Result<UnwindOutput, StageError> {
339        self.sync_gap.take();
340
341        // First unwind the db tables, until the unwind_to block number. use the walker to unwind
342        // HeaderNumbers based on the index in CanonicalHeaders
343        // unwind from the next block number since the unwind_to block is exclusive
344        provider
345            .tx_ref()
346            .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
347                (input.unwind_to + 1)..,
348            )?;
349        provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
350        provider
351            .tx_ref()
352            .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
353        let unfinalized_headers_unwound =
354            provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
355
356        // determine how many headers to unwind from the static files based on the highest block and
357        // the unwind_to block
358        let static_file_provider = provider.static_file_provider();
359        let highest_block = static_file_provider
360            .get_highest_static_file_block(StaticFileSegment::Headers)
361            .unwrap_or_default();
362        let static_file_headers_to_unwind = highest_block - input.unwind_to;
363        for block_number in (input.unwind_to + 1)..=highest_block {
364            let hash = static_file_provider.block_hash(block_number)?;
365            // we have to delete from HeaderNumbers here as well as in the above unwind, since that
366            // mapping contains entries for both headers in the db and headers in static files
367            //
368            // so if we are unwinding past the lowest block in the db, we have to iterate through
369            // the HeaderNumbers entries that we'll delete in static files below
370            if let Some(header_hash) = hash {
371                provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
372            }
373        }
374
375        // Now unwind the static files until the unwind_to block number
376        let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
377        writer.prune_headers(static_file_headers_to_unwind)?;
378
379        // Set the stage checkpoint entities processed based on how much we unwound - we add the
380        // headers unwound from static files and db
381        let stage_checkpoint =
382            input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
383                block_range: stage_checkpoint.block_range,
384                progress: EntitiesCheckpoint {
385                    processed: stage_checkpoint.progress.processed.saturating_sub(
386                        static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
387                    ),
388                    total: stage_checkpoint.progress.total,
389                },
390            });
391
392        let mut checkpoint = StageCheckpoint::new(input.unwind_to);
393        if let Some(stage_checkpoint) = stage_checkpoint {
394            checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
395        }
396
397        Ok(UnwindOutput { checkpoint })
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::test_utils::{
405        stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
406    };
407    use alloy_primitives::B256;
408    use assert_matches::assert_matches;
409    use reth_execution_types::ExecutionOutcome;
410    use reth_primitives::{BlockBody, SealedBlock, SealedBlockWithSenders};
411    use reth_provider::{BlockWriter, ProviderFactory, StaticFileProviderFactory};
412    use reth_stages_api::StageUnitCheckpoint;
413    use reth_testing_utils::generators::{self, random_header, random_header_range};
414    use reth_trie::{updates::TrieUpdates, HashedPostStateSorted};
415    use test_runner::HeadersTestRunner;
416
417    mod test_runner {
418        use super::*;
419        use crate::test_utils::{TestRunnerError, TestStageDB};
420        use reth_consensus::test_utils::TestConsensus;
421        use reth_downloaders::headers::reverse_headers::{
422            ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
423        };
424        use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
425        use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
426        use tokio::sync::watch;
427
428        pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
429            pub(crate) client: TestHeadersClient,
430            channel: (watch::Sender<B256>, watch::Receiver<B256>),
431            downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
432            db: TestStageDB,
433            consensus: Arc<TestConsensus>,
434        }
435
436        impl Default for HeadersTestRunner<TestHeaderDownloader> {
437            fn default() -> Self {
438                let client = TestHeadersClient::default();
439                Self {
440                    client: client.clone(),
441                    channel: watch::channel(B256::ZERO),
442                    consensus: Arc::new(TestConsensus::default()),
443                    downloader_factory: Box::new(move || {
444                        TestHeaderDownloader::new(
445                            client.clone(),
446                            Arc::new(TestConsensus::default()),
447                            1000,
448                            1000,
449                        )
450                    }),
451                    db: TestStageDB::default(),
452                }
453            }
454        }
455
456        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
457            for HeadersTestRunner<D>
458        {
459            type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
460
461            fn db(&self) -> &TestStageDB {
462                &self.db
463            }
464
465            fn stage(&self) -> Self::S {
466                HeaderStage::new(
467                    self.db.factory.clone(),
468                    (*self.downloader_factory)(),
469                    self.channel.1.clone(),
470                    self.consensus.clone(),
471                    EtlConfig::default(),
472                )
473            }
474        }
475
476        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
477            for HeadersTestRunner<D>
478        {
479            type Seed = Vec<SealedHeader>;
480
481            fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
482                let mut rng = generators::rng();
483                let start = input.checkpoint().block_number;
484                let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
485                let head = headers.last().cloned().unwrap();
486                self.db.insert_headers_with_td(headers.iter())?;
487
488                // use previous checkpoint as seed size
489                let end = input.target.unwrap_or_default() + 1;
490
491                if start + 1 >= end {
492                    return Ok(Vec::default())
493                }
494
495                let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
496                headers.insert(0, head);
497                Ok(headers)
498            }
499
500            /// Validate stored headers
501            fn validate_execution(
502                &self,
503                input: ExecInput,
504                output: Option<ExecOutput>,
505            ) -> Result<(), TestRunnerError> {
506                let initial_checkpoint = input.checkpoint().block_number;
507                match output {
508                    Some(output) if output.checkpoint.block_number > initial_checkpoint => {
509                        let provider = self.db.factory.provider()?;
510                        let mut td = provider
511                            .header_td_by_number(initial_checkpoint.saturating_sub(1))?
512                            .unwrap_or_default();
513
514                        for block_num in initial_checkpoint..output.checkpoint.block_number {
515                            // look up the header hash
516                            let hash = provider.block_hash(block_num)?.expect("no header hash");
517
518                            // validate the header number
519                            assert_eq!(provider.block_number(hash)?, Some(block_num));
520
521                            // validate the header
522                            let header = provider.header_by_number(block_num)?;
523                            assert!(header.is_some());
524                            let header = SealedHeader::seal(header.unwrap());
525                            assert_eq!(header.hash(), hash);
526
527                            // validate the header total difficulty
528                            td += header.difficulty;
529                            assert_eq!(
530                                provider.header_td_by_number(block_num)?.map(Into::into),
531                                Some(td)
532                            );
533                        }
534                    }
535                    _ => self.check_no_header_entry_above(initial_checkpoint)?,
536                };
537                Ok(())
538            }
539
540            async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
541                self.client.extend(headers.iter().map(|h| h.clone().unseal())).await;
542                let tip = if headers.is_empty() {
543                    let tip = random_header(&mut generators::rng(), 0, None);
544                    self.db.insert_headers(std::iter::once(&tip))?;
545                    tip.hash()
546                } else {
547                    headers.last().unwrap().hash()
548                };
549                self.send_tip(tip);
550                Ok(())
551            }
552        }
553
554        impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
555            for HeadersTestRunner<D>
556        {
557            fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
558                self.check_no_header_entry_above(input.unwind_to)
559            }
560        }
561
562        impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
563            pub(crate) fn with_linear_downloader() -> Self {
564                let client = TestHeadersClient::default();
565                Self {
566                    client: client.clone(),
567                    channel: watch::channel(B256::ZERO),
568                    downloader_factory: Box::new(move || {
569                        ReverseHeadersDownloaderBuilder::default()
570                            .stream_batch_size(500)
571                            .build(client.clone(), Arc::new(TestConsensus::default()))
572                    }),
573                    db: TestStageDB::default(),
574                    consensus: Arc::new(TestConsensus::default()),
575                }
576            }
577        }
578
579        impl<D: HeaderDownloader> HeadersTestRunner<D> {
580            pub(crate) fn check_no_header_entry_above(
581                &self,
582                block: BlockNumber,
583            ) -> Result<(), TestRunnerError> {
584                self.db
585                    .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
586                self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
587                self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
588                self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
589                    block,
590                    |num| num,
591                )?;
592                Ok(())
593            }
594
595            pub(crate) fn send_tip(&self, tip: B256) {
596                self.channel.0.send(tip).expect("failed to send tip");
597            }
598        }
599    }
600
601    stage_test_suite!(HeadersTestRunner, headers);
602
603    /// Execute the stage with linear downloader, unwinds, and ensures that the database tables
604    /// along with the static files are cleaned up.
605    #[tokio::test]
606    async fn execute_with_linear_downloader_unwind() {
607        let mut runner = HeadersTestRunner::with_linear_downloader();
608        let (checkpoint, previous_stage) = (1000, 1200);
609        let input = ExecInput {
610            target: Some(previous_stage),
611            checkpoint: Some(StageCheckpoint::new(checkpoint)),
612        };
613        let headers = runner.seed_execution(input).expect("failed to seed execution");
614        let rx = runner.execute(input);
615
616        runner.client.extend(headers.iter().rev().map(|h| h.clone().unseal())).await;
617
618        // skip `after_execution` hook for linear downloader
619        let tip = headers.last().unwrap();
620        runner.send_tip(tip.hash());
621
622        let result = rx.await.unwrap();
623        runner.db().factory.static_file_provider().commit().unwrap();
624        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
625            block_number,
626            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
627                block_range: CheckpointBlockRange {
628                    from,
629                    to
630                },
631                progress: EntitiesCheckpoint {
632                    processed,
633                    total,
634                }
635            }))
636        }, done: true }) if block_number == tip.number &&
637            from == checkpoint && to == previous_stage &&
638            // -1 because we don't need to download the local head
639            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
640        );
641        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
642        assert!(runner.stage().hash_collector.is_empty());
643        assert!(runner.stage().header_collector.is_empty());
644
645        // let's insert some blocks using append_blocks_with_state
646        let sealed_headers =
647            random_header_range(&mut generators::rng(), tip.number..tip.number + 10, tip.hash());
648
649        // make them sealed blocks with senders by converting them to empty blocks
650        let sealed_blocks = sealed_headers
651            .iter()
652            .map(|header| {
653                SealedBlockWithSenders::new(
654                    SealedBlock::new(header.clone(), BlockBody::default()),
655                    vec![],
656                )
657                .unwrap()
658            })
659            .collect();
660
661        // append the blocks
662        let provider = runner.db().factory.provider_rw().unwrap();
663        provider
664            .append_blocks_with_state(
665                sealed_blocks,
666                ExecutionOutcome::default(),
667                HashedPostStateSorted::default(),
668                TrieUpdates::default(),
669            )
670            .unwrap();
671        provider.commit().unwrap();
672
673        // now we can unwind 10 blocks
674        let unwind_input = UnwindInput {
675            checkpoint: StageCheckpoint::new(tip.number + 10),
676            unwind_to: tip.number,
677            bad_block: None,
678        };
679
680        let unwind_output = runner.unwind(unwind_input).await.unwrap();
681        assert_eq!(unwind_output.checkpoint.block_number, tip.number);
682
683        // validate the unwind, ensure that the tables are cleaned up
684        assert!(runner.validate_unwind(unwind_input).is_ok());
685    }
686
687    /// Execute the stage with linear downloader
688    #[tokio::test]
689    async fn execute_with_linear_downloader() {
690        let mut runner = HeadersTestRunner::with_linear_downloader();
691        let (checkpoint, previous_stage) = (1000, 1200);
692        let input = ExecInput {
693            target: Some(previous_stage),
694            checkpoint: Some(StageCheckpoint::new(checkpoint)),
695        };
696        let headers = runner.seed_execution(input).expect("failed to seed execution");
697        let rx = runner.execute(input);
698
699        runner.client.extend(headers.iter().rev().map(|h| h.clone().unseal())).await;
700
701        // skip `after_execution` hook for linear downloader
702        let tip = headers.last().unwrap();
703        runner.send_tip(tip.hash());
704
705        let result = rx.await.unwrap();
706        runner.db().factory.static_file_provider().commit().unwrap();
707        assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
708            block_number,
709            stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
710                block_range: CheckpointBlockRange {
711                    from,
712                    to
713                },
714                progress: EntitiesCheckpoint {
715                    processed,
716                    total,
717                }
718            }))
719        }, done: true }) if block_number == tip.number &&
720            from == checkpoint && to == previous_stage &&
721            // -1 because we don't need to download the local head
722            processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
723        );
724        assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
725        assert!(runner.stage().hash_collector.is_empty());
726        assert!(runner.stage().header_collector.is_empty());
727    }
728}