diff --git a/crates/engine/tree/Cargo.toml b/crates/engine/tree/Cargo.toml index 6a6a67a5e36bb..b5b8fc743645d 100644 --- a/crates/engine/tree/Cargo.toml +++ b/crates/engine/tree/Cargo.toml @@ -32,7 +32,6 @@ reth-prune.workspace = true reth-revm.workspace = true reth-stages-api.workspace = true reth-tasks.workspace = true -reth-trie-db.workspace = true reth-trie-parallel.workspace = true reth-trie-sparse.workspace = true reth-trie.workspace = true @@ -82,6 +81,7 @@ reth-stages = { workspace = true, features = ["test-utils"] } reth-static-file.workspace = true reth-testing-utils.workspace = true reth-tracing.workspace = true +reth-trie-db.workspace = true # alloy alloy-rlp.workspace = true @@ -120,6 +120,6 @@ test-utils = [ "reth-static-file", "reth-tracing", "reth-trie/test-utils", - "reth-prune-types?/test-utils", "reth-trie-db/test-utils", + "reth-prune-types?/test-utils", ] diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index cb64d95d8f924..545299a041cce 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -3,23 +3,18 @@ use alloy_primitives::{map::HashSet, Address}; use derive_more::derive::Deref; use rayon::iter::{ParallelBridge, ParallelIterator}; -use reth_errors::{ProviderError, ProviderResult}; +use reth_errors::ProviderError; use reth_evm::system_calls::OnStateHook; use reth_provider::{ - providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, - StateCommitmentProvider, + providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateCommitmentProvider, }; use reth_trie::{ - hashed_cursor::HashedPostStateCursorFactory, prefix_set::TriePrefixSetsMut, - proof::Proof, - trie_cursor::InMemoryTrieCursorFactory, updates::{TrieUpdates, TrieUpdatesSorted}, HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles, TrieInput, }; -use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory}; -use reth_trie_parallel::root::ParallelStateRootError; +use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError}; use reth_trie_sparse::{ blinded::{BlindedProvider, BlindedProviderFactory}, errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind}, @@ -400,20 +395,31 @@ where state_root_message_sender: Sender>, ) { // Dispatch proof gathering for this state update - scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) { - Ok(proof) => { - let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( - Box::new(ProofCalculated { - state_update: hashed_state_update, - targets: proof_targets, - proof, - sequence_number: proof_sequence_number, - }), - )); - } - Err(error) => { - let _ = - state_root_message_sender.send(StateRootMessage::ProofCalculationError(error)); + scope.spawn(move |_| { + let result = ParallelProof::new( + config.consistent_view.clone(), + config.nodes_sorted.clone(), + config.state_sorted.clone(), + config.prefix_sets.clone(), + ) + .with_branch_node_hash_masks(true) + .multiproof(proof_targets.clone()); + + match result { + Ok(proof) => { + let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( + Box::new(ProofCalculated { + state_update: hashed_state_update, + targets: proof_targets, + proof, + sequence_number: proof_sequence_number, + }), + )); + } + Err(error) => { + let _ = state_root_message_sender + .send(StateRootMessage::ProofCalculationError(error.into())); + } } }); } @@ -714,31 +720,6 @@ fn get_proof_targets( targets } -/// Calculate multiproof for the targets. -#[inline] -fn calculate_multiproof( - config: StateRootConfig, - proof_targets: MultiProofTargets, -) -> ProviderResult -where - Factory: DatabaseProviderFactory + StateCommitmentProvider, -{ - let provider = config.consistent_view.provider_ro()?; - - Ok(Proof::from_tx(provider.tx_ref()) - .with_trie_cursor_factory(InMemoryTrieCursorFactory::new( - DatabaseTrieCursorFactory::new(provider.tx_ref()), - &config.nodes_sorted, - )) - .with_hashed_cursor_factory(HashedPostStateCursorFactory::new( - DatabaseHashedCursorFactory::new(provider.tx_ref()), - &config.state_sorted, - )) - .with_prefix_sets_mut(config.prefix_sets.as_ref().clone()) - .with_branch_node_hash_masks(true) - .multiproof(proof_targets)?) -} - /// Updates the sparse trie with the given proofs and state, and returns the updated trie and the /// time it took. fn update_sparse_trie< diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 1d9daff5c9206..6893d2234be21 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -24,8 +24,8 @@ use reth_trie::{ }; use reth_trie_common::proof::ProofRetainer; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; -use std::sync::Arc; -use tracing::{debug, error}; +use std::{sync::Arc, time::Instant}; +use tracing::{debug, error, trace}; #[cfg(feature = "metrics")] use crate::metrics::ParallelStateRootMetrics; @@ -112,26 +112,58 @@ where prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())), prefix_sets.storage_prefix_sets.clone(), ); + let storage_root_targets_len = storage_root_targets.len(); + + let num_threads = + std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1)); + + // create a local thread pool with a fixed number of workers + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("proof-worker-{}", i)) + .build() + .map_err(|e| ParallelStateRootError::Other(e.to_string()))?; + + debug!( + target: "trie::parallel_state_root", + total_targets = storage_root_targets_len, + num_threads, + "Starting parallel proof generation" + ); - // Pre-calculate storage roots for accounts which were changed. - tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); - debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs"); let mut storage_proofs = B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default()); + for (hashed_address, prefix_set) in storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) { let view = self.view.clone(); let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); - let trie_nodes_sorted = self.nodes_sorted.clone(); let hashed_state_sorted = self.state_sorted.clone(); + let collect_masks = self.collect_branch_node_hash_masks; let (tx, rx) = std::sync::mpsc::sync_channel(1); - rayon::spawn_fifo(move || { + pool.spawn_fifo(move || { + debug!( + target: "trie::parallel", + ?hashed_address, + "Starting proof calculation" + ); + + let task_start = Instant::now(); let result = (|| -> Result<_, ParallelStateRootError> { + let provider_start = Instant::now(); let provider_ro = view.provider_ro()?; + trace!( + target: "trie::parallel", + ?hashed_address, + provider_time_ms = provider_start.elapsed().as_millis(), + "Got provider" + ); + + let cursor_start = Instant::now(); let trie_cursor_factory = InMemoryTrieCursorFactory::new( DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), &trie_nodes_sorted, @@ -140,19 +172,43 @@ where DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), &hashed_state_sorted, ); + trace!( + target: "trie::parallel", + ?hashed_address, + cursor_time_ms = cursor_start.elapsed().as_millis(), + "Created cursors" + ); - StorageProof::new_hashed( + let proof_start = Instant::now(); + let proof_result = StorageProof::new_hashed( trie_cursor_factory, hashed_cursor_factory, hashed_address, ) .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) - .with_branch_node_hash_masks(self.collect_branch_node_hash_masks) + .with_branch_node_hash_masks(collect_masks) .storage_multiproof(target_slots) - .map_err(|e| ParallelStateRootError::Other(e.to_string())) + .map_err(|e| ParallelStateRootError::Other(e.to_string())); + + trace!( + target: "trie::parallel", + ?hashed_address, + proof_time_ms = proof_start.elapsed().as_millis(), + "Completed proof calculation" + ); + + proof_result })(); - if let Err(err) = tx.send(result) { - error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result"); + + let task_time = task_start.elapsed(); + if let Err(e) = tx.send(result) { + error!( + target: "trie::parallel", + ?hashed_address, + error = ?e, + task_time_ms = task_time.as_millis(), + "Failed to send proof result" + ); } }); storage_proofs.insert(hashed_address, rx);