Skip to content

Commit

Permalink
Added feature to support
Browse files Browse the repository at this point in the history
  • Loading branch information
Eagle941 committed Aug 7, 2024
1 parent 99ddb92 commit e410118
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 45 deletions.
1 change: 1 addition & 0 deletions crates/blockifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ workspace = true

[features]
concurrency = []
full_visited_pcs = []
jemalloc = ["dep:tikv-jemallocator"]
testing = ["rand", "rstest"]

Expand Down
21 changes: 13 additions & 8 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#[cfg(feature = "concurrency")]
use std::collections::{HashMap, HashSet};
#[cfg(feature = "concurrency")]
use std::sync::Arc;
#[cfg(feature = "concurrency")]
use std::sync::Mutex;
Expand All @@ -15,7 +13,9 @@ use crate::bouncer::{Bouncer, BouncerWeights};
#[cfg(feature = "concurrency")]
use crate::concurrency::worker_logic::WorkerExecutor;
use crate::context::BlockContext;
use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState};
use crate::state::cached_state::{
CachedState, CommitmentStateDiff, TransactionalState, VisitedPcs,
};
use crate::state::errors::StateError;
use crate::state::state_api::StateReader;
use crate::transaction::errors::TransactionExecutionError;
Expand Down Expand Up @@ -150,14 +150,19 @@ impl<S: StateReader> TransactionExecutor<S> {
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.visited_pcs
.iter()
.map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> {
.keys()
.map(|class_hash| -> TransactionExecutorResult<_> {
let contract_class = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_contract_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
let visited_pcs_set = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_set_visited_pcs(class_hash);
Ok((*class_hash, contract_class.get_visited_segments(&visited_pcs_set)?))
})
.collect::<TransactionExecutorResult<_>>()?;

Expand Down Expand Up @@ -243,7 +248,7 @@ impl<S: StateReader + Send + Sync> TransactionExecutor<S> {

let n_committed_txs = worker_executor.scheduler.get_n_committed_txs();
let mut tx_execution_results = Vec::new();
let mut visited_pcs: HashMap<ClassHash, HashSet<usize>> = HashMap::new();
let mut visited_pcs: VisitedPcs = VisitedPcs::new();
for execution_output in worker_executor.execution_outputs.iter() {
if tx_execution_results.len() >= n_committed_txs {
break;
Expand All @@ -256,7 +261,7 @@ impl<S: StateReader + Send + Sync> TransactionExecutor<S> {
tx_execution_results
.push(locked_execution_output.result.map_err(TransactionExecutorError::from));
for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs {
visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs);
visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs.clone());
}
}

Expand Down
7 changes: 3 additions & 4 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, MutexGuard};

use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
Expand All @@ -8,7 +7,7 @@ use starknet_types_core::felt::Felt;
use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::ContractClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::cached_state::{ContractClassMapping, StateMaps, VisitedPcs};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};

Expand Down Expand Up @@ -202,7 +201,7 @@ impl<U: UpdatableState> VersionedState<U> {
pub fn commit_chunk_and_recover_block_state(
mut self,
n_committed_txs: usize,
visited_pcs: HashMap<ClassHash, HashSet<usize>>,
visited_pcs: VisitedPcs,
) -> U {
if n_committed_txs == 0 {
return self.into_initial_state();
Expand Down Expand Up @@ -277,7 +276,7 @@ impl<S: StateReader> UpdatableState for VersionedStateProxy<S> {
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
_visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
_visited_pcs: &VisitedPcs,
) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
}
Expand Down
10 changes: 4 additions & 6 deletions crates/blockifier/src/concurrency/worker_logic.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;

use starknet_api::core::ClassHash;

use super::versioned_state::VersionedState;
use crate::blockifier::transaction_executor::TransactionExecutorError;
use crate::bouncer::Bouncer;
Expand All @@ -16,7 +14,7 @@ use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::concurrency::TxIndex;
use crate::context::BlockContext;
use crate::state::cached_state::{
ContractClassMapping, StateChanges, StateMaps, TransactionalState,
ContractClassMapping, StateChanges, StateMaps, TransactionalState, VisitedPcs,
};
use crate::state::state_api::{StateReader, UpdatableState};
use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult};
Expand All @@ -34,7 +32,7 @@ pub struct ExecutionTaskOutput {
pub reads: StateMaps,
pub writes: StateMaps,
pub contract_classes: ContractClassMapping,
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
pub visited_pcs: VisitedPcs,
pub result: TransactionExecutionResult<TransactionExecutionInfo>,
}

Expand Down Expand Up @@ -264,7 +262,7 @@ impl<'a, U: UpdatableState> WorkerExecutor<'a, U> {
pub fn commit_chunk_and_recover_block_state(
self,
n_committed_txs: usize,
visited_pcs: HashMap<ClassHash, HashSet<usize>>,
visited_pcs: VisitedPcs,
) -> U {
self.state
.into_inner_state()
Expand Down
15 changes: 11 additions & 4 deletions crates/blockifier/src/execution/entry_point_execution.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashSet;

use cairo_vm::types::builtin_name::BuiltinName;
use cairo_vm::types::layout_name::LayoutName;
use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable};
Expand All @@ -19,6 +17,7 @@ use crate::execution::execution_utils::{
read_execution_retdata, write_felt, write_maybe_relocatable, Args, ReadOnlySegments,
};
use crate::execution::syscalls::hint_processor::SyscallHintProcessor;
use crate::state::cached_state::Pcs;
use crate::state::state_api::State;

// TODO(spapini): Try to refactor this file into a StarknetRunner struct.
Expand Down Expand Up @@ -109,7 +108,14 @@ fn register_visited_pcs(
program_segment_size: usize,
bytecode_length: usize,
) -> EntryPointExecutionResult<()> {
let mut class_visited_pcs = HashSet::new();
fn add_element(pcs: &mut Pcs, element: usize) {
#[cfg(not(feature = "full_visited_pcs"))]
pcs.insert(element);

#[cfg(feature = "full_visited_pcs")]
pcs.push(element);
}
let mut class_visited_pcs = Pcs::new();
// Relocate the trace, putting the program segment at address 1 and the execution segment right
// after it.
// TODO(lior): Avoid unnecessary relocation once the VM has a non-relocated `get_trace()`
Expand All @@ -126,10 +132,11 @@ fn register_visited_pcs(
// Jumping to a PC that is not inside the bytecode is possible. For example, to obtain
// the builtin costs. Filter out these values.
if real_pc < bytecode_length {
class_visited_pcs.insert(real_pc);
add_element(&mut class_visited_pcs, real_pc);
}
}
state.add_visited_pcs(class_hash, &class_visited_pcs);

Ok(())
}

Expand Down
7 changes: 1 addition & 6 deletions crates/blockifier/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
// length to pointer type ([not necessarily true](https://github.com/rust-lang/rust/issues/65473),
// but it is a reasonable assumption for now), this attribute protects against potential overflow
// when converting usize to u128.
#![cfg(any(
target_pointer_width = "16",
target_pointer_width = "32",
target_pointer_width = "64",
target_pointer_width = "128"
))]
#![cfg(any(target_pointer_width = "16", target_pointer_width = "32", target_pointer_width = "64"))]

