Skip to content

Commit

Permalink
Added VisitedPcsSet as abstraction layer to visited_pcs in
Browse files Browse the repository at this point in the history
`CachedState`.
  • Loading branch information
Eagle941 committed Aug 29, 2024
1 parent 497a36f commit 9d083b3
Show file tree
Hide file tree
Showing 35 changed files with 932 additions and 322 deletions.
9 changes: 5 additions & 4 deletions crates/blockifier/src/blockifier/stateful_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::fee::fee_checks::PostValidationReport;
use crate::state::cached_state::CachedState;
use crate::state::errors::StateError;
use crate::state::state_api::StateReader;
use crate::state::visited_pcs::VisitedPcsTrait;
use crate::transaction::account_transaction::AccountTransaction;
use crate::transaction::errors::{TransactionExecutionError, TransactionPreValidationError};
use crate::transaction::transaction_execution::Transaction;
Expand All @@ -39,12 +40,12 @@ pub enum StatefulValidatorError {
pub type StatefulValidatorResult<T> = Result<T, StatefulValidatorError>;

/// Manages state related transaction validations for pre-execution flows.
pub struct StatefulValidator<S: StateReader> {
tx_executor: TransactionExecutor<S>,
pub struct StatefulValidator<S: StateReader, V: VisitedPcsTrait> {
tx_executor: TransactionExecutor<S, V>,
}

impl<S: StateReader> StatefulValidator<S> {
pub fn create(state: CachedState<S>, block_context: BlockContext) -> Self {
impl<S: StateReader, V: VisitedPcsTrait> StatefulValidator<S, V> {
pub fn create(state: CachedState<S, V>, block_context: BlockContext) -> Self {
let tx_executor =
TransactionExecutor::new(state, block_context, TransactionExecutorConfig::default());
Self { tx_executor }
Expand Down
34 changes: 20 additions & 14 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt::Debug;
#[cfg(feature = "concurrency")]
use std::panic::{self, catch_unwind, AssertUnwindSafe};
#[cfg(feature = "concurrency")]
Expand All @@ -18,6 +19,7 @@ use crate::context::BlockContext;
use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState};
use crate::state::errors::StateError;
use crate::state::state_api::StateReader;
use crate::state::visited_pcs::VisitedPcsTrait;
use crate::transaction::errors::TransactionExecutionError;
use crate::transaction::objects::TransactionExecutionInfo;
use crate::transaction::transaction_execution::Transaction;
Expand All @@ -43,7 +45,7 @@ pub type TransactionExecutorResult<T> = Result<T, TransactionExecutorError>;
pub type VisitedSegmentsMapping = Vec<(ClassHash, Vec<usize>)>;

// TODO(Gilad): make this hold TransactionContext instead of BlockContext.
pub struct TransactionExecutor<S: StateReader> {
pub struct TransactionExecutor<S: StateReader, V: VisitedPcsTrait> {
pub block_context: BlockContext,
pub bouncer: Bouncer,
// Note: this config must not affect the execution result (e.g. state diff and traces).
Expand All @@ -54,12 +56,12 @@ pub struct TransactionExecutor<S: StateReader> {
// block state to the worker executor - operating at the chunk level - and gets it back after
// committing the chunk. The block state is wrapped with an Option<_> to allow setting it to
// `None` while it is moved to the worker executor.
pub block_state: Option<CachedState<S>>,
pub block_state: Option<CachedState<S, V>>,
}

impl<S: StateReader> TransactionExecutor<S> {
impl<S: StateReader, V: VisitedPcsTrait> TransactionExecutor<S, V> {
pub fn new(
block_state: CachedState<S>,
block_state: CachedState<S, V>,
block_context: BlockContext,
config: TransactionExecutorConfig,
) -> Self {
Expand All @@ -85,9 +87,10 @@ impl<S: StateReader> TransactionExecutor<S> {
&mut self,
tx: &Transaction,
) -> TransactionExecutorResult<TransactionExecutionInfo> {
let mut transactional_state = TransactionalState::create_transactional(
self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR),
);
let mut transactional_state: TransactionalState<'_, _, V> =
TransactionalState::create_transactional(
self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR),
);
// Executing a single transaction cannot be done in a concurrent mode.
let execution_flags =
ExecutionFlags { charge_fee: true, validate: true, concurrency_mode: false };
Expand Down Expand Up @@ -157,7 +160,8 @@ impl<S: StateReader> TransactionExecutor<S> {
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_contract_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
let class_visited_pcs = V::to_set(class_visited_pcs.clone());
Ok((*class_hash, contract_class.get_visited_segments(&class_visited_pcs)?))
})
.collect::<TransactionExecutorResult<_>>()?;

Expand All @@ -170,7 +174,9 @@ impl<S: StateReader> TransactionExecutor<S> {
}
}

impl<S: StateReader + Send + Sync> TransactionExecutor<S> {
impl<S: StateReader + Send + Sync, V: VisitedPcsTrait + Send + Sync + Debug>
TransactionExecutor<S, V>
{
/// Executes the given transactions on the state maintained by the executor.
/// Stops if and when there is no more room in the block, and returns the executed transactions'
/// results.
Expand Down Expand Up @@ -219,7 +225,7 @@ impl<S: StateReader + Send + Sync> TransactionExecutor<S> {
chunk: &[Transaction],
) -> Vec<TransactionExecutorResult<TransactionExecutionInfo>> {
use crate::concurrency::utils::AbortIfPanic;
use crate::state::cached_state::VisitedPcs;
use crate::concurrency::worker_logic::ExecutionTaskOutput;

let block_state = self.block_state.take().expect("The block state should be `Some`.");

Expand Down Expand Up @@ -263,20 +269,20 @@ impl<S: StateReader + Send + Sync> TransactionExecutor<S> {

let n_committed_txs = worker_executor.scheduler.get_n_committed_txs();
let mut tx_execution_results = Vec::new();
let mut visited_pcs: VisitedPcs = VisitedPcs::new();
let mut visited_pcs: V = V::new();
for execution_output in worker_executor.execution_outputs.iter() {
if tx_execution_results.len() >= n_committed_txs {
break;
}
let locked_execution_output = execution_output
let locked_execution_output: ExecutionTaskOutput<V> = execution_output
.lock()
.expect("Failed to lock execution output.")
.take()
.expect("Output must be ready.");
tx_execution_results
.push(locked_execution_output.result.map_err(TransactionExecutorError::from));
for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs {
visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs);
for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs.iter() {
visited_pcs.extend(class_hash, class_visited_pcs);
}
}

Expand Down
5 changes: 3 additions & 2 deletions crates/blockifier/src/blockifier/transaction_executor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::bouncer::{Bouncer, BouncerWeights};
use crate::context::BlockContext;
use crate::state::cached_state::CachedState;
use crate::state::state_api::StateReader;
use crate::state::visited_pcs::VisitedPcsTrait;
use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::declare::declare_tx;
use crate::test_utils::deploy_account::deploy_account_tx;
Expand All @@ -30,8 +31,8 @@ use crate::transaction::transaction_execution::Transaction;
use crate::transaction::transactions::L1HandlerTransaction;
use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args, nonce};

fn tx_executor_test_body<S: StateReader>(
state: CachedState<S>,
fn tx_executor_test_body<S: StateReader, V: VisitedPcsTrait>(
state: CachedState<S, V>,
block_context: BlockContext,
tx: Transaction,
expected_bouncer_weights: BouncerWeights,
Expand Down
4 changes: 3 additions & 1 deletion crates/blockifier/src/bouncer_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::bouncer::{verify_tx_weights_in_bounds, Bouncer, BouncerWeights, Built
use crate::context::BlockContext;
use crate::execution::call_info::ExecutionSummary;
use crate::state::cached_state::{StateChangesKeys, TransactionalState};
use crate::state::visited_pcs::VisitedPcsSet;
use crate::storage_key;
use crate::test_utils::initial_test_state::test_state;
use crate::transaction::errors::TransactionExecutionError;
Expand Down Expand Up @@ -187,7 +188,8 @@ fn test_bouncer_try_update(
use crate::transaction::objects::TransactionResources;

let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]);
let mut transactional_state = TransactionalState::create_transactional(state);
let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> =
TransactionalState::create_transactional(state);

// Setup the bouncer.
let block_max_capacity = BouncerWeights {
Expand Down
11 changes: 6 additions & 5 deletions crates/blockifier/src/concurrency/fee_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::execution::call_info::CallInfo;
use crate::fee::fee_utils::get_sequencer_balance_keys;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::state_api::UpdatableState;
use crate::state::visited_pcs::VisitedPcsTrait;
use crate::transaction::objects::TransactionExecutionInfo;

#[cfg(test)]
Expand All @@ -22,10 +23,10 @@ mod test;
pub(crate) const STORAGE_READ_SEQUENCER_BALANCE_INDICES: (usize, usize) = (2, 3);

// Completes the fee transfer flow if needed (if the transfer was made in concurrent mode).
pub fn complete_fee_transfer_flow(
pub fn complete_fee_transfer_flow<V: VisitedPcsTrait, U: UpdatableState<T = V>>(
tx_context: &TransactionContext,
tx_execution_info: &mut TransactionExecutionInfo,
state: &mut impl UpdatableState,
state: &mut U,
) {
if tx_context.is_sequencer_the_sender() {
// When the sequencer is the sender, we use the sequential (full) fee transfer.
Expand Down Expand Up @@ -93,9 +94,9 @@ pub fn fill_sequencer_balance_reads(
storage_read_values[high_index] = high;
}

pub fn add_fee_to_sequencer_balance(
pub fn add_fee_to_sequencer_balance<V: VisitedPcsTrait, U: UpdatableState<T = V>>(
fee_token_address: ContractAddress,
state: &mut impl UpdatableState,
state: &mut U,
actual_fee: Fee,
block_context: &BlockContext,
sequencer_balance: (Felt, Felt),
Expand All @@ -120,5 +121,5 @@ pub fn add_fee_to_sequencer_balance(
]),
..StateMaps::default()
};
state.apply_writes(&writes, &ContractClassMapping::default(), &HashMap::default());
state.apply_writes(&writes, &ContractClassMapping::default(), &V::default());
}
20 changes: 13 additions & 7 deletions crates/blockifier/src/concurrency/flow_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use starknet_api::{contract_address, felt, patricia_key};
use crate::abi::sierra_types::{SierraType, SierraU128};
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedStateProxy};
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::state_api::UpdatableState;
use crate::state::visited_pcs::VisitedPcsSet;
use crate::storage_key;
use crate::test_utils::dict_state_reader::DictStateReader;

Expand All @@ -27,6 +28,9 @@ fn scheduler_flow_test(
// transactions with multiple threads, where every transaction depends on its predecessor. Each
// transaction sequentially advances a counter by reading the previous value and bumping it by
// 1.

use crate::concurrency::versioned_state::VersionedStateProxy;
use crate::state::visited_pcs::VisitedPcsSet;
let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE));
let versioned_state =
safe_versioned_state_for_testing(CachedState::from(DictStateReader::default()));
Expand All @@ -53,7 +57,7 @@ fn scheduler_flow_test(
state_proxy.apply_writes(
&new_writes,
&ContractClassMapping::default(),
&HashMap::default(),
&VisitedPcsSet::default(),
);
scheduler.finish_execution_during_commit(tx_index);
}
Expand All @@ -66,13 +70,14 @@ fn scheduler_flow_test(
versioned_state.pin_version(tx_index).apply_writes(
&writes,
&ContractClassMapping::default(),
&HashMap::default(),
&VisitedPcsSet::default(),
);
scheduler.finish_execution(tx_index);
Task::AskForTask
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
versioned_state.pin_version(tx_index);
let (reads, writes) =
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let read_set_valid = state_proxy.validate_reads(&reads);
Expand Down Expand Up @@ -120,11 +125,11 @@ fn scheduler_flow_test(

fn get_reads_writes_for(
task: Task,
versioned_state: &ThreadSafeVersionedState<CachedState<DictStateReader>>,
versioned_state: &ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>>,
) -> (StateMaps, StateMaps) {
match task {
Task::ExecutionTask(tx_index) => {
let state_proxy = match tx_index {
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = match tx_index {
0 => {
return (
state_maps_with_single_storage_entry(0),
Expand All @@ -146,7 +151,8 @@ fn get_reads_writes_for(
)
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let state_proxy: VersionedStateProxy<_, VisitedPcsSet> =
versioned_state.pin_version(tx_index);
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
Expand Down
12 changes: 7 additions & 5 deletions crates/blockifier/src/concurrency/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::context::BlockContext;
use crate::execution::call_info::CallInfo;
use crate::state::cached_state::{CachedState, TransactionalState};
use crate::state::state_api::StateReader;
use crate::state::visited_pcs::{VisitedPcsSet, VisitedPcsTrait};
use crate::test_utils::dict_state_reader::DictStateReader;
use crate::transaction::account_transaction::AccountTransaction;
use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags};
Expand Down Expand Up @@ -61,21 +62,22 @@ macro_rules! default_scheduler {

// TODO(meshi, 01/06/2024): Consider making this a macro.
pub fn safe_versioned_state_for_testing(
block_state: CachedState<DictStateReader>,
) -> ThreadSafeVersionedState<CachedState<DictStateReader>> {
block_state: CachedState<DictStateReader, VisitedPcsSet>,
) -> ThreadSafeVersionedState<CachedState<DictStateReader, VisitedPcsSet>> {
ThreadSafeVersionedState::new(VersionedState::new(block_state))
}

// Utils.

// Note: this function does not mutate the state.
pub fn create_fee_transfer_call_info<S: StateReader>(
state: &mut CachedState<S>,
pub fn create_fee_transfer_call_info<S: StateReader, V: VisitedPcsTrait>(
state: &mut CachedState<S, V>,
account_tx: &AccountTransaction,
concurrency_mode: bool,
) -> CallInfo {
let block_context = BlockContext::create_for_account_testing();
let mut transactional_state = TransactionalState::create_transactional(state);
let mut transactional_state: TransactionalState<'_, _, V> =
TransactionalState::create_transactional(state);
let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode };
let execution_info =
account_tx.execute_raw(&mut transactional_state, &block_context, execution_flags).unwrap();
Expand Down
25 changes: 15 additions & 10 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::marker::PhantomData;
use std::sync::{Arc, Mutex, MutexGuard};

use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
Expand All @@ -7,9 +8,10 @@ use starknet_types_core::felt::Felt;
use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::ContractClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps, VisitedPcs};
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};
use crate::state::visited_pcs::VisitedPcsTrait;

#[cfg(test)]
#[path = "versioned_state_test.rs"]
Expand Down Expand Up @@ -197,11 +199,11 @@ impl<S: StateReader> VersionedState<S> {
}
}

impl<U: UpdatableState> VersionedState<U> {
impl<V: VisitedPcsTrait, U: UpdatableState<T = V>> VersionedState<U> {
pub fn commit_chunk_and_recover_block_state(
mut self,
n_committed_txs: usize,
visited_pcs: VisitedPcs,
visited_pcs: V,
) -> U {
if n_committed_txs == 0 {
return self.into_initial_state();
Expand All @@ -228,8 +230,8 @@ impl<S: StateReader> ThreadSafeVersionedState<S> {
ThreadSafeVersionedState(Arc::new(Mutex::new(versioned_state)))
}

pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy<S> {
VersionedStateProxy { tx_index, state: self.0.clone() }
pub fn pin_version<V: VisitedPcsTrait>(&self, tx_index: TxIndex) -> VersionedStateProxy<S, V> {
VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData }
}

pub fn into_inner_state(self) -> VersionedState<S> {
Expand All @@ -251,12 +253,13 @@ impl<S: StateReader> Clone for ThreadSafeVersionedState<S> {
}
}

pub struct VersionedStateProxy<S: StateReader> {
pub struct VersionedStateProxy<S: StateReader, V: VisitedPcsTrait> {
pub tx_index: TxIndex,
pub state: Arc<Mutex<VersionedState<S>>>,
_marker: PhantomData<V>,
}

impl<S: StateReader> VersionedStateProxy<S> {
impl<S: StateReader, V: VisitedPcsTrait> VersionedStateProxy<S, V> {
fn state(&self) -> LockedVersionedState<'_, S> {
self.state.lock().expect("Failed to acquire state lock.")
}
Expand All @@ -271,18 +274,20 @@ impl<S: StateReader> VersionedStateProxy<S> {
}

// TODO(Noa, 15/5/24): Consider using visited_pcs.
impl<S: StateReader> UpdatableState for VersionedStateProxy<S> {
impl<V: VisitedPcsTrait, S: StateReader> UpdatableState for VersionedStateProxy<S, V> {
type T = V;

fn apply_writes(
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
_visited_pcs: &VisitedPcs,
_visited_pcs: &V,
) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
}
}

impl<S: StateReader> StateReader for VersionedStateProxy<S> {
impl<V: VisitedPcsTrait, S: StateReader> StateReader for VersionedStateProxy<S, V> {
fn get_storage_at(
&self,
contract_address: ContractAddress,
Expand Down
Loading

0 comments on commit 9d083b3

Please sign in to comment.