From 90d68f714f616c5d7910ecbd73a946e7aa72c18f Mon Sep 17 00:00:00 2001 From: Federico Gimenez Date: Tue, 17 Dec 2024 06:11:18 +0000 Subject: [PATCH] use local threadpool --- crates/trie/parallel/src/proof.rs | 221 ++++++++++++++---------------- 1 file changed, 104 insertions(+), 117 deletions(-) diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 66b5f07fedb32..6ce3e888ea173 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -108,150 +108,137 @@ where }); let prefix_sets = prefix_sets.freeze(); - let storage_root_targets: Vec<_> = StorageRootTargets::new( + let storage_root_targets = StorageRootTargets::new( prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())), prefix_sets.storage_prefix_sets.clone(), - ) - .into_iter() - .sorted_unstable_by_key(|(address, _)| *address) - .collect(); + ); + let storage_root_targets_len = storage_root_targets.len(); - // Pre-calculate storage roots for accounts which were changed. - tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); + let num_threads = std::thread::available_parallelism() + .map_or(0, |num| num.get().saturating_sub(1).max(1)); - const CHUNK_SIZE: usize = 128; - let num_chunks = storage_root_targets.len().div_ceil(CHUNK_SIZE); + // create a local thread pool with a fixed number of workers + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("proof-worker-{}", i)) + .build() + .map_err(|e| ParallelStateRootError::Other(e.to_string()))?; debug!( target: "trie::parallel_state_root", - total_targets = storage_root_targets.len(), - chunk_size = CHUNK_SIZE, - num_chunks, - "Starting batched proof generation" + total_targets = storage_root_targets_len, + num_threads, + "Starting parallel proof generation" ); - // Create a single channel for all proofs with appropriate capacity - let (tx, rx) = std::sync::mpsc::sync_channel(CHUNK_SIZE); + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); let mut storage_proofs = B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default()); - for (chunk_idx, chunk) in storage_root_targets.chunks(CHUNK_SIZE).enumerate() { - let chunk_size = chunk.len(); - - debug!( - target: "trie::parallel_state_root", - chunk_idx, - chunk_size, - "Processing proof batch" - ); - - // Spawn tasks for this batch - for (hashed_address, prefix_set) in chunk { - let view = self.view.clone(); - let target_slots = targets.get(hashed_address).cloned().unwrap_or_default(); - let trie_nodes_sorted = self.nodes_sorted.clone(); - let hashed_state_sorted = self.state_sorted.clone(); - let collect_masks = self.collect_branch_node_hash_masks; - let tx = tx.clone(); - let hashed_address = *hashed_address; - let prefix_set = prefix_set.clone(); - - rayon::spawn(move || { - debug!( + + for (hashed_address, prefix_set) in + storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) + { + let view = self.view.clone(); + let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); + let trie_nodes_sorted = self.nodes_sorted.clone(); + let hashed_state_sorted = self.state_sorted.clone(); + let collect_masks = self.collect_branch_node_hash_masks; + let tx = tx.clone(); + + pool.spawn_fifo(move || { + debug!( + target: "trie::parallel", + ?hashed_address, + "Starting proof calculation" + ); + + let task_start = Instant::now(); + let result = (|| -> Result<_, ParallelStateRootError> { + let provider_start = Instant::now(); + let provider_ro = view.provider_ro()?; + trace!( target: "trie::parallel", ?hashed_address, - "Starting proof calculation" + provider_time_ms = provider_start.elapsed().as_millis(), + "Got provider" ); - let task_start = Instant::now(); - let result = (|| -> Result<_, ParallelStateRootError> { - let provider_start = Instant::now(); - let provider_ro = view.provider_ro()?; - trace!( - target: "trie::parallel", - ?hashed_address, - provider_time_ms = provider_start.elapsed().as_millis(), - "Got provider" - ); + let cursor_start = Instant::now(); + let trie_cursor_factory = InMemoryTrieCursorFactory::new( + DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), + &trie_nodes_sorted, + ); + let hashed_cursor_factory = HashedPostStateCursorFactory::new( + DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), + &hashed_state_sorted, + ); + trace!( + target: "trie::parallel", + ?hashed_address, + cursor_time_ms = cursor_start.elapsed().as_millis(), + "Created cursors" + ); - let cursor_start = Instant::now(); - let trie_cursor_factory = InMemoryTrieCursorFactory::new( - DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), - &trie_nodes_sorted, - ); - let hashed_cursor_factory = HashedPostStateCursorFactory::new( - DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), - &hashed_state_sorted, - ); - trace!( - target: "trie::parallel", - ?hashed_address, - cursor_time_ms = cursor_start.elapsed().as_millis(), - "Created cursors" - ); + let proof_start = Instant::now(); + let proof = StorageProof::new_hashed( + trie_cursor_factory, + hashed_cursor_factory, + hashed_address, + ) + .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) + .with_branch_node_hash_masks(collect_masks) + .storage_multiproof(target_slots) + .map_err(|e| ParallelStateRootError::Other(e.to_string()))?; - let proof_start = Instant::now(); - let proof = StorageProof::new_hashed( - trie_cursor_factory, - hashed_cursor_factory, - hashed_address, - ) - .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) - .with_branch_node_hash_masks(collect_masks) - .storage_multiproof(target_slots) - .map_err(|e| ParallelStateRootError::Other(e.to_string()))?; - - trace!( - target: "trie::parallel", - ?hashed_address, - proof_time_ms = proof_start.elapsed().as_millis(), - "Completed proof calculation" - ); + trace!( + target: "trie::parallel", + ?hashed_address, + proof_time_ms = proof_start.elapsed().as_millis(), + "Completed proof calculation" + ); - Ok((hashed_address, proof)) - })(); + Ok((hashed_address, proof)) + })(); - let task_time = task_start.elapsed(); - if let Err(e) = tx.send(result) { - error!( - target: "trie::parallel", - ?hashed_address, - error = ?e, - task_time_ms = task_time.as_millis(), - "Failed to send proof result" - ); - } - }); - } + let task_time = task_start.elapsed(); + if let Err(e) = tx.send(result) { + error!( + target: "trie::parallel", + ?hashed_address, + error = ?e, + task_time_ms = task_time.as_millis(), + "Failed to send proof result" + ); + } + }); + } - // Wait for all proofs in this batch - for _ in 0..chunk_size { - match rx.recv_timeout(std::time::Duration::from_secs(30)) { - Ok(result) => match result { - Ok((address, proof)) => { - storage_proofs.insert(address, proof); - } - Err(e) => { - error!( - target: "trie::parallel", - error = ?e, - chunk_idx, - "Proof calculation failed" - ); - return Err(e); - } - }, + // Wait for all proofs in this batch + for _ in 0..storage_root_targets_len { + match rx.recv_timeout(std::time::Duration::from_secs(30)) { + Ok(result) => match result { + Ok((address, proof)) => { + storage_proofs.insert(address, proof); + } Err(e) => { error!( target: "trie::parallel", error = ?e, - chunk_idx, - "Failed to receive proof result" + "Proof calculation failed" ); - return Err(ParallelStateRootError::Other(format!( - "Failed to receive proof result: {e:?}", - ))); + return Err(e); } + }, + Err(e) => { + error!( + target: "trie::parallel", + error = ?e, + "Failed to receive proof result" + ); + return Err(ParallelStateRootError::Other(format!( + "Failed to receive proof result: {e:?}", + ))); } } }