1use alloy_consensus::BlockHeader;
2use alloy_eips::{eip1898::BlockWithParent, NumHash};
3use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
4use futures_util::StreamExt;
5use reth_config::config::EtlConfig;
6use reth_consensus::HeaderValidator;
7use reth_db::{table::Value, tables, transaction::DbTx, RawKey, RawTable, RawValue};
8use reth_db_api::{
9 cursor::{DbCursorRO, DbCursorRW},
10 transaction::DbTxMut,
11 DbTxUnwindExt,
12};
13use reth_etl::Collector;
14use reth_network_p2p::headers::{downloader::HeaderDownloader, error::HeadersDownloaderError};
15use reth_primitives::{NodePrimitives, SealedHeader, StaticFileSegment};
16use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader};
17use reth_provider::{
18 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider, HeaderSyncGap,
19 HeaderSyncGapProvider, StaticFileProviderFactory,
20};
21use reth_stages_api::{
22 BlockErrorKind, CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput,
23 HeadersCheckpoint, Stage, StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
24};
25use reth_storage_errors::provider::ProviderError;
26use std::{
27 sync::Arc,
28 task::{ready, Context, Poll},
29};
30use tokio::sync::watch;
31use tracing::*;
32
33#[derive(Debug)]
45pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
46 provider: Provider,
48 downloader: Downloader,
50 tip: watch::Receiver<B256>,
52 consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
54 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
56 hash_collector: Collector<BlockHash, BlockNumber>,
58 header_collector: Collector<BlockNumber, Bytes>,
60 is_etl_ready: bool,
62}
63
64impl<Provider, Downloader> HeaderStage<Provider, Downloader>
67where
68 Downloader: HeaderDownloader,
69{
70 pub fn new(
72 database: Provider,
73 downloader: Downloader,
74 tip: watch::Receiver<B256>,
75 consensus: Arc<dyn HeaderValidator<Downloader::Header>>,
76 etl_config: EtlConfig,
77 ) -> Self {
78 Self {
79 provider: database,
80 downloader,
81 tip,
82 consensus,
83 sync_gap: None,
84 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
85 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
86 is_etl_ready: false,
87 }
88 }
89
90 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
95 where
96 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
97 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
98 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
99 {
100 let total_headers = self.header_collector.len();
101
102 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
103
104 let static_file_provider = provider.static_file_provider();
105
106 let mut last_header_number = static_file_provider
109 .get_highest_static_file_block(StaticFileSegment::Headers)
110 .unwrap_or_default();
111
112 let mut td = static_file_provider
114 .header_td_by_number(last_header_number)?
115 .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
116
117 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
120 let interval = (total_headers / 10).max(1);
121 for (index, header) in self.header_collector.iter()?.enumerate() {
122 let (_, header_buf) = header?;
123
124 if index > 0 && index % interval == 0 && total_headers > 100 {
125 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
126 }
127
128 let sealed_header: SealedHeader<Downloader::Header> =
129 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
130 .map_err(|err| StageError::Fatal(Box::new(err)))?
131 .into();
132
133 let (header, header_hash) = sealed_header.split();
134 if header.number() == 0 {
135 continue
136 }
137 last_header_number = header.number();
138
139 td += header.difficulty();
141
142 self.consensus.validate_header_with_total_difficulty(&header, td).map_err(|error| {
144 StageError::Block {
145 block: Box::new(BlockWithParent::new(
146 header.parent_hash(),
147 NumHash::new(header.number(), header_hash),
148 )),
149 error: BlockErrorKind::Validation(error),
150 }
151 })?;
152
153 writer.append_header(&header, td, &header_hash)?;
155 }
156
157 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
158
159 let mut cursor_header_numbers =
160 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
161 let mut first_sync = false;
162
163 if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 {
166 if let Some((hash, block_number)) = cursor_header_numbers.last()? {
167 if block_number.value()? == 0 {
168 self.hash_collector.insert(hash.key()?, 0)?;
169 cursor_header_numbers.delete_current()?;
170 first_sync = true;
171 }
172 }
173 }
174
175 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
178 let (hash, number) = hash_to_number?;
179
180 if index > 0 && index % interval == 0 && total_headers > 100 {
181 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
182 }
183
184 if first_sync {
185 cursor_header_numbers.append(
186 RawKey::<BlockHash>::from_vec(hash),
187 RawValue::<BlockNumber>::from_vec(number),
188 )?;
189 } else {
190 cursor_header_numbers.insert(
191 RawKey::<BlockHash>::from_vec(hash),
192 RawValue::<BlockNumber>::from_vec(number),
193 )?;
194 }
195 }
196
197 Ok(last_header_number)
198 }
199}
200
201impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
202where
203 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
204 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
205 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
206 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
207{
208 fn id(&self) -> StageId {
210 StageId::Headers
211 }
212
213 fn poll_execute_ready(
214 &mut self,
215 cx: &mut Context<'_>,
216 input: ExecInput,
217 ) -> Poll<Result<(), StageError>> {
218 let current_checkpoint = input.checkpoint();
219
220 if self.is_etl_ready {
222 return Poll::Ready(Ok(()))
223 }
224
225 let gap = self.provider.sync_gap(self.tip.clone(), current_checkpoint.block_number)?;
227 let tip = gap.target.tip();
228 self.sync_gap = Some(gap.clone());
229
230 if gap.is_closed() {
232 info!(
233 target: "sync::stages::headers",
234 checkpoint = %current_checkpoint.block_number,
235 target = ?tip,
236 "Target block already reached"
237 );
238 self.is_etl_ready = true;
239 return Poll::Ready(Ok(()))
240 }
241
242 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
243 let local_head_number = gap.local_head.number();
244
245 self.downloader.update_sync_gap(gap.local_head, gap.target);
247
248 loop {
250 match ready!(self.downloader.poll_next_unpin(cx)) {
251 Some(Ok(headers)) => {
252 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
253 for header in headers {
254 let header_number = header.number();
255
256 self.hash_collector.insert(header.hash(), header_number)?;
257 self.header_collector.insert(
258 header_number,
259 Bytes::from(
260 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
261 &header,
262 ))
263 .map_err(|err| StageError::Fatal(Box::new(err)))?,
264 ),
265 )?;
266
267 if header_number == local_head_number + 1 {
270 self.is_etl_ready = true;
271 return Poll::Ready(Ok(()))
272 }
273 }
274 }
275 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
276 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
277 return Poll::Ready(Err(StageError::DetachedHead {
278 local_head: Box::new(local_head.block_with_parent()),
279 header: Box::new(header.block_with_parent()),
280 error,
281 }))
282 }
283 None => return Poll::Ready(Err(StageError::ChannelClosed)),
284 }
285 }
286 }
287
288 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
291 let current_checkpoint = input.checkpoint();
292
293 if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() {
294 self.is_etl_ready = false;
295 return Ok(ExecOutput::done(current_checkpoint))
296 }
297
298 if !self.is_etl_ready {
300 return Err(StageError::MissingDownloadBuffer)
301 }
302
303 self.is_etl_ready = false;
305
306 let to_be_processed = self.hash_collector.len() as u64;
308 let last_header_number = self.write_headers(provider)?;
309
310 self.hash_collector.clear();
312 self.header_collector.clear();
313
314 Ok(ExecOutput {
315 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
316 HeadersCheckpoint {
317 block_range: CheckpointBlockRange {
318 from: input.checkpoint().block_number,
319 to: last_header_number,
320 },
321 progress: EntitiesCheckpoint {
322 processed: input.checkpoint().block_number + to_be_processed,
323 total: last_header_number,
324 },
325 },
326 ),
327 done: true,
330 })
331 }
332
333 fn unwind(
335 &mut self,
336 provider: &Provider,
337 input: UnwindInput,
338 ) -> Result<UnwindOutput, StageError> {
339 self.sync_gap.take();
340
341 provider
345 .tx_ref()
346 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
347 (input.unwind_to + 1)..,
348 )?;
349 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
350 provider
351 .tx_ref()
352 .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
353 let unfinalized_headers_unwound =
354 provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
355
356 let static_file_provider = provider.static_file_provider();
359 let highest_block = static_file_provider
360 .get_highest_static_file_block(StaticFileSegment::Headers)
361 .unwrap_or_default();
362 let static_file_headers_to_unwind = highest_block - input.unwind_to;
363 for block_number in (input.unwind_to + 1)..=highest_block {
364 let hash = static_file_provider.block_hash(block_number)?;
365 if let Some(header_hash) = hash {
371 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
372 }
373 }
374
375 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
377 writer.prune_headers(static_file_headers_to_unwind)?;
378
379 let stage_checkpoint =
382 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
383 block_range: stage_checkpoint.block_range,
384 progress: EntitiesCheckpoint {
385 processed: stage_checkpoint.progress.processed.saturating_sub(
386 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
387 ),
388 total: stage_checkpoint.progress.total,
389 },
390 });
391
392 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
393 if let Some(stage_checkpoint) = stage_checkpoint {
394 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
395 }
396
397 Ok(UnwindOutput { checkpoint })
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::test_utils::{
405 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
406 };
407 use alloy_primitives::B256;
408 use assert_matches::assert_matches;
409 use reth_execution_types::ExecutionOutcome;
410 use reth_primitives::{BlockBody, SealedBlock, SealedBlockWithSenders};
411 use reth_provider::{BlockWriter, ProviderFactory, StaticFileProviderFactory};
412 use reth_stages_api::StageUnitCheckpoint;
413 use reth_testing_utils::generators::{self, random_header, random_header_range};
414 use reth_trie::{updates::TrieUpdates, HashedPostStateSorted};
415 use test_runner::HeadersTestRunner;
416
417 mod test_runner {
418 use super::*;
419 use crate::test_utils::{TestRunnerError, TestStageDB};
420 use reth_consensus::test_utils::TestConsensus;
421 use reth_downloaders::headers::reverse_headers::{
422 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
423 };
424 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
425 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
426 use tokio::sync::watch;
427
428 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
429 pub(crate) client: TestHeadersClient,
430 channel: (watch::Sender<B256>, watch::Receiver<B256>),
431 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
432 db: TestStageDB,
433 consensus: Arc<TestConsensus>,
434 }
435
436 impl Default for HeadersTestRunner<TestHeaderDownloader> {
437 fn default() -> Self {
438 let client = TestHeadersClient::default();
439 Self {
440 client: client.clone(),
441 channel: watch::channel(B256::ZERO),
442 consensus: Arc::new(TestConsensus::default()),
443 downloader_factory: Box::new(move || {
444 TestHeaderDownloader::new(
445 client.clone(),
446 Arc::new(TestConsensus::default()),
447 1000,
448 1000,
449 )
450 }),
451 db: TestStageDB::default(),
452 }
453 }
454 }
455
456 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
457 for HeadersTestRunner<D>
458 {
459 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
460
461 fn db(&self) -> &TestStageDB {
462 &self.db
463 }
464
465 fn stage(&self) -> Self::S {
466 HeaderStage::new(
467 self.db.factory.clone(),
468 (*self.downloader_factory)(),
469 self.channel.1.clone(),
470 self.consensus.clone(),
471 EtlConfig::default(),
472 )
473 }
474 }
475
476 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
477 for HeadersTestRunner<D>
478 {
479 type Seed = Vec<SealedHeader>;
480
481 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
482 let mut rng = generators::rng();
483 let start = input.checkpoint().block_number;
484 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
485 let head = headers.last().cloned().unwrap();
486 self.db.insert_headers_with_td(headers.iter())?;
487
488 let end = input.target.unwrap_or_default() + 1;
490
491 if start + 1 >= end {
492 return Ok(Vec::default())
493 }
494
495 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
496 headers.insert(0, head);
497 Ok(headers)
498 }
499
500 fn validate_execution(
502 &self,
503 input: ExecInput,
504 output: Option<ExecOutput>,
505 ) -> Result<(), TestRunnerError> {
506 let initial_checkpoint = input.checkpoint().block_number;
507 match output {
508 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
509 let provider = self.db.factory.provider()?;
510 let mut td = provider
511 .header_td_by_number(initial_checkpoint.saturating_sub(1))?
512 .unwrap_or_default();
513
514 for block_num in initial_checkpoint..output.checkpoint.block_number {
515 let hash = provider.block_hash(block_num)?.expect("no header hash");
517
518 assert_eq!(provider.block_number(hash)?, Some(block_num));
520
521 let header = provider.header_by_number(block_num)?;
523 assert!(header.is_some());
524 let header = SealedHeader::seal(header.unwrap());
525 assert_eq!(header.hash(), hash);
526
527 td += header.difficulty;
529 assert_eq!(
530 provider.header_td_by_number(block_num)?.map(Into::into),
531 Some(td)
532 );
533 }
534 }
535 _ => self.check_no_header_entry_above(initial_checkpoint)?,
536 };
537 Ok(())
538 }
539
540 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
541 self.client.extend(headers.iter().map(|h| h.clone().unseal())).await;
542 let tip = if headers.is_empty() {
543 let tip = random_header(&mut generators::rng(), 0, None);
544 self.db.insert_headers(std::iter::once(&tip))?;
545 tip.hash()
546 } else {
547 headers.last().unwrap().hash()
548 };
549 self.send_tip(tip);
550 Ok(())
551 }
552 }
553
554 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
555 for HeadersTestRunner<D>
556 {
557 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
558 self.check_no_header_entry_above(input.unwind_to)
559 }
560 }
561
562 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
563 pub(crate) fn with_linear_downloader() -> Self {
564 let client = TestHeadersClient::default();
565 Self {
566 client: client.clone(),
567 channel: watch::channel(B256::ZERO),
568 downloader_factory: Box::new(move || {
569 ReverseHeadersDownloaderBuilder::default()
570 .stream_batch_size(500)
571 .build(client.clone(), Arc::new(TestConsensus::default()))
572 }),
573 db: TestStageDB::default(),
574 consensus: Arc::new(TestConsensus::default()),
575 }
576 }
577 }
578
579 impl<D: HeaderDownloader> HeadersTestRunner<D> {
580 pub(crate) fn check_no_header_entry_above(
581 &self,
582 block: BlockNumber,
583 ) -> Result<(), TestRunnerError> {
584 self.db
585 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
586 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
587 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
588 self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
589 block,
590 |num| num,
591 )?;
592 Ok(())
593 }
594
595 pub(crate) fn send_tip(&self, tip: B256) {
596 self.channel.0.send(tip).expect("failed to send tip");
597 }
598 }
599 }
600
601 stage_test_suite!(HeadersTestRunner, headers);
602
603 #[tokio::test]
606 async fn execute_with_linear_downloader_unwind() {
607 let mut runner = HeadersTestRunner::with_linear_downloader();
608 let (checkpoint, previous_stage) = (1000, 1200);
609 let input = ExecInput {
610 target: Some(previous_stage),
611 checkpoint: Some(StageCheckpoint::new(checkpoint)),
612 };
613 let headers = runner.seed_execution(input).expect("failed to seed execution");
614 let rx = runner.execute(input);
615
616 runner.client.extend(headers.iter().rev().map(|h| h.clone().unseal())).await;
617
618 let tip = headers.last().unwrap();
620 runner.send_tip(tip.hash());
621
622 let result = rx.await.unwrap();
623 runner.db().factory.static_file_provider().commit().unwrap();
624 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
625 block_number,
626 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
627 block_range: CheckpointBlockRange {
628 from,
629 to
630 },
631 progress: EntitiesCheckpoint {
632 processed,
633 total,
634 }
635 }))
636 }, done: true }) if block_number == tip.number &&
637 from == checkpoint && to == previous_stage &&
638 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
640 );
641 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
642 assert!(runner.stage().hash_collector.is_empty());
643 assert!(runner.stage().header_collector.is_empty());
644
645 let sealed_headers =
647 random_header_range(&mut generators::rng(), tip.number..tip.number + 10, tip.hash());
648
649 let sealed_blocks = sealed_headers
651 .iter()
652 .map(|header| {
653 SealedBlockWithSenders::new(
654 SealedBlock::new(header.clone(), BlockBody::default()),
655 vec![],
656 )
657 .unwrap()
658 })
659 .collect();
660
661 let provider = runner.db().factory.provider_rw().unwrap();
663 provider
664 .append_blocks_with_state(
665 sealed_blocks,
666 ExecutionOutcome::default(),
667 HashedPostStateSorted::default(),
668 TrieUpdates::default(),
669 )
670 .unwrap();
671 provider.commit().unwrap();
672
673 let unwind_input = UnwindInput {
675 checkpoint: StageCheckpoint::new(tip.number + 10),
676 unwind_to: tip.number,
677 bad_block: None,
678 };
679
680 let unwind_output = runner.unwind(unwind_input).await.unwrap();
681 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
682
683 assert!(runner.validate_unwind(unwind_input).is_ok());
685 }
686
687 #[tokio::test]
689 async fn execute_with_linear_downloader() {
690 let mut runner = HeadersTestRunner::with_linear_downloader();
691 let (checkpoint, previous_stage) = (1000, 1200);
692 let input = ExecInput {
693 target: Some(previous_stage),
694 checkpoint: Some(StageCheckpoint::new(checkpoint)),
695 };
696 let headers = runner.seed_execution(input).expect("failed to seed execution");
697 let rx = runner.execute(input);
698
699 runner.client.extend(headers.iter().rev().map(|h| h.clone().unseal())).await;
700
701 let tip = headers.last().unwrap();
703 runner.send_tip(tip.hash());
704
705 let result = rx.await.unwrap();
706 runner.db().factory.static_file_provider().commit().unwrap();
707 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
708 block_number,
709 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
710 block_range: CheckpointBlockRange {
711 from,
712 to
713 },
714 progress: EntitiesCheckpoint {
715 processed,
716 total,
717 }
718 }))
719 }, done: true }) if block_number == tip.number &&
720 from == checkpoint && to == previous_stage &&
721 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
723 );
724 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
725 assert!(runner.stage().hash_collector.is_empty());
726 assert!(runner.stage().header_collector.is_empty());
727 }
728}