#[cfg(feature = "jemalloc")]
// Override default allocator.
Expand Down
53 changes: 47 additions & 6 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ mod test;

pub type ContractClassMapping = HashMap<ClassHash, ContractClass>;

#[cfg(not(feature = "full_visited_pcs"))]
pub type Pcs = HashSet<usize>;

#[cfg(feature = "full_visited_pcs")]
pub type Pcs = Vec<usize>;

#[cfg(not(feature = "full_visited_pcs"))]
pub type VisitedPcs = HashMap<ClassHash, Pcs>;

#[cfg(feature = "full_visited_pcs")]
pub type VisitedPcs = HashMap<ClassHash, Vec<Pcs>>;

/// Caches read and write requests.
///
/// Writer functionality is builtin, whereas Reader functionality is injected through
Expand All @@ -33,7 +45,7 @@ pub struct CachedState<S: StateReader> {
pub(crate) cache: RefCell<StateCache>,
pub(crate) class_hash_to_class: RefCell<ContractClassMapping>,
/// A map from class hash to the set of PC values that were visited in the class.
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
pub visited_pcs: VisitedPcs,
}

impl<S: StateReader> CachedState<S> {
Expand All @@ -59,6 +71,25 @@ impl<S: StateReader> CachedState<S> {
Ok(self.to_state_diff()?.into())
}

pub fn get_set_visited_pcs(&self, class_hash: &ClassHash) -> HashSet<usize> {
#[cfg(not(feature = "full_visited_pcs"))]
fn from_set(class_hash: &ClassHash, visited_pcs: &VisitedPcs) -> HashSet<usize> {
return visited_pcs.get(class_hash).unwrap().clone();
}

#[cfg(feature = "full_visited_pcs")]
fn from_set(class_hash: &ClassHash, visited_pcs: &VisitedPcs) -> HashSet<usize> {
let class_visited_pcs = visited_pcs.get(class_hash).unwrap();
let mut visited_pcs_set: HashSet<usize> = HashSet::new();
for pcs in class_visited_pcs {
visited_pcs_set.extend(pcs.iter());
}
visited_pcs_set
}

from_set(class_hash, &self.visited_pcs)
}

pub fn update_cache(
&mut self,
write_updates: &StateMaps,
Expand All @@ -73,9 +104,9 @@ impl<S: StateReader> CachedState<S> {
self.class_hash_to_class.get_mut().extend(local_contract_cache_updates);
}

pub fn update_visited_pcs_cache(&mut self, visited_pcs: &HashMap<ClassHash, HashSet<usize>>) {
pub fn update_visited_pcs_cache(&mut self, visited_pcs: &VisitedPcs) {
for (class_hash, class_visited_pcs) in visited_pcs {
self.add_visited_pcs(*class_hash, class_visited_pcs);
self.visited_pcs.entry(*class_hash).or_default().extend(class_visited_pcs.clone());
}
}

Expand Down Expand Up @@ -112,7 +143,7 @@ impl<S: StateReader> UpdatableState for CachedState<S> {
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
visited_pcs: &VisitedPcs,
) {
// TODO(OriF,15/5/24): Reconsider the clone.
self.update_cache(writes, class_hash_to_class.clone());
Expand Down Expand Up @@ -275,8 +306,18 @@ impl<S: StateReader> State for CachedState<S> {
Ok(())
}

fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>) {
self.visited_pcs.entry(class_hash).or_default().extend(pcs);
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Pcs) {
#[cfg(not(feature = "full_visited_pcs"))]
fn from_set(visited_pcs: &mut VisitedPcs, class_hash: ClassHash, pcs: &Pcs) {
visited_pcs.entry(class_hash).or_default().extend(pcs);
}

#[cfg(feature = "full_visited_pcs")]
fn from_set(visited_pcs: &mut VisitedPcs, class_hash: ClassHash, pcs: &Pcs) {
visited_pcs.entry(class_hash).or_default().push(pcs.to_vec());
}

from_set(&mut self.visited_pcs, class_hash, pcs);
}
}

Expand Down
8 changes: 3 additions & 5 deletions crates/blockifier/src/state/state_api.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::collections::{HashMap, HashSet};

use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use super::cached_state::{ContractClassMapping, StateMaps};
use super::cached_state::{ContractClassMapping, Pcs, StateMaps, VisitedPcs};
use crate::abi::abi_utils::get_fee_token_var_address;
use crate::abi::sierra_types::next_storage_key;
use crate::execution::contract_class::ContractClass;
Expand Down Expand Up @@ -107,7 +105,7 @@ pub trait State: StateReader {
/// Marks the given set of PC values as visited for the given class hash.
// TODO(lior): Once we have a BlockResources object, move this logic there. Make sure reverted
// entry points do not affect the final set of PCs.
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>);
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Pcs);
}

/// A class defining the API for updating a state with transactions writes.
Expand All @@ -116,6 +114,6 @@ pub trait UpdatableState: StateReader {
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
visited_pcs: &VisitedPcs,
);
}
7 changes: 1 addition & 6 deletions crates/native_blockifier/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
// The blockifier crate supports only these specific architectures.
#![cfg(any(
target_pointer_width = "16",
target_pointer_width = "32",
target_pointer_width = "64",
target_pointer_width = "128"
))]
#![cfg(any(target_pointer_width = "16", target_pointer_width = "32", target_pointer_width = "64"))]

pub mod errors;
pub mod py_block_executor;
Expand Down

0 comments on commit e410118

Please sign in to comment.