diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 1d9daff5c920..149a53a1e4b2 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -24,15 +24,15 @@ use reth_trie::{ }; use reth_trie_common::proof::ProofRetainer; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; -use std::sync::Arc; -use tracing::{debug, error}; +use std::{sync::Arc, time::Instant}; +use tracing::{debug, error, trace}; #[cfg(feature = "metrics")] use crate::metrics::ParallelStateRootMetrics; /// TODO: #[derive(Debug)] -pub struct ParallelProof { +pub struct ParallelProof<'env, Factory> { /// Consistent view of the database. view: ConsistentDbView, /// The sorted collection of cached in-memory intermediate trie nodes that @@ -46,18 +46,21 @@ pub struct ParallelProof { pub prefix_sets: Arc, /// Flag indicating whether to include branch node hash masks in the proof. collect_branch_node_hash_masks: bool, + /// Thread pool for local tasks + thread_pool: &'env rayon::ThreadPool, /// Parallel state root metrics. #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics, } -impl ParallelProof { +impl<'env, Factory> ParallelProof<'env, Factory> { /// Create new state proof generator. pub fn new( view: ConsistentDbView, nodes_sorted: Arc, state_sorted: Arc, prefix_sets: Arc, + thread_pool: &'env rayon::ThreadPool, ) -> Self { Self { view, @@ -65,6 +68,7 @@ impl ParallelProof { state_sorted, prefix_sets, collect_branch_node_hash_masks: false, + thread_pool, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), } @@ -77,7 +81,7 @@ impl ParallelProof { } } -impl ParallelProof +impl ParallelProof<'_, Factory> where Factory: DatabaseProviderFactory + StateCommitmentProvider @@ -112,26 +116,50 @@ where prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())), prefix_sets.storage_prefix_sets.clone(), ); + let storage_root_targets_len = storage_root_targets.len(); + + debug!( + target: "trie::parallel_state_root", + total_targets = storage_root_targets_len, + "Starting parallel proof generation" + ); // Pre-calculate storage roots for accounts which were changed. tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); - debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs"); + let mut storage_proofs = B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default()); + for (hashed_address, prefix_set) in storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) { let view = self.view.clone(); let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); - let trie_nodes_sorted = self.nodes_sorted.clone(); let hashed_state_sorted = self.state_sorted.clone(); + let collect_masks = self.collect_branch_node_hash_masks; let (tx, rx) = std::sync::mpsc::sync_channel(1); - rayon::spawn_fifo(move || { + self.thread_pool.spawn_fifo(move || { + debug!( + target: "trie::parallel", + ?hashed_address, + "Starting proof calculation" + ); + + let task_start = Instant::now(); let result = (|| -> Result<_, ParallelStateRootError> { + let provider_start = Instant::now(); let provider_ro = view.provider_ro()?; + trace!( + target: "trie::parallel", + ?hashed_address, + provider_time = ?provider_start.elapsed(), + "Got provider" + ); + + let cursor_start = Instant::now(); let trie_cursor_factory = InMemoryTrieCursorFactory::new( DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), &trie_nodes_sorted, @@ -140,19 +168,42 @@ where DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), &hashed_state_sorted, ); + trace!( + target: "trie::parallel", + ?hashed_address, + cursor_time = ?cursor_start.elapsed(), + "Created cursors" + ); - StorageProof::new_hashed( + let proof_start = Instant::now(); + let proof_result = StorageProof::new_hashed( trie_cursor_factory, hashed_cursor_factory, hashed_address, ) .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) - .with_branch_node_hash_masks(self.collect_branch_node_hash_masks) + .with_branch_node_hash_masks(collect_masks) .storage_multiproof(target_slots) - .map_err(|e| ParallelStateRootError::Other(e.to_string())) + .map_err(|e| ParallelStateRootError::Other(e.to_string())); + + trace!( + target: "trie::parallel", + ?hashed_address, + proof_time = ?proof_start.elapsed(), + "Completed proof calculation" + ); + + proof_result })(); - if let Err(err) = tx.send(result) { - error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result"); + + if let Err(e) = tx.send(result) { + error!( + target: "trie::parallel", + ?hashed_address, + error = ?e, + task_time = ?task_start.elapsed(), + "Failed to send proof result" + ); } }); storage_proofs.insert(hashed_address, rx); @@ -338,12 +389,22 @@ mod tests { let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref()); let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref()); + let num_threads = + std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1)); + + let state_root_task_pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("proof-worker-{}", i)) + .build() + .expect("Failed to create proof worker thread pool"); + assert_eq!( ParallelProof::new( consistent_view, Default::default(), Default::default(), - Default::default() + Default::default(), + &state_root_task_pool ) .multiproof(targets.clone()) .unwrap(),