Skip to content

Commit

Permalink
perf(trie): use local ThreadPool in Parallel::multiproof
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Dec 18, 2024
1 parent ef033ab commit 2691434
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 61 deletions.
4 changes: 2 additions & 2 deletions crates/engine/tree/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ reth-prune.workspace = true
reth-revm.workspace = true
reth-stages-api.workspace = true
reth-tasks.workspace = true
reth-trie-db.workspace = true
reth-trie-parallel.workspace = true
reth-trie-sparse.workspace = true
reth-trie.workspace = true
Expand Down Expand Up @@ -82,6 +81,7 @@ reth-stages = { workspace = true, features = ["test-utils"] }
reth-static-file.workspace = true
reth-testing-utils.workspace = true
reth-tracing.workspace = true
reth-trie-db.workspace = true

# alloy
alloy-rlp.workspace = true
Expand Down Expand Up @@ -120,6 +120,6 @@ test-utils = [
"reth-static-file",
"reth-tracing",
"reth-trie/test-utils",
"reth-prune-types?/test-utils",
"reth-trie-db/test-utils",
"reth-prune-types?/test-utils",
]
75 changes: 28 additions & 47 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,18 @@
use alloy_primitives::{map::HashSet, Address};
use derive_more::derive::Deref;
use rayon::iter::{ParallelBridge, ParallelIterator};
use reth_errors::{ProviderError, ProviderResult};
use reth_errors::ProviderError;
use reth_evm::system_calls::OnStateHook;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
StateCommitmentProvider,
providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateCommitmentProvider,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::Proof,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError};
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind},
Expand Down Expand Up @@ -400,20 +395,31 @@ where
state_root_message_sender: Sender<StateRootMessage<BPF>>,
) {
// Dispatch proof gathering for this state update
scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ =
state_root_message_sender.send(StateRootMessage::ProofCalculationError(error));
scope.spawn(move |_| {
let result = ParallelProof::new(
config.consistent_view.clone(),
config.nodes_sorted.clone(),
config.state_sorted.clone(),
config.prefix_sets.clone(),
)
.with_branch_node_hash_masks(true)
.multiproof(proof_targets.clone());

match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error.into()));
}
}
});
}
Expand Down Expand Up @@ -714,31 +720,6 @@ fn get_proof_targets(
targets
}

/// Calculate multiproof for the targets.
#[inline]
fn calculate_multiproof<Factory>(
config: StateRootConfig<Factory>,
proof_targets: MultiProofTargets,
) -> ProviderResult<MultiProof>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider,
{
let provider = config.consistent_view.provider_ro()?;

Ok(Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets)?)
}

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
fn update_sparse_trie<
Expand Down
80 changes: 68 additions & 12 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use reth_trie::{
};
use reth_trie_common::proof::ProofRetainer;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::sync::Arc;
use tracing::{debug, error};
use std::{sync::Arc, time::Instant};
use tracing::{debug, error, trace};

#[cfg(feature = "metrics")]
use crate::metrics::ParallelStateRootMetrics;
Expand Down Expand Up @@ -112,26 +112,58 @@ where
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
);
let storage_root_targets_len = storage_root_targets.len();

let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));

// create a local thread pool with a fixed number of workers
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

debug!(
target: "trie::parallel_state_root",
total_targets = storage_root_targets_len,
num_threads,
"Starting parallel proof generation"
);

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs");
let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());

for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let view = self.view.clone();
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();

let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;

let (tx, rx) = std::sync::mpsc::sync_channel(1);

rayon::spawn_fifo(move || {
pool.spawn_fifo(move || {
debug!(
target: "trie::parallel",
?hashed_address,
"Starting proof calculation"
);

let task_start = Instant::now();
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_start = Instant::now();
let provider_ro = view.provider_ro()?;
trace!(
target: "trie::parallel",
?hashed_address,
provider_time_ms = provider_start.elapsed().as_millis(),
"Got provider"
);

let cursor_start = Instant::now();
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
Expand All @@ -140,19 +172,43 @@ where
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
trace!(
target: "trie::parallel",
?hashed_address,
cursor_time_ms = cursor_start.elapsed().as_millis(),
"Created cursors"
);

StorageProof::new_hashed(
let proof_start = Instant::now();
let proof_result = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.with_branch_node_hash_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))
.map_err(|e| ParallelStateRootError::Other(e.to_string()));

trace!(
target: "trie::parallel",
?hashed_address,
proof_time_ms = proof_start.elapsed().as_millis(),
"Completed proof calculation"
);

proof_result
})();
if let Err(err) = tx.send(result) {
error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result");

let task_time = task_start.elapsed();
if let Err(e) = tx.send(result) {
error!(
target: "trie::parallel",
?hashed_address,
error = ?e,
task_time_ms = task_time.as_millis(),
"Failed to send proof result"
);
}
});
storage_proofs.insert(hashed_address, rx);
Expand Down

0 comments on commit 2691434

Please sign in to comment.