Skip to content

Commit

Permalink
use local threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Dec 17, 2024
1 parent bbbbe10 commit 90d68f7
Showing 1 changed file with 104 additions and 117 deletions.
221 changes: 104 additions & 117 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,150 +108,137 @@ where
});
let prefix_sets = prefix_sets.freeze();

let storage_root_targets: Vec<_> = StorageRootTargets::new(
let storage_root_targets = StorageRootTargets::new(
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
)
.into_iter()
.sorted_unstable_by_key(|(address, _)| *address)
.collect();
);
let storage_root_targets_len = storage_root_targets.len();

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
let num_threads = std::thread::available_parallelism()
.map_or(0, |num| num.get().saturating_sub(1).max(1));

const CHUNK_SIZE: usize = 128;
let num_chunks = storage_root_targets.len().div_ceil(CHUNK_SIZE);
// create a local thread pool with a fixed number of workers
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

debug!(
target: "trie::parallel_state_root",
total_targets = storage_root_targets.len(),
chunk_size = CHUNK_SIZE,
num_chunks,
"Starting batched proof generation"
total_targets = storage_root_targets_len,
num_threads,
"Starting parallel proof generation"
);

// Create a single channel for all proofs with appropriate capacity
let (tx, rx) = std::sync::mpsc::sync_channel(CHUNK_SIZE);
let (tx, rx) = std::sync::mpsc::sync_channel(num_threads);

let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
for (chunk_idx, chunk) in storage_root_targets.chunks(CHUNK_SIZE).enumerate() {
let chunk_size = chunk.len();

debug!(
target: "trie::parallel_state_root",
chunk_idx,
chunk_size,
"Processing proof batch"
);

// Spawn tasks for this batch
for (hashed_address, prefix_set) in chunk {
let view = self.view.clone();
let target_slots = targets.get(hashed_address).cloned().unwrap_or_default();
let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;
let tx = tx.clone();
let hashed_address = *hashed_address;
let prefix_set = prefix_set.clone();

rayon::spawn(move || {
debug!(

for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let view = self.view.clone();
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;
let tx = tx.clone();

pool.spawn_fifo(move || {
debug!(
target: "trie::parallel",
?hashed_address,
"Starting proof calculation"
);

let task_start = Instant::now();
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_start = Instant::now();
let provider_ro = view.provider_ro()?;
trace!(
target: "trie::parallel",
?hashed_address,
"Starting proof calculation"
provider_time_ms = provider_start.elapsed().as_millis(),
"Got provider"
);

let task_start = Instant::now();
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_start = Instant::now();
let provider_ro = view.provider_ro()?;
trace!(
target: "trie::parallel",
?hashed_address,
provider_time_ms = provider_start.elapsed().as_millis(),
"Got provider"
);
let cursor_start = Instant::now();
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
trace!(
target: "trie::parallel",
?hashed_address,
cursor_time_ms = cursor_start.elapsed().as_millis(),
"Created cursors"
);

let cursor_start = Instant::now();
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
trace!(
target: "trie::parallel",
?hashed_address,
cursor_time_ms = cursor_start.elapsed().as_millis(),
"Created cursors"
);
let proof_start = Instant::now();
let proof = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

let proof_start = Instant::now();
let proof = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

trace!(
target: "trie::parallel",
?hashed_address,
proof_time_ms = proof_start.elapsed().as_millis(),
"Completed proof calculation"
);
trace!(
target: "trie::parallel",
?hashed_address,
proof_time_ms = proof_start.elapsed().as_millis(),
"Completed proof calculation"
);

Ok((hashed_address, proof))
})();
Ok((hashed_address, proof))
})();

let task_time = task_start.elapsed();
if let Err(e) = tx.send(result) {
error!(
target: "trie::parallel",
?hashed_address,
error = ?e,
task_time_ms = task_time.as_millis(),
"Failed to send proof result"
);
}
});
}
let task_time = task_start.elapsed();
if let Err(e) = tx.send(result) {
error!(
target: "trie::parallel",
?hashed_address,
error = ?e,
task_time_ms = task_time.as_millis(),
"Failed to send proof result"
);
}
});
}

// Wait for all proofs in this batch
for _ in 0..chunk_size {
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
Ok(result) => match result {
Ok((address, proof)) => {
storage_proofs.insert(address, proof);
}
Err(e) => {
error!(
target: "trie::parallel",
error = ?e,
chunk_idx,
"Proof calculation failed"
);
return Err(e);
}
},
// Wait for all proofs in this batch
for _ in 0..storage_root_targets_len {
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
Ok(result) => match result {
Ok((address, proof)) => {
storage_proofs.insert(address, proof);
}
Err(e) => {
error!(
target: "trie::parallel",
error = ?e,
chunk_idx,
"Failed to receive proof result"
"Proof calculation failed"
);
return Err(ParallelStateRootError::Other(format!(
"Failed to receive proof result: {e:?}",
)));
return Err(e);
}
},
Err(e) => {
error!(
target: "trie::parallel",
error = ?e,
"Failed to receive proof result"
);
return Err(ParallelStateRootError::Other(format!(
"Failed to receive proof result: {e:?}",
)));
}
}
}
Expand Down

0 comments on commit 90d68f7

Please sign in to comment.