1use alloy_consensus::BlockHeader;
2use alloy_primitives::{BlockHash, BlockNumber, Bytes, B256};
3use futures_util::StreamExt;
4use reth_config::config::EtlConfig;
5use reth_db_api::{
6 cursor::{DbCursorRO, DbCursorRW},
7 table::Value,
8 tables,
9 transaction::{DbTx, DbTxMut},
10 DbTxUnwindExt, RawKey, RawTable, RawValue,
11};
12use reth_etl::Collector;
13use reth_network_p2p::headers::{
14 downloader::{HeaderDownloader, HeaderSyncGap, SyncTarget},
15 error::HeadersDownloaderError,
16};
17use reth_primitives_traits::{serde_bincode_compat, FullBlockHeader, NodePrimitives, SealedHeader};
18use reth_provider::{
19 providers::StaticFileWriter, BlockHashReader, DBProvider, HeaderProvider,
20 HeaderSyncGapProvider, StaticFileProviderFactory,
21};
22use reth_stages_api::{
23 CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, HeadersCheckpoint, Stage,
24 StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput,
25};
26use reth_static_file_types::StaticFileSegment;
27use reth_storage_errors::provider::ProviderError;
28use std::task::{ready, Context, Poll};
29
30use tokio::sync::watch;
31use tracing::*;
32
33#[derive(Debug)]
45pub struct HeaderStage<Provider, Downloader: HeaderDownloader> {
46 provider: Provider,
48 downloader: Downloader,
50 tip: watch::Receiver<B256>,
54 sync_gap: Option<HeaderSyncGap<Downloader::Header>>,
56 hash_collector: Collector<BlockHash, BlockNumber>,
58 header_collector: Collector<BlockNumber, Bytes>,
60 is_etl_ready: bool,
62}
63
64impl<Provider, Downloader> HeaderStage<Provider, Downloader>
67where
68 Downloader: HeaderDownloader,
69{
70 pub fn new(
72 database: Provider,
73 downloader: Downloader,
74 tip: watch::Receiver<B256>,
75 etl_config: EtlConfig,
76 ) -> Self {
77 Self {
78 provider: database,
79 downloader,
80 tip,
81 sync_gap: None,
82 hash_collector: Collector::new(etl_config.file_size / 2, etl_config.dir.clone()),
83 header_collector: Collector::new(etl_config.file_size / 2, etl_config.dir),
84 is_etl_ready: false,
85 }
86 }
87
88 fn write_headers<P>(&mut self, provider: &P) -> Result<BlockNumber, StageError>
93 where
94 P: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
95 Downloader: HeaderDownloader<Header = <P::Primitives as NodePrimitives>::BlockHeader>,
96 <P::Primitives as NodePrimitives>::BlockHeader: Value + FullBlockHeader,
97 {
98 let total_headers = self.header_collector.len();
99
100 info!(target: "sync::stages::headers", total = total_headers, "Writing headers");
101
102 let static_file_provider = provider.static_file_provider();
103
104 let mut last_header_number = static_file_provider
107 .get_highest_static_file_block(StaticFileSegment::Headers)
108 .unwrap_or_default();
109
110 let mut td = static_file_provider
112 .header_td_by_number(last_header_number)?
113 .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;
114
115 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
118 let interval = (total_headers / 10).max(1);
119 for (index, header) in self.header_collector.iter()?.enumerate() {
120 let (_, header_buf) = header?;
121
122 if index > 0 && index % interval == 0 && total_headers > 100 {
123 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers");
124 }
125
126 let sealed_header: SealedHeader<Downloader::Header> =
127 bincode::deserialize::<serde_bincode_compat::SealedHeader<'_, _>>(&header_buf)
128 .map_err(|err| StageError::Fatal(Box::new(err)))?
129 .into();
130
131 let (header, header_hash) = sealed_header.split_ref();
132 if header.number() == 0 {
133 continue
134 }
135 last_header_number = header.number();
136
137 td += header.difficulty();
139
140 writer.append_header(header, td, header_hash)?;
142 }
143
144 info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");
145
146 let mut cursor_header_numbers =
147 provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
148 let mut first_sync = false;
149
150 if provider.tx_ref().entries::<RawTable<tables::HeaderNumbers>>()? == 1 {
153 if let Some((hash, block_number)) = cursor_header_numbers.last()? {
154 if block_number.value()? == 0 {
155 self.hash_collector.insert(hash.key()?, 0)?;
156 cursor_header_numbers.delete_current()?;
157 first_sync = true;
158 }
159 }
160 }
161
162 for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() {
165 let (hash, number) = hash_to_number?;
166
167 if index > 0 && index % interval == 0 && total_headers > 100 {
168 info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers hash index");
169 }
170
171 if first_sync {
172 cursor_header_numbers.append(
173 RawKey::<BlockHash>::from_vec(hash),
174 &RawValue::<BlockNumber>::from_vec(number),
175 )?;
176 } else {
177 cursor_header_numbers.upsert(
178 RawKey::<BlockHash>::from_vec(hash),
179 &RawValue::<BlockNumber>::from_vec(number),
180 )?;
181 }
182 }
183
184 Ok(last_header_number)
185 }
186}
187
188impl<Provider, P, D> Stage<Provider> for HeaderStage<P, D>
189where
190 Provider: DBProvider<Tx: DbTxMut> + StaticFileProviderFactory,
191 P: HeaderSyncGapProvider<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
192 D: HeaderDownloader<Header = <Provider::Primitives as NodePrimitives>::BlockHeader>,
193 <Provider::Primitives as NodePrimitives>::BlockHeader: FullBlockHeader + Value,
194{
195 fn id(&self) -> StageId {
197 StageId::Headers
198 }
199
200 fn poll_execute_ready(
201 &mut self,
202 cx: &mut Context<'_>,
203 input: ExecInput,
204 ) -> Poll<Result<(), StageError>> {
205 let current_checkpoint = input.checkpoint();
206
207 if self.is_etl_ready {
209 return Poll::Ready(Ok(()))
210 }
211
212 let local_head = self.provider.local_tip_header(current_checkpoint.block_number)?;
214 let target = SyncTarget::Tip(*self.tip.borrow());
215 let gap = HeaderSyncGap { local_head, target };
216 let tip = gap.target.tip();
217 self.sync_gap = Some(gap.clone());
218
219 if gap.is_closed() {
221 info!(
222 target: "sync::stages::headers",
223 checkpoint = %current_checkpoint.block_number,
224 target = ?tip,
225 "Target block already reached"
226 );
227 self.is_etl_ready = true;
228 return Poll::Ready(Ok(()))
229 }
230
231 debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync");
232 let local_head_number = gap.local_head.number();
233
234 self.downloader.update_sync_gap(gap.local_head, gap.target);
236
237 loop {
239 match ready!(self.downloader.poll_next_unpin(cx)) {
240 Some(Ok(headers)) => {
241 info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number()), to_block = headers.last().map(|h| h.number()), "Received headers");
242 for header in headers {
243 let header_number = header.number();
244
245 self.hash_collector.insert(header.hash(), header_number)?;
246 self.header_collector.insert(
247 header_number,
248 Bytes::from(
249 bincode::serialize(&serde_bincode_compat::SealedHeader::from(
250 &header,
251 ))
252 .map_err(|err| StageError::Fatal(Box::new(err)))?,
253 ),
254 )?;
255
256 if header_number == local_head_number + 1 {
259 self.is_etl_ready = true;
260 return Poll::Ready(Ok(()))
261 }
262 }
263 }
264 Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => {
265 error!(target: "sync::stages::headers", %error, "Cannot attach header to head");
266 return Poll::Ready(Err(StageError::DetachedHead {
267 local_head: Box::new(local_head.block_with_parent()),
268 header: Box::new(header.block_with_parent()),
269 error,
270 }))
271 }
272 None => return Poll::Ready(Err(StageError::ChannelClosed)),
273 }
274 }
275 }
276
277 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
280 let current_checkpoint = input.checkpoint();
281
282 if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() {
283 self.is_etl_ready = false;
284 return Ok(ExecOutput::done(current_checkpoint))
285 }
286
287 if !self.is_etl_ready {
289 return Err(StageError::MissingDownloadBuffer)
290 }
291
292 self.is_etl_ready = false;
294
295 let to_be_processed = self.hash_collector.len() as u64;
297 let last_header_number = self.write_headers(provider)?;
298
299 self.hash_collector.clear();
301 self.header_collector.clear();
302
303 Ok(ExecOutput {
304 checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint(
305 HeadersCheckpoint {
306 block_range: CheckpointBlockRange {
307 from: input.checkpoint().block_number,
308 to: last_header_number,
309 },
310 progress: EntitiesCheckpoint {
311 processed: input.checkpoint().block_number + to_be_processed,
312 total: last_header_number,
313 },
314 },
315 ),
316 done: true,
319 })
320 }
321
322 fn unwind(
324 &mut self,
325 provider: &Provider,
326 input: UnwindInput,
327 ) -> Result<UnwindOutput, StageError> {
328 self.sync_gap.take();
329
330 provider
334 .tx_ref()
335 .unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
336 (input.unwind_to + 1)..,
337 )?;
338 provider.tx_ref().unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
339 provider
340 .tx_ref()
341 .unwind_table_by_num::<tables::HeaderTerminalDifficulties>(input.unwind_to)?;
342 let unfinalized_headers_unwound =
343 provider.tx_ref().unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
344
345 let static_file_provider = provider.static_file_provider();
348 let highest_block = static_file_provider
349 .get_highest_static_file_block(StaticFileSegment::Headers)
350 .unwrap_or_default();
351 let static_file_headers_to_unwind = highest_block - input.unwind_to;
352 for block_number in (input.unwind_to + 1)..=highest_block {
353 let hash = static_file_provider.block_hash(block_number)?;
354 if let Some(header_hash) = hash {
360 provider.tx_ref().delete::<tables::HeaderNumbers>(header_hash, None)?;
361 }
362 }
363
364 let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;
366 writer.prune_headers(static_file_headers_to_unwind)?;
367
368 let stage_checkpoint =
371 input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
372 block_range: stage_checkpoint.block_range,
373 progress: EntitiesCheckpoint {
374 processed: stage_checkpoint.progress.processed.saturating_sub(
375 static_file_headers_to_unwind + unfinalized_headers_unwound as u64,
376 ),
377 total: stage_checkpoint.progress.total,
378 },
379 });
380
381 let mut checkpoint = StageCheckpoint::new(input.unwind_to);
382 if let Some(stage_checkpoint) = stage_checkpoint {
383 checkpoint = checkpoint.with_headers_stage_checkpoint(stage_checkpoint);
384 }
385
386 Ok(UnwindOutput { checkpoint })
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::test_utils::{
394 stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
395 };
396 use alloy_primitives::B256;
397 use assert_matches::assert_matches;
398 use reth_ethereum_primitives::BlockBody;
399 use reth_execution_types::ExecutionOutcome;
400 use reth_primitives_traits::{RecoveredBlock, SealedBlock};
401 use reth_provider::{BlockWriter, ProviderFactory, StaticFileProviderFactory};
402 use reth_stages_api::StageUnitCheckpoint;
403 use reth_testing_utils::generators::{self, random_header, random_header_range};
404 use reth_trie::{updates::TrieUpdates, HashedPostStateSorted};
405 use std::sync::Arc;
406 use test_runner::HeadersTestRunner;
407
408 mod test_runner {
409 use super::*;
410 use crate::test_utils::{TestRunnerError, TestStageDB};
411 use reth_consensus::test_utils::TestConsensus;
412 use reth_downloaders::headers::reverse_headers::{
413 ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
414 };
415 use reth_network_p2p::test_utils::{TestHeaderDownloader, TestHeadersClient};
416 use reth_provider::{test_utils::MockNodeTypesWithDB, BlockNumReader};
417 use tokio::sync::watch;
418
419 pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
420 pub(crate) client: TestHeadersClient,
421 channel: (watch::Sender<B256>, watch::Receiver<B256>),
422 downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
423 db: TestStageDB,
424 }
425
426 impl Default for HeadersTestRunner<TestHeaderDownloader> {
427 fn default() -> Self {
428 let client = TestHeadersClient::default();
429 Self {
430 client: client.clone(),
431 channel: watch::channel(B256::ZERO),
432
433 downloader_factory: Box::new(move || {
434 TestHeaderDownloader::new(client.clone(), 1000, 1000)
435 }),
436 db: TestStageDB::default(),
437 }
438 }
439 }
440
441 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> StageTestRunner
442 for HeadersTestRunner<D>
443 {
444 type S = HeaderStage<ProviderFactory<MockNodeTypesWithDB>, D>;
445
446 fn db(&self) -> &TestStageDB {
447 &self.db
448 }
449
450 fn stage(&self) -> Self::S {
451 HeaderStage::new(
452 self.db.factory.clone(),
453 (*self.downloader_factory)(),
454 self.channel.1.clone(),
455 EtlConfig::default(),
456 )
457 }
458 }
459
460 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> ExecuteStageTestRunner
461 for HeadersTestRunner<D>
462 {
463 type Seed = Vec<SealedHeader>;
464
465 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
466 let mut rng = generators::rng();
467 let start = input.checkpoint().block_number;
468 let headers = random_header_range(&mut rng, 0..start + 1, B256::ZERO);
469 let head = headers.last().cloned().unwrap();
470 self.db.insert_headers_with_td(headers.iter())?;
471
472 let end = input.target.unwrap_or_default() + 1;
474
475 if start + 1 >= end {
476 return Ok(Vec::default())
477 }
478
479 let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
480 headers.insert(0, head);
481 Ok(headers)
482 }
483
484 fn validate_execution(
486 &self,
487 input: ExecInput,
488 output: Option<ExecOutput>,
489 ) -> Result<(), TestRunnerError> {
490 let initial_checkpoint = input.checkpoint().block_number;
491 match output {
492 Some(output) if output.checkpoint.block_number > initial_checkpoint => {
493 let provider = self.db.factory.provider()?;
494 let mut td = provider
495 .header_td_by_number(initial_checkpoint.saturating_sub(1))?
496 .unwrap_or_default();
497
498 for block_num in initial_checkpoint..output.checkpoint.block_number {
499 let hash = provider.block_hash(block_num)?.expect("no header hash");
501
502 assert_eq!(provider.block_number(hash)?, Some(block_num));
504
505 let header = provider.header_by_number(block_num)?;
507 assert!(header.is_some());
508 let header = SealedHeader::seal_slow(header.unwrap());
509 assert_eq!(header.hash(), hash);
510
511 td += header.difficulty;
513 assert_eq!(provider.header_td_by_number(block_num)?, Some(td));
514 }
515 }
516 _ => self.check_no_header_entry_above(initial_checkpoint)?,
517 };
518 Ok(())
519 }
520
521 async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
522 self.client.extend(headers.iter().map(|h| h.clone_header())).await;
523 let tip = if headers.is_empty() {
524 let tip = random_header(&mut generators::rng(), 0, None);
525 self.db.insert_headers(std::iter::once(&tip))?;
526 tip.hash()
527 } else {
528 headers.last().unwrap().hash()
529 };
530 self.send_tip(tip);
531 Ok(())
532 }
533 }
534
535 impl<D: HeaderDownloader<Header = alloy_consensus::Header> + 'static> UnwindStageTestRunner
536 for HeadersTestRunner<D>
537 {
538 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
539 self.check_no_header_entry_above(input.unwind_to)
540 }
541 }
542
543 impl HeadersTestRunner<ReverseHeadersDownloader<TestHeadersClient>> {
544 pub(crate) fn with_linear_downloader() -> Self {
545 let client = TestHeadersClient::default();
546 Self {
547 client: client.clone(),
548 channel: watch::channel(B256::ZERO),
549 downloader_factory: Box::new(move || {
550 ReverseHeadersDownloaderBuilder::default()
551 .stream_batch_size(500)
552 .build(client.clone(), Arc::new(TestConsensus::default()))
553 }),
554 db: TestStageDB::default(),
555 }
556 }
557 }
558
559 impl<D: HeaderDownloader> HeadersTestRunner<D> {
560 pub(crate) fn check_no_header_entry_above(
561 &self,
562 block: BlockNumber,
563 ) -> Result<(), TestRunnerError> {
564 self.db
565 .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
566 self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
567 self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
568 self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
569 block,
570 |num| num,
571 )?;
572 Ok(())
573 }
574
575 pub(crate) fn send_tip(&self, tip: B256) {
576 self.channel.0.send(tip).expect("failed to send tip");
577 }
578 }
579 }
580
581 stage_test_suite!(HeadersTestRunner, headers);
582
583 #[tokio::test]
586 async fn execute_with_linear_downloader_unwind() {
587 let mut runner = HeadersTestRunner::with_linear_downloader();
588 let (checkpoint, previous_stage) = (1000, 1200);
589 let input = ExecInput {
590 target: Some(previous_stage),
591 checkpoint: Some(StageCheckpoint::new(checkpoint)),
592 };
593 let headers = runner.seed_execution(input).expect("failed to seed execution");
594 let rx = runner.execute(input);
595
596 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
597
598 let tip = headers.last().unwrap();
600 runner.send_tip(tip.hash());
601
602 let result = rx.await.unwrap();
603 runner.db().factory.static_file_provider().commit().unwrap();
604 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
605 block_number,
606 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
607 block_range: CheckpointBlockRange {
608 from,
609 to
610 },
611 progress: EntitiesCheckpoint {
612 processed,
613 total,
614 }
615 }))
616 }, done: true }) if block_number == tip.number &&
617 from == checkpoint && to == previous_stage &&
618 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
620 );
621 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
622 assert!(runner.stage().hash_collector.is_empty());
623 assert!(runner.stage().header_collector.is_empty());
624
625 let sealed_headers =
627 random_header_range(&mut generators::rng(), tip.number..tip.number + 10, tip.hash());
628
629 let sealed_blocks = sealed_headers
631 .iter()
632 .map(|header| {
633 RecoveredBlock::new_sealed(
634 SealedBlock::from_sealed_parts(header.clone(), BlockBody::default()),
635 vec![],
636 )
637 })
638 .collect();
639
640 let provider = runner.db().factory.provider_rw().unwrap();
642 provider
643 .append_blocks_with_state(
644 sealed_blocks,
645 &ExecutionOutcome::default(),
646 HashedPostStateSorted::default(),
647 TrieUpdates::default(),
648 )
649 .unwrap();
650 provider.commit().unwrap();
651
652 let unwind_input = UnwindInput {
654 checkpoint: StageCheckpoint::new(tip.number + 10),
655 unwind_to: tip.number,
656 bad_block: None,
657 };
658
659 let unwind_output = runner.unwind(unwind_input).await.unwrap();
660 assert_eq!(unwind_output.checkpoint.block_number, tip.number);
661
662 assert!(runner.validate_unwind(unwind_input).is_ok());
664 }
665
666 #[tokio::test]
668 async fn execute_with_linear_downloader() {
669 let mut runner = HeadersTestRunner::with_linear_downloader();
670 let (checkpoint, previous_stage) = (1000, 1200);
671 let input = ExecInput {
672 target: Some(previous_stage),
673 checkpoint: Some(StageCheckpoint::new(checkpoint)),
674 };
675 let headers = runner.seed_execution(input).expect("failed to seed execution");
676 let rx = runner.execute(input);
677
678 runner.client.extend(headers.iter().rev().map(|h| h.clone_header())).await;
679
680 let tip = headers.last().unwrap();
682 runner.send_tip(tip.hash());
683
684 let result = rx.await.unwrap();
685 runner.db().factory.static_file_provider().commit().unwrap();
686 assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint {
687 block_number,
688 stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint {
689 block_range: CheckpointBlockRange {
690 from,
691 to
692 },
693 progress: EntitiesCheckpoint {
694 processed,
695 total,
696 }
697 }))
698 }, done: true }) if block_number == tip.number &&
699 from == checkpoint && to == previous_stage &&
700 processed == checkpoint + headers.len() as u64 - 1 && total == tip.number
702 );
703 assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
704 assert!(runner.stage().hash_collector.is_empty());
705 assert!(runner.stage().header_collector.is_empty());
706 }
707}