Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(trie): use local ThreadPool in Parallel::multiproof #13416

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ use reth_trie::{
};
use reth_trie_common::proof::ProofRetainer;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::sync::Arc;
use tracing::{debug, error};
use std::{sync::Arc, time::Instant};
use tracing::{debug, error, trace};

#[cfg(feature = "metrics")]
use crate::metrics::ParallelStateRootMetrics;

/// TODO:
#[derive(Debug)]
pub struct ParallelProof<Factory> {
pub struct ParallelProof<'env, Factory> {
/// Consistent view of the database.
view: ConsistentDbView<Factory>,
/// The sorted collection of cached in-memory intermediate trie nodes that
Expand All @@ -46,25 +46,29 @@ pub struct ParallelProof<Factory> {
pub prefix_sets: Arc<TriePrefixSetsMut>,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Thread pool for local tasks
thread_pool: &'env rayon::ThreadPool,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
}

impl<Factory> ParallelProof<Factory> {
impl<'env, Factory> ParallelProof<'env, Factory> {
/// Create new state proof generator.
pub fn new(
view: ConsistentDbView<Factory>,
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
thread_pool: &'env rayon::ThreadPool,
) -> Self {
Self {
view,
nodes_sorted,
state_sorted,
prefix_sets,
collect_branch_node_hash_masks: false,
thread_pool,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
Expand All @@ -77,7 +81,7 @@ impl<Factory> ParallelProof<Factory> {
}
}

impl<Factory> ParallelProof<Factory>
impl<Factory> ParallelProof<'_, Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
Expand Down Expand Up @@ -112,26 +116,50 @@ where
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
);
let storage_root_targets_len = storage_root_targets.len();

debug!(
target: "trie::parallel_state_root",
total_targets = storage_root_targets_len,
"Starting parallel proof generation"
);

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
fgimenez marked this conversation as resolved.
Show resolved Hide resolved
debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs");

let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());

for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let view = self.view.clone();
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();

let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;

let (tx, rx) = std::sync::mpsc::sync_channel(1);

rayon::spawn_fifo(move || {
self.thread_pool.spawn_fifo(move || {
debug!(
target: "trie::parallel",
?hashed_address,
"Starting proof calculation"
);

let task_start = Instant::now();
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_start = Instant::now();
let provider_ro = view.provider_ro()?;
trace!(
target: "trie::parallel",
?hashed_address,
provider_time = ?provider_start.elapsed(),
"Got provider"
);

let cursor_start = Instant::now();
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
Expand All @@ -140,19 +168,42 @@ where
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
trace!(
target: "trie::parallel",
?hashed_address,
cursor_time = ?cursor_start.elapsed(),
"Created cursors"
);

StorageProof::new_hashed(
let proof_start = Instant::now();
let proof_result = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.with_branch_node_hash_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))
.map_err(|e| ParallelStateRootError::Other(e.to_string()));

trace!(
target: "trie::parallel",
?hashed_address,
proof_time = ?proof_start.elapsed(),
"Completed proof calculation"
);

proof_result
})();
if let Err(err) = tx.send(result) {
error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result");

if let Err(e) = tx.send(result) {
error!(
target: "trie::parallel",
?hashed_address,
error = ?e,
task_time = ?task_start.elapsed(),
"Failed to send proof result"
);
}
});
storage_proofs.insert(hashed_address, rx);
Expand Down Expand Up @@ -338,12 +389,22 @@ mod tests {
let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());

let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));

let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");

assert_eq!(
ParallelProof::new(
consistent_view,
Default::default(),
Default::default(),
Default::default()
Default::default(),
&state_root_task_pool
)
.multiproof(targets.clone())
.unwrap(),
Expand Down
Loading