1use alloy_primitives::{Address, TxNumber};
2use reth_config::config::SenderRecoveryConfig;
3use reth_consensus::ConsensusError;
4use reth_db::{static_file::TransactionMask, table::Value, tables, RawValue};
5use reth_db_api::{
6 cursor::DbCursorRW,
7 transaction::{DbTx, DbTxMut},
8 DbTxUnwindExt,
9};
10use reth_primitives::{GotExpected, NodePrimitives, StaticFileSegment};
11use reth_primitives_traits::SignedTransaction;
12use reth_provider::{
13 BlockReader, DBProvider, HeaderProvider, ProviderError, PruneCheckpointReader,
14 StaticFileProviderFactory, StatsReader,
15};
16use reth_prune_types::PruneSegment;
17use reth_stages_api::{
18 BlockErrorKind, EntitiesCheckpoint, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError,
19 StageId, UnwindInput, UnwindOutput,
20};
21use std::{fmt::Debug, ops::Range, sync::mpsc};
22use thiserror::Error;
23use tracing::*;
24
25const BATCH_SIZE: usize = 100_000;
29
30const WORKER_CHUNK_SIZE: usize = 100;
32
33type RecoveryResultSender = mpsc::Sender<Result<(u64, Address), Box<SenderRecoveryStageError>>>;
35
36#[derive(Clone, Debug)]
40pub struct SenderRecoveryStage {
41 pub commit_threshold: u64,
44}
45
46impl SenderRecoveryStage {
47 pub const fn new(config: SenderRecoveryConfig) -> Self {
49 Self { commit_threshold: config.commit_threshold }
50 }
51}
52
53impl Default for SenderRecoveryStage {
54 fn default() -> Self {
55 Self { commit_threshold: 5_000_000 }
56 }
57}
58
59impl<Provider> Stage<Provider> for SenderRecoveryStage
60where
61 Provider: DBProvider<Tx: DbTxMut>
62 + BlockReader
63 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>
64 + StatsReader
65 + PruneCheckpointReader,
66{
67 fn id(&self) -> StageId {
69 StageId::SenderRecovery
70 }
71
72 fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
77 if input.target_reached() {
78 return Ok(ExecOutput::done(input.checkpoint()))
79 }
80
81 let (tx_range, block_range, is_final_range) =
82 input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?;
83 let end_block = *block_range.end();
84
85 if tx_range.is_empty() {
87 info!(target: "sync::stages::sender_recovery", ?tx_range, "Target transaction already reached");
88 return Ok(ExecOutput {
89 checkpoint: StageCheckpoint::new(end_block)
90 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
91 done: is_final_range,
92 })
93 }
94
95 let mut senders_cursor = provider.tx_ref().cursor_write::<tables::TransactionSenders>()?;
97
98 info!(target: "sync::stages::sender_recovery", ?tx_range, "Recovering senders");
99
100 let batch = tx_range
102 .clone()
103 .step_by(BATCH_SIZE)
104 .map(|start| start..std::cmp::min(start + BATCH_SIZE as u64, tx_range.end))
105 .collect::<Vec<Range<u64>>>();
106
107 let tx_batch_sender = setup_range_recovery(provider);
108
109 for range in batch {
110 recover_range(range, provider, tx_batch_sender.clone(), &mut senders_cursor)?;
111 }
112
113 Ok(ExecOutput {
114 checkpoint: StageCheckpoint::new(end_block)
115 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
116 done: is_final_range,
117 })
118 }
119
120 fn unwind(
122 &mut self,
123 provider: &Provider,
124 input: UnwindInput,
125 ) -> Result<UnwindOutput, StageError> {
126 let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
127
128 let latest_tx_id = provider
130 .block_body_indices(unwind_to)?
131 .ok_or(ProviderError::BlockBodyIndicesNotFound(unwind_to))?
132 .last_tx_num();
133 provider.tx_ref().unwind_table_by_num::<tables::TransactionSenders>(latest_tx_id)?;
134
135 Ok(UnwindOutput {
136 checkpoint: StageCheckpoint::new(unwind_to)
137 .with_entities_stage_checkpoint(stage_checkpoint(provider)?),
138 })
139 }
140}
141
142fn recover_range<Provider, CURSOR>(
143 tx_range: Range<u64>,
144 provider: &Provider,
145 tx_batch_sender: mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>,
146 senders_cursor: &mut CURSOR,
147) -> Result<(), StageError>
148where
149 Provider: DBProvider + HeaderProvider + StaticFileProviderFactory,
150 CURSOR: DbCursorRW<tables::TransactionSenders>,
151{
152 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Sending batch for processing");
153
154 let (chunks, receivers): (Vec<_>, Vec<_>) = tx_range
156 .clone()
157 .step_by(WORKER_CHUNK_SIZE)
158 .map(|start| {
159 let range = start..std::cmp::min(start + WORKER_CHUNK_SIZE as u64, tx_range.end);
160 let (tx, rx) = mpsc::channel();
161 ((range, tx), rx)
163 })
164 .unzip();
165
166 if let Some(err) = tx_batch_sender.send(chunks).err() {
167 return Err(StageError::Fatal(err.into()));
168 }
169
170 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Appending recovered senders to the database");
171
172 let mut processed_transactions = 0;
173 for channel in receivers {
174 while let Ok(recovered) = channel.recv() {
175 let (tx_id, sender) = match recovered {
176 Ok(result) => result,
177 Err(error) => {
178 return match *error {
179 SenderRecoveryStageError::FailedRecovery(err) => {
180 let block_number = provider
182 .tx_ref()
183 .get::<tables::TransactionBlocks>(err.tx)?
184 .ok_or(ProviderError::BlockNumberForTransactionIndexNotFound)?;
185
186 let sealed_header =
189 provider.sealed_header(block_number)?.ok_or_else(|| {
190 ProviderError::HeaderNotFound(block_number.into())
191 })?;
192
193 Err(StageError::Block {
194 block: Box::new(sealed_header.block_with_parent()),
195 error: BlockErrorKind::Validation(
196 ConsensusError::TransactionSignerRecoveryError,
197 ),
198 })
199 }
200 SenderRecoveryStageError::StageError(err) => Err(err),
201 SenderRecoveryStageError::RecoveredSendersMismatch(expectation) => {
202 Err(StageError::Fatal(
203 SenderRecoveryStageError::RecoveredSendersMismatch(expectation)
204 .into(),
205 ))
206 }
207 }
208 }
209 };
210 senders_cursor.append(tx_id, sender)?;
211 processed_transactions += 1;
212 }
213 }
214 debug!(target: "sync::stages::sender_recovery", ?tx_range, "Finished recovering senders batch");
215
216 let expected = tx_range.end - tx_range.start;
218 if processed_transactions != expected {
219 return Err(StageError::Fatal(
220 SenderRecoveryStageError::RecoveredSendersMismatch(GotExpected {
221 got: processed_transactions,
222 expected,
223 })
224 .into(),
225 ));
226 }
227 Ok(())
228}
229
230fn setup_range_recovery<Provider>(
234 provider: &Provider,
235) -> mpsc::Sender<Vec<(Range<u64>, RecoveryResultSender)>>
236where
237 Provider: DBProvider
238 + HeaderProvider
239 + StaticFileProviderFactory<Primitives: NodePrimitives<SignedTx: Value + SignedTransaction>>,
240{
241 let (tx_sender, tx_receiver) = mpsc::channel::<Vec<(Range<u64>, RecoveryResultSender)>>();
242 let static_file_provider = provider.static_file_provider();
243
244 std::thread::spawn(move || {
252 while let Ok(chunks) = tx_receiver.recv() {
253 for (chunk_range, recovered_senders_tx) in chunks {
254 let chunk = match static_file_provider.fetch_range_with_predicate(
256 StaticFileSegment::Transactions,
257 chunk_range.clone(),
258 |cursor, number| {
259 Ok(cursor
260 .get_one::<TransactionMask<
261 RawValue<<Provider::Primitives as NodePrimitives>::SignedTx>,
262 >>(number.into())?
263 .map(|tx| (number, tx)))
264 },
265 |_| true,
266 ) {
267 Ok(chunk) => chunk,
268 Err(err) => {
269 let _ = recovered_senders_tx
271 .send(Err(Box::new(SenderRecoveryStageError::StageError(err.into()))));
272 break
273 }
274 };
275
276 rayon::spawn(move || {
280 let mut rlp_buf = Vec::with_capacity(128);
281 for (number, tx) in chunk {
282 let res = tx
283 .value()
284 .map_err(|err| {
285 Box::new(SenderRecoveryStageError::StageError(err.into()))
286 })
287 .and_then(|tx| recover_sender((number, tx), &mut rlp_buf));
288
289 let is_err = res.is_err();
290
291 let _ = recovered_senders_tx.send(res);
292
293 if is_err {
295 break
296 }
297 }
298 });
299 }
300 }
301 });
302 tx_sender
303}
304
305#[inline]
306fn recover_sender<T: SignedTransaction>(
307 (tx_id, tx): (TxNumber, T),
308 rlp_buf: &mut Vec<u8>,
309) -> Result<(u64, Address), Box<SenderRecoveryStageError>> {
310 rlp_buf.clear();
311 let sender = tx
317 .recover_signer_unchecked_with_buf(rlp_buf)
318 .ok_or(SenderRecoveryStageError::FailedRecovery(FailedSenderRecoveryError { tx: tx_id }))?;
319
320 Ok((tx_id, sender))
321}
322
323fn stage_checkpoint<Provider>(provider: &Provider) -> Result<EntitiesCheckpoint, StageError>
324where
325 Provider: StatsReader + StaticFileProviderFactory + PruneCheckpointReader,
326{
327 let pruned_entries = provider
328 .get_prune_checkpoint(PruneSegment::SenderRecovery)?
329 .and_then(|checkpoint| checkpoint.tx_number)
330 .unwrap_or_default();
331 Ok(EntitiesCheckpoint {
332 processed: provider.count_entries::<tables::TransactionSenders>()? as u64 + pruned_entries,
336 total: provider.static_file_provider().count_entries::<tables::Transactions>()? as u64,
340 })
341}
342
343#[derive(Error, Debug)]
344#[error(transparent)]
345enum SenderRecoveryStageError {
346 #[error(transparent)]
348 FailedRecovery(#[from] FailedSenderRecoveryError),
349
350 #[error("mismatched sender count during recovery: {_0}")]
352 RecoveredSendersMismatch(GotExpected<u64>),
353
354 #[error(transparent)]
356 StageError(#[from] StageError),
357}
358
359#[derive(Error, Debug)]
360#[error("sender recovery failed for transaction {tx}")]
361struct FailedSenderRecoveryError {
362 tx: TxNumber,
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::test_utils::{
370 stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, StorageKind,
371 TestRunnerError, TestStageDB, UnwindStageTestRunner,
372 };
373 use alloy_primitives::{BlockNumber, B256};
374 use assert_matches::assert_matches;
375 use reth_db_api::cursor::DbCursorRO;
376 use reth_primitives::{SealedBlock, TransactionSigned};
377 use reth_primitives_traits::SignedTransaction;
378 use reth_provider::{
379 providers::StaticFileWriter, BlockBodyIndicesProvider, DatabaseProviderFactory,
380 PruneCheckpointWriter, StaticFileProviderFactory, TransactionsProvider,
381 };
382 use reth_prune_types::{PruneCheckpoint, PruneMode};
383 use reth_stages_api::StageUnitCheckpoint;
384 use reth_testing_utils::generators::{
385 self, random_block, random_block_range, BlockParams, BlockRangeParams,
386 };
387
388 stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery);
389
390 #[tokio::test]
392 async fn execute_single_transaction() {
393 let (previous_stage, stage_progress) = (500, 100);
394 let mut rng = generators::rng();
395
396 let runner = SenderRecoveryTestRunner::default();
398 let input = ExecInput {
399 target: Some(previous_stage),
400 checkpoint: Some(StageCheckpoint::new(stage_progress)),
401 };
402
403 let non_empty_block_number = stage_progress + 10;
405 let blocks = (stage_progress..=input.target())
406 .map(|number| {
407 random_block(
408 &mut rng,
409 number,
410 BlockParams {
411 tx_count: Some((number == non_empty_block_number) as u8),
412 ..Default::default()
413 },
414 )
415 })
416 .collect::<Vec<_>>();
417 runner
418 .db
419 .insert_blocks(blocks.iter(), StorageKind::Static)
420 .expect("failed to insert blocks");
421
422 let rx = runner.execute(input);
423
424 let result = rx.await.unwrap();
426 assert_matches!(
427 result,
428 Ok(ExecOutput { checkpoint: StageCheckpoint {
429 block_number,
430 stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint {
431 processed: 1,
432 total: 1
433 }))
434 }, done: true }) if block_number == previous_stage
435 );
436
437 assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
439 }
440
441 #[tokio::test]
443 async fn execute_intermediate_commit() {
444 let mut rng = generators::rng();
445
446 let threshold = 10;
447 let mut runner = SenderRecoveryTestRunner::default();
448 runner.set_threshold(threshold);
449 let (stage_progress, previous_stage) = (1000, 1100); let seed = random_block_range(
453 &mut rng,
454 stage_progress + 1..=previous_stage,
455 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..4, ..Default::default() },
456 ); runner
458 .db
459 .insert_blocks(seed.iter(), StorageKind::Static)
460 .expect("failed to seed execution");
461
462 let total_transactions = runner
463 .db
464 .factory
465 .static_file_provider()
466 .count_entries::<tables::Transactions>()
467 .unwrap() as u64;
468
469 let first_input = ExecInput {
470 target: Some(previous_stage),
471 checkpoint: Some(StageCheckpoint::new(stage_progress)),
472 };
473
474 let result = runner.execute(first_input).await.unwrap();
476 let mut tx_count = 0;
477 let expected_progress = seed
478 .iter()
479 .find(|x| {
480 tx_count += x.body.transactions.len();
481 tx_count as u64 > threshold
482 })
483 .map(|x| x.number)
484 .unwrap_or(previous_stage);
485 assert_matches!(result, Ok(_));
486 assert_eq!(
487 result.unwrap(),
488 ExecOutput {
489 checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
490 EntitiesCheckpoint {
491 processed: runner.db.table::<tables::TransactionSenders>().unwrap().len()
492 as u64,
493 total: total_transactions
494 }
495 ),
496 done: false
497 }
498 );
499
500 runner.set_threshold(u64::MAX);
502 let second_input = ExecInput {
503 target: Some(previous_stage),
504 checkpoint: Some(StageCheckpoint::new(expected_progress)),
505 };
506 let result = runner.execute(second_input).await.unwrap();
507 assert_matches!(result, Ok(_));
508 assert_eq!(
509 result.as_ref().unwrap(),
510 &ExecOutput {
511 checkpoint: StageCheckpoint::new(previous_stage).with_entities_stage_checkpoint(
512 EntitiesCheckpoint { processed: total_transactions, total: total_transactions }
513 ),
514 done: true
515 }
516 );
517
518 assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed");
519 }
520
521 #[test]
522 fn stage_checkpoint_pruned() {
523 let db = TestStageDB::default();
524 let mut rng = generators::rng();
525
526 let blocks = random_block_range(
527 &mut rng,
528 0..=100,
529 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..10, ..Default::default() },
530 );
531 db.insert_blocks(blocks.iter(), StorageKind::Static).expect("insert blocks");
532
533 let max_pruned_block = 30;
534 let max_processed_block = 70;
535
536 let mut tx_senders = Vec::new();
537 let mut tx_number = 0;
538 for block in &blocks[..=max_processed_block] {
539 for transaction in &block.body.transactions {
540 if block.number > max_pruned_block {
541 tx_senders
542 .push((tx_number, transaction.recover_signer().expect("recover signer")));
543 }
544 tx_number += 1;
545 }
546 }
547 db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
548
549 let provider = db.factory.provider_rw().unwrap();
550 provider
551 .save_prune_checkpoint(
552 PruneSegment::SenderRecovery,
553 PruneCheckpoint {
554 block_number: Some(max_pruned_block),
555 tx_number: Some(
556 blocks[..=max_pruned_block as usize]
557 .iter()
558 .map(|block| block.body.transactions.len() as u64)
559 .sum(),
560 ),
561 prune_mode: PruneMode::Full,
562 },
563 )
564 .expect("save stage checkpoint");
565 provider.commit().expect("commit");
566
567 let provider = db.factory.database_provider_rw().unwrap();
568 assert_eq!(
569 stage_checkpoint(&provider).expect("stage checkpoint"),
570 EntitiesCheckpoint {
571 processed: blocks[..=max_processed_block]
572 .iter()
573 .map(|block| block.body.transactions.len() as u64)
574 .sum(),
575 total: blocks.iter().map(|block| block.body.transactions.len() as u64).sum()
576 }
577 );
578 }
579
580 struct SenderRecoveryTestRunner {
581 db: TestStageDB,
582 threshold: u64,
583 }
584
585 impl Default for SenderRecoveryTestRunner {
586 fn default() -> Self {
587 Self { threshold: 1000, db: TestStageDB::default() }
588 }
589 }
590
591 impl SenderRecoveryTestRunner {
592 fn set_threshold(&mut self, threshold: u64) {
593 self.threshold = threshold;
594 }
595
596 fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
603 let body_result = self
604 .db
605 .factory
606 .provider_rw()?
607 .block_body_indices(block)?
608 .ok_or(ProviderError::BlockBodyIndicesNotFound(block));
609 match body_result {
610 Ok(body) => self.db.ensure_no_entry_above::<tables::TransactionSenders, _>(
611 body.last_tx_num(),
612 |key| key,
613 )?,
614 Err(_) => {
615 assert!(self.db.table_is_empty::<tables::TransactionSenders>()?);
616 }
617 };
618
619 Ok(())
620 }
621 }
622
623 impl StageTestRunner for SenderRecoveryTestRunner {
624 type S = SenderRecoveryStage;
625
626 fn db(&self) -> &TestStageDB {
627 &self.db
628 }
629
630 fn stage(&self) -> Self::S {
631 SenderRecoveryStage { commit_threshold: self.threshold }
632 }
633 }
634
635 impl ExecuteStageTestRunner for SenderRecoveryTestRunner {
636 type Seed = Vec<SealedBlock>;
637
638 fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
639 let mut rng = generators::rng();
640 let stage_progress = input.checkpoint().block_number;
641 let end = input.target();
642
643 let blocks = random_block_range(
644 &mut rng,
645 stage_progress..=end,
646 BlockRangeParams { parent: Some(B256::ZERO), tx_count: 0..2, ..Default::default() },
647 );
648 self.db.insert_blocks(blocks.iter(), StorageKind::Static)?;
649 Ok(blocks)
650 }
651
652 fn validate_execution(
653 &self,
654 input: ExecInput,
655 output: Option<ExecOutput>,
656 ) -> Result<(), TestRunnerError> {
657 match output {
658 Some(output) => {
659 let provider = self.db.factory.provider()?;
660 let start_block = input.next_block();
661 let end_block = output.checkpoint.block_number;
662
663 if start_block > end_block {
664 return Ok(())
665 }
666
667 let mut body_cursor =
668 provider.tx_ref().cursor_read::<tables::BlockBodyIndices>()?;
669 body_cursor.seek_exact(start_block)?;
670
671 while let Some((_, body)) = body_cursor.next()? {
672 for tx_id in body.tx_num_range() {
673 let transaction: TransactionSigned = provider
674 .transaction_by_id_unhashed(tx_id)?
675 .map(|tx| {
676 TransactionSigned::new_unhashed(tx.transaction, tx.signature)
677 })
678 .expect("no transaction entry");
679 let signer =
680 transaction.recover_signer().expect("failed to recover signer");
681 assert_eq!(Some(signer), provider.transaction_sender(tx_id)?)
682 }
683 }
684 }
685 None => self.ensure_no_senders_by_block(input.checkpoint().block_number)?,
686 };
687
688 Ok(())
689 }
690 }
691
692 impl UnwindStageTestRunner for SenderRecoveryTestRunner {
693 fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
694 self.ensure_no_senders_by_block(input.unwind_to)
695 }
696 }
697}