diff --git a/crates/blockifier/src/blockifier/stateful_validator.rs b/crates/blockifier/src/blockifier/stateful_validator.rs index bb89e3d5fc..9220b82832 100644 --- a/crates/blockifier/src/blockifier/stateful_validator.rs +++ b/crates/blockifier/src/blockifier/stateful_validator.rs @@ -15,6 +15,7 @@ use crate::fee::fee_checks::PostValidationReport; use crate::state::cached_state::CachedState; use crate::state::errors::StateError; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::errors::{TransactionExecutionError, TransactionPreValidationError}; use crate::transaction::transaction_execution::Transaction; @@ -39,12 +40,12 @@ pub enum StatefulValidatorError { pub type StatefulValidatorResult = Result; /// Manages state related transaction validations for pre-execution flows. -pub struct StatefulValidator { - tx_executor: TransactionExecutor, +pub struct StatefulValidator { + tx_executor: TransactionExecutor, } -impl StatefulValidator { - pub fn create(state: CachedState, block_context: BlockContext) -> Self { +impl StatefulValidator { + pub fn create(state: CachedState, block_context: BlockContext) -> Self { let tx_executor = TransactionExecutor::new(state, block_context, TransactionExecutorConfig::default()); Self { tx_executor } diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index 4cb04c6341..fc02629241 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; #[cfg(feature = "concurrency")] use std::panic::{self, catch_unwind, AssertUnwindSafe}; #[cfg(feature = "concurrency")] @@ -18,6 +19,7 @@ use crate::context::BlockContext; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::TransactionExecutionInfo; use crate::transaction::transaction_execution::Transaction; @@ -43,7 +45,7 @@ pub type TransactionExecutorResult = Result; pub type VisitedSegmentsMapping = Vec<(ClassHash, Vec)>; // TODO(Gilad): make this hold TransactionContext instead of BlockContext. -pub struct TransactionExecutor { +pub struct TransactionExecutor { pub block_context: BlockContext, pub bouncer: Bouncer, // Note: this config must not affect the execution result (e.g. state diff and traces). @@ -54,12 +56,12 @@ pub struct TransactionExecutor { // block state to the worker executor - operating at the chunk level - and gets it back after // committing the chunk. The block state is wrapped with an Option<_> to allow setting it to // `None` while it is moved to the worker executor. - pub block_state: Option>, + pub block_state: Option>, } -impl TransactionExecutor { +impl TransactionExecutor { pub fn new( - block_state: CachedState, + block_state: CachedState, block_context: BlockContext, config: TransactionExecutorConfig, ) -> Self { @@ -85,9 +87,10 @@ impl TransactionExecutor { &mut self, tx: &Transaction, ) -> TransactionExecutorResult { - let mut transactional_state = TransactionalState::create_transactional( - self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), - ); + let mut transactional_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional( + self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), + ); // Executing a single transaction cannot be done in a concurrent mode. let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode: false }; @@ -157,7 +160,8 @@ impl TransactionExecutor { .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_contract_class(*class_hash)?; - Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) + let class_visited_pcs = V::to_set(class_visited_pcs.clone()); + Ok((*class_hash, contract_class.get_visited_segments(&class_visited_pcs)?)) }) .collect::>()?; @@ -170,7 +174,9 @@ impl TransactionExecutor { } } -impl TransactionExecutor { +impl + TransactionExecutor +{ /// Executes the given transactions on the state maintained by the executor. /// Stops if and when there is no more room in the block, and returns the executed transactions' /// results. @@ -219,7 +225,7 @@ impl TransactionExecutor { chunk: &[Transaction], ) -> Vec> { use crate::concurrency::utils::AbortIfPanic; - use crate::state::cached_state::VisitedPcs; + use crate::concurrency::worker_logic::ExecutionTaskOutput; let block_state = self.block_state.take().expect("The block state should be `Some`."); @@ -263,20 +269,20 @@ impl TransactionExecutor { let n_committed_txs = worker_executor.scheduler.get_n_committed_txs(); let mut tx_execution_results = Vec::new(); - let mut visited_pcs: VisitedPcs = VisitedPcs::new(); + let mut visited_pcs: V = V::new(); for execution_output in worker_executor.execution_outputs.iter() { if tx_execution_results.len() >= n_committed_txs { break; } - let locked_execution_output = execution_output + let locked_execution_output: ExecutionTaskOutput = execution_output .lock() .expect("Failed to lock execution output.") .take() .expect("Output must be ready."); tx_execution_results .push(locked_execution_output.result.map_err(TransactionExecutorError::from)); - for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs { - visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs); + for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs.iter() { + visited_pcs.extend(class_hash, class_visited_pcs); } } diff --git a/crates/blockifier/src/blockifier/transaction_executor_test.rs b/crates/blockifier/src/blockifier/transaction_executor_test.rs index b0139de8ba..d575603aac 100644 --- a/crates/blockifier/src/blockifier/transaction_executor_test.rs +++ b/crates/blockifier/src/blockifier/transaction_executor_test.rs @@ -13,6 +13,7 @@ use crate::bouncer::{Bouncer, BouncerWeights}; use crate::context::BlockContext; use crate::state::cached_state::CachedState; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -30,8 +31,8 @@ use crate::transaction::transaction_execution::Transaction; use crate::transaction::transactions::L1HandlerTransaction; use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args, nonce}; -fn tx_executor_test_body( - state: CachedState, +fn tx_executor_test_body( + state: CachedState, block_context: BlockContext, tx: Transaction, expected_bouncer_weights: BouncerWeights, diff --git a/crates/blockifier/src/bouncer_test.rs b/crates/blockifier/src/bouncer_test.rs index 9976ca0f13..d0a744704a 100644 --- a/crates/blockifier/src/bouncer_test.rs +++ b/crates/blockifier/src/bouncer_test.rs @@ -13,6 +13,7 @@ use crate::bouncer::{verify_tx_weights_in_bounds, Bouncer, BouncerWeights, Built use crate::context::BlockContext; use crate::execution::call_info::ExecutionSummary; use crate::state::cached_state::{StateChangesKeys, TransactionalState}; +use crate::state::visited_pcs::VisitedPcsSet; use crate::storage_key; use crate::test_utils::initial_test_state::test_state; use crate::transaction::errors::TransactionExecutionError; @@ -187,7 +188,8 @@ fn test_bouncer_try_update( use crate::transaction::objects::TransactionResources; let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(state); // Setup the bouncer. let block_max_capacity = BouncerWeights { diff --git a/crates/blockifier/src/concurrency/fee_utils.rs b/crates/blockifier/src/concurrency/fee_utils.rs index b9ad04942e..7288cf03f7 100644 --- a/crates/blockifier/src/concurrency/fee_utils.rs +++ b/crates/blockifier/src/concurrency/fee_utils.rs @@ -10,6 +10,7 @@ use crate::execution::call_info::CallInfo; use crate::fee::fee_utils::get_sequencer_balance_keys; use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::objects::TransactionExecutionInfo; #[cfg(test)] @@ -22,10 +23,10 @@ mod test; pub(crate) const STORAGE_READ_SEQUENCER_BALANCE_INDICES: (usize, usize) = (2, 3); // Completes the fee transfer flow if needed (if the transfer was made in concurrent mode). -pub fn complete_fee_transfer_flow( +pub fn complete_fee_transfer_flow>( tx_context: &TransactionContext, tx_execution_info: &mut TransactionExecutionInfo, - state: &mut impl UpdatableState, + state: &mut U, ) { if tx_context.is_sequencer_the_sender() { // When the sequencer is the sender, we use the sequential (full) fee transfer. @@ -93,9 +94,9 @@ pub fn fill_sequencer_balance_reads( storage_read_values[high_index] = high; } -pub fn add_fee_to_sequencer_balance( +pub fn add_fee_to_sequencer_balance>( fee_token_address: ContractAddress, - state: &mut impl UpdatableState, + state: &mut U, actual_fee: Fee, block_context: &BlockContext, sequencer_balance: (Felt, Felt), @@ -120,5 +121,5 @@ pub fn add_fee_to_sequencer_balance( ]), ..StateMaps::default() }; - state.apply_writes(&writes, &ContractClassMapping::default(), &HashMap::default()); + state.apply_writes(&writes, &ContractClassMapping::default(), &V::default()); } diff --git a/crates/blockifier/src/concurrency/flow_test.rs b/crates/blockifier/src/concurrency/flow_test.rs index c89644940b..4e6a7e1cac 100644 --- a/crates/blockifier/src/concurrency/flow_test.rs +++ b/crates/blockifier/src/concurrency/flow_test.rs @@ -9,9 +9,10 @@ use starknet_api::{contract_address, felt, patricia_key}; use crate::abi::sierra_types::{SierraType, SierraU128}; use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus}; use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE}; -use crate::concurrency::versioned_state::ThreadSafeVersionedState; +use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedStateProxy}; use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps}; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::storage_key; use crate::test_utils::dict_state_reader::DictStateReader; @@ -27,6 +28,9 @@ fn scheduler_flow_test( // transactions with multiple threads, where every transaction depends on its predecessor. Each // transaction sequentially advances a counter by reading the previous value and bumping it by // 1. + + use crate::concurrency::versioned_state::VersionedStateProxy; + use crate::state::visited_pcs::VisitedPcsSet; let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE)); let versioned_state = safe_versioned_state_for_testing(CachedState::from(DictStateReader::default())); @@ -53,7 +57,7 @@ fn scheduler_flow_test( state_proxy.apply_writes( &new_writes, &ContractClassMapping::default(), - &HashMap::default(), + &VisitedPcsSet::default(), ); scheduler.finish_execution_during_commit(tx_index); } @@ -66,13 +70,14 @@ fn scheduler_flow_test( versioned_state.pin_version(tx_index).apply_writes( &writes, &ContractClassMapping::default(), - &HashMap::default(), + &VisitedPcsSet::default(), ); scheduler.finish_execution(tx_index); Task::AskForTask } Task::ValidationTask(tx_index) => { - let state_proxy = versioned_state.pin_version(tx_index); + let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + versioned_state.pin_version(tx_index); let (reads, writes) = get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state); let read_set_valid = state_proxy.validate_reads(&reads); @@ -120,11 +125,11 @@ fn scheduler_flow_test( fn get_reads_writes_for( task: Task, - versioned_state: &ThreadSafeVersionedState>, + versioned_state: &ThreadSafeVersionedState>, ) -> (StateMaps, StateMaps) { match task { Task::ExecutionTask(tx_index) => { - let state_proxy = match tx_index { + let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = match tx_index { 0 => { return ( state_maps_with_single_storage_entry(0), @@ -146,7 +151,8 @@ fn get_reads_writes_for( ) } Task::ValidationTask(tx_index) => { - let state_proxy = versioned_state.pin_version(tx_index); + let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + versioned_state.pin_version(tx_index); let tx_written_value = SierraU128::from_storage( &state_proxy, &contract_address!(CONTRACT_ADDRESS), diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index 87722b1171..617760e629 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -7,6 +7,7 @@ use crate::context::BlockContext; use crate::execution::call_info::CallInfo; use crate::state::cached_state::{CachedState, TransactionalState}; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::{VisitedPcsSet, VisitedPcsTrait}; use crate::test_utils::dict_state_reader::DictStateReader; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags}; @@ -61,21 +62,22 @@ macro_rules! default_scheduler { // TODO(meshi, 01/06/2024): Consider making this a macro. pub fn safe_versioned_state_for_testing( - block_state: CachedState, -) -> ThreadSafeVersionedState> { + block_state: CachedState, +) -> ThreadSafeVersionedState> { ThreadSafeVersionedState::new(VersionedState::new(block_state)) } // Utils. // Note: this function does not mutate the state. -pub fn create_fee_transfer_call_info( - state: &mut CachedState, +pub fn create_fee_transfer_call_info( + state: &mut CachedState, account_tx: &AccountTransaction, concurrency_mode: bool, ) -> CallInfo { let block_context = BlockContext::create_for_account_testing(); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional(state); let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode }; let execution_info = account_tx.execute_raw(&mut transactional_state, &block_context, execution_flags).unwrap(); diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index fe80fa38e3..249a66a7ce 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use std::sync::{Arc, Mutex, MutexGuard}; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; @@ -7,9 +8,10 @@ use starknet_types_core::felt::Felt; use crate::concurrency::versioned_storage::VersionedStorage; use crate::concurrency::TxIndex; use crate::execution::contract_class::ContractClass; -use crate::state::cached_state::{ContractClassMapping, StateMaps, VisitedPcs}; +use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult, UpdatableState}; +use crate::state::visited_pcs::VisitedPcsTrait; #[cfg(test)] #[path = "versioned_state_test.rs"] @@ -197,11 +199,11 @@ impl VersionedState { } } -impl VersionedState { +impl> VersionedState { pub fn commit_chunk_and_recover_block_state( mut self, n_committed_txs: usize, - visited_pcs: VisitedPcs, + visited_pcs: V, ) -> U { if n_committed_txs == 0 { return self.into_initial_state(); @@ -228,8 +230,8 @@ impl ThreadSafeVersionedState { ThreadSafeVersionedState(Arc::new(Mutex::new(versioned_state))) } - pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { - VersionedStateProxy { tx_index, state: self.0.clone() } + pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { + VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData } } pub fn into_inner_state(self) -> VersionedState { @@ -251,12 +253,13 @@ impl Clone for ThreadSafeVersionedState { } } -pub struct VersionedStateProxy { +pub struct VersionedStateProxy { pub tx_index: TxIndex, pub state: Arc>>, + _marker: PhantomData, } -impl VersionedStateProxy { +impl VersionedStateProxy { fn state(&self) -> LockedVersionedState<'_, S> { self.state.lock().expect("Failed to acquire state lock.") } @@ -271,18 +274,20 @@ impl VersionedStateProxy { } // TODO(Noa, 15/5/24): Consider using visited_pcs. -impl UpdatableState for VersionedStateProxy { +impl UpdatableState for VersionedStateProxy { + type T = V; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - _visited_pcs: &VisitedPcs, + _visited_pcs: &V, ) { self.state().apply_writes(self.tx_index, writes, class_hash_to_class) } } -impl StateReader for VersionedStateProxy { +impl StateReader for VersionedStateProxy { fn get_storage_at( &self, contract_address: ContractAddress, diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index ab99698b4e..d0a5624958 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -24,6 +24,7 @@ use crate::state::cached_state::{ }; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, UpdatableState}; +use crate::state::visited_pcs::{VisitedPcsSet, VisitedPcsTrait}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::deploy_account::deploy_account_tx; use crate::test_utils::dict_state_reader::DictStateReader; @@ -39,7 +40,7 @@ use crate::{compiled_class_hash, deploy_account_tx_args, nonce, storage_key}; pub fn safe_versioned_state( contract_address: ContractAddress, class_hash: ClassHash, -) -> ThreadSafeVersionedState> { +) -> ThreadSafeVersionedState> { let init_state = DictStateReader { address_to_class_hash: HashMap::from([(contract_address, class_hash)]), ..Default::default() @@ -72,8 +73,9 @@ fn test_versioned_state_proxy() { let versioned_state = Arc::new(Mutex::new(VersionedState::new(cached_state))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let versioned_state_proxys: Vec>> = - (0..20).map(|i| safe_versioned_state.pin_version(i)).collect(); + let versioned_state_proxys: Vec< + VersionedStateProxy, VisitedPcsSet>, + > = (0..20).map(|i| safe_versioned_state.pin_version(i)).collect(); // Read initial data assert_eq!(versioned_state_proxys[5].get_nonce_at(contract_address).unwrap(), nonce); @@ -208,10 +210,14 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) { )))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let mut versioned_state_proxy_1 = safe_versioned_state.pin_version(1); - let mut state_1 = TransactionalState::create_transactional(&mut versioned_state_proxy_1); - let mut versioned_state_proxy_2 = safe_versioned_state.pin_version(2); - let mut state_2 = TransactionalState::create_transactional(&mut versioned_state_proxy_2); + let mut versioned_state_proxy_1: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(1); + let mut state_1: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(&mut versioned_state_proxy_1); + let mut versioned_state_proxy_2: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(2); + let mut state_2: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(&mut versioned_state_proxy_2); // Prepare transactions let deploy_account_tx_1 = deploy_account_tx( @@ -248,14 +254,29 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) { let block_context_1 = block_context.clone(); let block_context_2 = block_context.clone(); + // Execute transactions thread::scope(|s| { s.spawn(move || { - let result = account_tx_1.execute(&mut state_1, &block_context_1, true, true); + let result = >::execute( + &account_tx_1, + &mut state_1, + &block_context_1, + true, + true, + ); + assert_eq!(result.is_err(), enforce_fee); }); s.spawn(move || { - account_tx_2.execute(&mut state_2, &block_context_2, true, true).unwrap(); + >::execute( + &account_tx_2, + &mut state_2, + &block_context_2, + true, + true, + ) + .unwrap(); // Check that the constructor wrote ctor_arg to the storage. let storage_key = get_storage_var_address("ctor_arg", &[]); let deployed_contract_address = calculate_contract_address( @@ -276,15 +297,19 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) { fn test_validate_reads( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let storage_key = storage_key!(0x10_u8); - let mut version_state_proxy = safe_versioned_state.pin_version(1); - let transactional_state = TransactionalState::create_transactional(&mut version_state_proxy); + let mut version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(1); + let transactional_state: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(&mut version_state_proxy); // Validating tx index 0 always succeeds. - assert!(safe_versioned_state.pin_version(0).validate_reads(&StateMaps::default())); + assert!( + safe_versioned_state.pin_version::(0).validate_reads(&StateMaps::default()) + ); assert!(transactional_state.cache.borrow().initial_reads.storage.is_empty()); transactional_state.get_storage_at(contract_address, storage_key).unwrap(); @@ -313,7 +338,7 @@ fn test_validate_reads( assert!( safe_versioned_state - .pin_version(1) + .pin_version::(1) .validate_reads(&transactional_state.cache.borrow().initial_reads) ); } @@ -366,16 +391,17 @@ fn test_validate_reads( fn test_false_validate_reads( #[case] tx_1_reads: StateMaps, #[case] tx_0_writes: StateMaps, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let version_state_proxy = safe_versioned_state.pin_version(0); + let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(0); version_state_proxy.state().apply_writes(0, &tx_0_writes, &HashMap::default()); - assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); + assert!(!safe_versioned_state.pin_version::(1).validate_reads(&tx_1_reads)); } #[rstest] fn test_false_validate_reads_declared_contracts( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let tx_1_reads = StateMaps { declared_contracts: HashMap::from([(class_hash!(1_u8), false)]), @@ -385,24 +411,24 @@ fn test_false_validate_reads_declared_contracts( declared_contracts: HashMap::from([(class_hash!(1_u8), true)]), ..Default::default() }; - let version_state_proxy = safe_versioned_state.pin_version(0); + let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(0); let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class(); let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]); version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class); - assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); + assert!(!safe_versioned_state.pin_version::(1).validate_reads(&tx_1_reads)); } #[rstest] fn test_apply_writes( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut transactional_states: Vec> = + versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -419,7 +445,7 @@ fn test_apply_writes( safe_versioned_state.pin_version(0).apply_writes( &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), - &HashMap::default(), + &VisitedPcsSet::default(), ); assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); assert!( @@ -432,13 +458,12 @@ fn test_apply_writes( fn test_apply_writes_reexecute_scenario( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut transactional_states: Vec> = + versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -451,7 +476,7 @@ fn test_apply_writes_reexecute_scenario( safe_versioned_state.pin_version(0).apply_writes( &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), - &HashMap::default(), + &VisitedPcsSet::default(), ); // Although transaction 0 wrote to the shared state, version 1 needs to be re-executed to see // the new value (its read value has already been cached). @@ -468,14 +493,13 @@ fn test_apply_writes_reexecute_scenario( #[rstest] fn test_delete_writes( #[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let num_of_txs = 3; - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut transactional_states: Vec> = + versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Setting 2 instances of the contract to ensure `delete_writes` removes information from // multiple keys. Class hash values are not checked in this test. @@ -496,11 +520,11 @@ fn test_delete_writes( safe_versioned_state.pin_version(i).apply_writes( &tx_state.cache.borrow().writes, &tx_state.class_hash_to_class.borrow(), - &HashMap::default(), + &VisitedPcsSet::default(), ); } - safe_versioned_state.pin_version(tx_index_to_delete_writes).delete_writes( + safe_versioned_state.pin_version::(tx_index_to_delete_writes).delete_writes( &transactional_states[tx_index_to_delete_writes].cache.borrow().writes, &transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(), ); @@ -533,7 +557,7 @@ fn test_delete_writes( #[rstest] fn test_delete_writes_completeness( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let state_maps_writes = StateMaps { @@ -558,7 +582,7 @@ fn test_delete_writes_completeness( versioned_state_proxy.apply_writes( &state_maps_writes, &class_hash_to_class_writes, - &HashMap::default(), + &VisitedPcsSet::default(), ); assert_eq!( safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index), @@ -592,15 +616,16 @@ fn test_delete_writes_completeness( #[rstest] fn test_versioned_proxy_state_flow( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let contract_address = contract_address!("0x1"); let class_hash = ClassHash(felt!(27_u8)); - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..4).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states = Vec::with_capacity(4); + let mut transactional_states: Vec> = + Vec::with_capacity(4); for proxy_state in &mut versioned_proxy_states { transactional_states.push(TransactionalState::create_transactional(proxy_state)); } @@ -635,7 +660,7 @@ fn test_versioned_proxy_state_flow( } let modified_block_state = safe_versioned_state .into_inner_state() - .commit_chunk_and_recover_block_state(4, HashMap::new()); + .commit_chunk_and_recover_block_state(4, VisitedPcsSet::new()); assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3); assert!( diff --git a/crates/blockifier/src/concurrency/worker_logic.rs b/crates/blockifier/src/concurrency/worker_logic.rs index 150af07e9e..d1ad3b2614 100644 --- a/crates/blockifier/src/concurrency/worker_logic.rs +++ b/crates/blockifier/src/concurrency/worker_logic.rs @@ -14,9 +14,10 @@ use crate::concurrency::versioned_state::ThreadSafeVersionedState; use crate::concurrency::TxIndex; use crate::context::BlockContext; use crate::state::cached_state::{ - ContractClassMapping, StateChanges, StateMaps, TransactionalState, VisitedPcs, + ContractClassMapping, StateChanges, StateMaps, TransactionalState, }; use crate::state::state_api::{StateReader, UpdatableState}; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult}; use crate::transaction::transaction_execution::Transaction; use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags}; @@ -28,23 +29,23 @@ pub mod test; const EXECUTION_OUTPUTS_UNWRAP_ERROR: &str = "Execution task outputs should not be None."; #[derive(Debug)] -pub struct ExecutionTaskOutput { +pub struct ExecutionTaskOutput { pub reads: StateMaps, pub writes: StateMaps, pub contract_classes: ContractClassMapping, - pub visited_pcs: VisitedPcs, + pub visited_pcs: V, pub result: TransactionExecutionResult, } -pub struct WorkerExecutor<'a, S: StateReader> { +pub struct WorkerExecutor<'a, S: StateReader, V: VisitedPcsTrait> { pub scheduler: Scheduler, pub state: ThreadSafeVersionedState, pub chunk: &'a [Transaction], - pub execution_outputs: Box<[Mutex>]>, + pub execution_outputs: Box<[Mutex>>]>, pub block_context: &'a BlockContext, pub bouncer: Mutex<&'a mut Bouncer>, } -impl<'a, S: StateReader> WorkerExecutor<'a, S> { +impl<'a, S: StateReader, V: VisitedPcsTrait + Default + Debug> WorkerExecutor<'a, S, V> { pub fn new( state: ThreadSafeVersionedState, chunk: &'a [Transaction], @@ -135,7 +136,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { self.state.pin_version(tx_index).apply_writes( &transactional_state.cache.borrow().writes, &transactional_state.class_hash_to_class.borrow(), - &HashMap::default(), + &V::default(), ); } @@ -145,7 +146,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { // In case of a failed transaction, we don't record its writes and visited pcs. let (writes, contract_classes, visited_pcs) = match execution_result { Ok(_) => (tx_reads_writes.writes, class_hash_to_class, transactional_state.visited_pcs), - Err(_) => (StateMaps::default(), HashMap::default(), HashMap::default()), + Err(_) => (StateMaps::default(), HashMap::default(), V::default()), }; let mut execution_output = lock_mutex_in_array(&self.execution_outputs, tx_index); *execution_output = Some(ExecutionTaskOutput { @@ -158,7 +159,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { } fn validate(&self, tx_index: TxIndex) -> Task { - let tx_versioned_state = self.state.pin_version(tx_index); + let tx_versioned_state = self.state.pin_version::(tx_index); let execution_output = lock_mutex_in_array(&self.execution_outputs, tx_index); let execution_output = execution_output.as_ref().expect(EXECUTION_OUTPUTS_UNWRAP_ERROR); let reads = &execution_output.reads; @@ -191,7 +192,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { let execution_output_ref = execution_output.as_ref().expect(EXECUTION_OUTPUTS_UNWRAP_ERROR); let reads = &execution_output_ref.reads; - let mut tx_versioned_state = self.state.pin_version(tx_index); + let mut tx_versioned_state = self.state.pin_version::(tx_index); let reads_valid = tx_versioned_state.validate_reads(reads); // First, re-validate the transaction. @@ -258,12 +259,8 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { } } -impl<'a, U: UpdatableState> WorkerExecutor<'a, U> { - pub fn commit_chunk_and_recover_block_state( - self, - n_committed_txs: usize, - visited_pcs: VisitedPcs, - ) -> U { +impl<'a, V: VisitedPcsTrait, U: UpdatableState> WorkerExecutor<'a, U, V> { + pub fn commit_chunk_and_recover_block_state(self, n_committed_txs: usize, visited_pcs: V) -> U { self.state .into_inner_state() .commit_chunk_and_recover_block_state(n_committed_txs, visited_pcs) diff --git a/crates/blockifier/src/concurrency/worker_logic_test.rs b/crates/blockifier/src/concurrency/worker_logic_test.rs index 1f28b1ee4f..da61c09aab 100644 --- a/crates/blockifier/src/concurrency/worker_logic_test.rs +++ b/crates/blockifier/src/concurrency/worker_logic_test.rs @@ -22,6 +22,7 @@ use crate::context::{BlockContext, TransactionContext}; use crate::fee::fee_utils::get_sequencer_balance_keys; use crate::state::cached_state::StateMaps; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::initial_test_state::test_state; @@ -61,7 +62,7 @@ fn verify_sequencer_balance_update( expected_sequencer_balance_low: u128, ) { let TransactionContext { block_context, tx_info } = tx_context; - let tx_version_state = state.pin_version(tx_index); + let tx_version_state = state.pin_version::(tx_index); let (sequencer_balance_key_low, sequencer_balance_key_high) = get_sequencer_balance_keys(block_context); for (expected_balance, storage_key) in [ @@ -105,7 +106,7 @@ pub fn test_commit_tx() { let cached_state = test_state(&block_context.chain_info, BALANCE, &[(account, 1), (test_contract, 1)]); let versioned_state = safe_versioned_state_for_testing(cached_state); - let executor = + let executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Execute transactions. @@ -205,14 +206,14 @@ fn test_commit_tx_when_sender_is_sequencer() { let state = test_state(&block_context.chain_info, BALANCE, &[(account, 1), (test_contract, 1)]); let versioned_state = safe_versioned_state_for_testing(state); - let executor = WorkerExecutor::new( + let executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new( versioned_state, &sequencer_tx, &block_context, Mutex::new(&mut bouncer), ); let tx_index = 0; - let tx_versioned_state = executor.state.pin_version(tx_index); + let tx_versioned_state = executor.state.pin_version::(tx_index); // Execute and save the execution result. executor.execute_tx(tx_index); @@ -312,7 +313,7 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = WorkerExecutor::new( + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new( safe_versioned_state.clone(), &txs, &block_context, @@ -330,7 +331,7 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { // Read a write made by the transaction. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version::(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value @@ -383,14 +384,17 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { assert_eq!(execution_output.writes, writes); assert_eq!(execution_output.reads, reads); - assert_ne!(execution_output.visited_pcs, HashMap::default()); + assert_ne!(execution_output.visited_pcs, VisitedPcsSet::default()); // Failed execution. let tx_index = 1; worker_executor.execute(tx_index); // No write was made by the transaction. assert_eq!( - safe_versioned_state.pin_version(tx_index).get_nonce_at(account_address).unwrap(), + safe_versioned_state + .pin_version::(tx_index) + .get_nonce_at(account_address) + .unwrap(), nonce!(1_u8) ); let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap(); @@ -402,21 +406,24 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { }; assert_eq!(execution_output.reads, reads); assert_eq!(execution_output.writes, StateMaps::default()); - assert_eq!(execution_output.visited_pcs, HashMap::default()); + assert_eq!(execution_output.visited_pcs, VisitedPcsSet::default()); // Reverted execution. let tx_index = 2; worker_executor.execute(tx_index); // Read a write made by the transaction. assert_eq!( - safe_versioned_state.pin_version(tx_index).get_nonce_at(account_address).unwrap(), + safe_versioned_state + .pin_version::(tx_index) + .get_nonce_at(account_address) + .unwrap(), nonce!(2_u8) ); let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap(); let execution_output = execution_output.as_ref().unwrap(); assert!(execution_output.result.as_ref().unwrap().is_reverted()); assert_ne!(execution_output.writes, StateMaps::default()); - assert_ne!(execution_output.visited_pcs, HashMap::default()); + assert_ne!(execution_output.visited_pcs, VisitedPcsSet::default()); // Validate status change. for tx_index in 0..3 { @@ -474,7 +481,7 @@ fn test_worker_validate(max_resource_bounds: ResourceBoundsMapping) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = WorkerExecutor::new( + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new( safe_versioned_state.clone(), &txs, &block_context, @@ -500,7 +507,7 @@ fn test_worker_validate(max_resource_bounds: ResourceBoundsMapping) { // Verify writes exist in state. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version::(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value0 @@ -515,7 +522,7 @@ fn test_worker_validate(max_resource_bounds: ResourceBoundsMapping) { // Verify writes were removed. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version::(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value0 @@ -587,7 +594,7 @@ fn test_deploy_before_declare( .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Creates 2 active tasks. @@ -659,7 +666,7 @@ fn test_worker_commit_phase(max_resource_bounds: ResourceBoundsMapping) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Try to commit before any transaction is ready. @@ -749,7 +756,7 @@ fn test_worker_commit_phase_with_halt() { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Creates 2 active tasks. diff --git a/crates/blockifier/src/execution/contract_address_test.rs b/crates/blockifier/src/execution/contract_address_test.rs index 405360da1e..55b65cc3c5 100644 --- a/crates/blockifier/src/execution/contract_address_test.rs +++ b/crates/blockifier/src/execution/contract_address_test.rs @@ -9,6 +9,7 @@ use crate::execution::call_info::{CallExecution, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::retdata; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -27,7 +28,7 @@ fn test_calculate_contract_address() { constructor_calldata: &Calldata, calldata: Calldata, deployer_address: ContractAddress, - state: &mut CachedState, + state: &mut CachedState, ) { let versioned_constants = VersionedConstants::create_for_testing(); let entry_point_call = CallEntryPoint { diff --git a/crates/blockifier/src/execution/entry_point_execution.rs b/crates/blockifier/src/execution/entry_point_execution.rs index e63b4085d5..147d2caef9 100644 --- a/crates/blockifier/src/execution/entry_point_execution.rs +++ b/crates/blockifier/src/execution/entry_point_execution.rs @@ -1,5 +1,3 @@ -use std::collections::HashSet; - use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::types::layout_name::LayoutName; use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable}; @@ -114,7 +112,7 @@ fn register_visited_pcs( program_segment_size: usize, bytecode_length: usize, ) -> EntryPointExecutionResult<()> { - let mut class_visited_pcs = HashSet::new(); + let mut class_visited_pcs = Vec::new(); // Relocate the trace, putting the program segment at address 1 and the execution segment right // after it. // TODO(lior): Avoid unnecessary relocation once the VM has a non-relocated `get_trace()` @@ -131,7 +129,7 @@ fn register_visited_pcs( // Jumping to a PC that is not inside the bytecode is possible. For example, to obtain // the builtin costs. Filter out these values. if real_pc < bytecode_length { - class_visited_pcs.insert(real_pc); + class_visited_pcs.push(real_pc); } } state.add_visited_pcs(class_hash, &class_visited_pcs); diff --git a/crates/blockifier/src/execution/entry_point_test.rs b/crates/blockifier/src/execution/entry_point_test.rs index 07abce7eb9..a9610f9b15 100644 --- a/crates/blockifier/src/execution/entry_point_test.rs +++ b/crates/blockifier/src/execution/entry_point_test.rs @@ -12,6 +12,7 @@ use crate::context::ChainInfo; use crate::execution::call_info::{CallExecution, CallInfo, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -187,7 +188,7 @@ fn test_storage_var() { /// Runs test scenarios that could fail the OS run and therefore must be caught in the Blockifier. fn run_security_test( - state: &mut CachedState, + state: &mut CachedState, security_contract: FeatureContract, expected_error: &str, entry_point_name: &str, diff --git a/crates/blockifier/src/execution/stack_trace_test.rs b/crates/blockifier/src/execution/stack_trace_test.rs index 1fcb2106c1..5217be082f 100644 --- a/crates/blockifier/src/execution/stack_trace_test.rs +++ b/crates/blockifier/src/execution/stack_trace_test.rs @@ -513,7 +513,14 @@ Execution failed. Failure reason: 0x496e76616c6964207363656e6172696f ('Invalid s // Clean pc locations from the trace. let re = Regex::new(r"pc=0:[0-9]+").unwrap(); let cleaned_expected_error = &re.replace_all(&expected_error, "pc=0:*"); - let actual_error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let actual_error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); let actual_error_str = actual_error.to_string(); let cleaned_actual_error = &re.replace_all(&actual_error_str, "pc=0:*"); // Compare actual trace to the expected trace (sans pc locations). @@ -576,7 +583,14 @@ An ASSERT_EQ instruction failed: 1 != 0. }; // Compare expected and actual error. - let error = deploy_account_tx.execute(state, &block_context, true, true).unwrap_err(); + let error = >::execute( + &deploy_account_tx, + state, + &block_context, + true, + true, + ) + .unwrap_err(); assert_eq!(error.to_string(), expected_error); } @@ -708,7 +722,15 @@ Execution failed. Failure reason: 0x496e76616c6964207363656e6172696f ('Invalid s }; // Compare expected and actual error. - let error = - invoke_deploy_tx.execute(state, &block_context, true, true).unwrap().revert_error.unwrap(); + let error = >::execute( + &invoke_deploy_tx, + state, + &block_context, + true, + true, + ) + .unwrap() + .revert_error + .unwrap(); assert_eq!(error.to_string(), expected_error); } diff --git a/crates/blockifier/src/fee/actual_cost_test.rs b/crates/blockifier/src/fee/actual_cost_test.rs index 1db2f14c56..9fa42b1cfb 100644 --- a/crates/blockifier/src/fee/actual_cost_test.rs +++ b/crates/blockifier/src/fee/actual_cost_test.rs @@ -284,6 +284,8 @@ fn test_calculate_tx_gas_usage( max_resource_bounds: ResourceBoundsMapping, #[values(false, true)] use_kzg_da: bool, ) { + use crate::transaction::account_transaction::AccountTransaction; + let account_cairo_version = CairoVersion::Cairo0; let test_contract_cairo_version = CairoVersion::Cairo0; let block_context = &BlockContext::create_for_account_testing_with_kzg(use_kzg_da); @@ -302,7 +304,14 @@ fn test_calculate_tx_gas_usage( let calldata_length = account_tx.calldata_length(); let signature_length = account_tx.signature_length(); let fee_token_address = chain_info.fee_token_address(&account_tx.fee_type()); - let tx_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let tx_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let n_storage_updates = 1; // For the account balance update. let n_modified_contracts = 1; @@ -351,7 +360,14 @@ fn test_calculate_tx_gas_usage( let calldata_length = account_tx.calldata_length(); let signature_length = account_tx.signature_length(); - let tx_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let tx_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); // For the balance update of the sender and the recipient. let n_storage_updates = 2; // Only the account contract modification (nonce update) excluding the fee token contract. diff --git a/crates/blockifier/src/state.rs b/crates/blockifier/src/state.rs index e027d2b301..3bef337429 100644 --- a/crates/blockifier/src/state.rs +++ b/crates/blockifier/src/state.rs @@ -4,3 +4,4 @@ pub mod error_format_test; pub mod errors; pub mod global_cache; pub mod state_api; +pub mod visited_pcs; diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 74ba043d63..6dcdcc62b1 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -7,6 +7,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; +use super::visited_pcs::{VisitedPcsSet, VisitedPcsTrait}; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::TransactionContext; use crate::execution::contract_class::ContractClass; @@ -21,30 +22,28 @@ mod test; pub type ContractClassMapping = HashMap; -pub type VisitedPcs = HashMap>; - /// Caches read and write requests. /// /// Writer functionality is builtin, whereas Reader functionality is injected through /// initialization. #[derive(Debug)] -pub struct CachedState { +pub struct CachedState { pub state: S, // Invariant: read/write access is managed by CachedState. // Using interior mutability to update caches during `State`'s immutable getters. pub(crate) cache: RefCell, pub(crate) class_hash_to_class: RefCell, /// A map from class hash to the set of PC values that were visited in the class. - pub visited_pcs: VisitedPcs, + pub visited_pcs: V, } -impl CachedState { +impl CachedState { pub fn new(state: S) -> Self { Self { state, cache: RefCell::new(StateCache::default()), class_hash_to_class: RefCell::new(HashMap::default()), - visited_pcs: VisitedPcs::default(), + visited_pcs: V::default(), } } @@ -75,9 +74,10 @@ impl CachedState { self.class_hash_to_class.get_mut().extend(local_contract_cache_updates); } - pub fn update_visited_pcs_cache(&mut self, visited_pcs: &VisitedPcs) { - for (class_hash, class_visited_pcs) in visited_pcs { - self.add_visited_pcs(*class_hash, class_visited_pcs); + pub fn update_visited_pcs_cache(&mut self, visited_pcs: &V) { + for (class_hash, class_visited_pcs) in visited_pcs.iter() { + let vec_visited_pcs = V::to_vec(class_visited_pcs.clone()); //Vec::from_iter(*class_visited_pcs.clone()); + self.add_visited_pcs(*class_hash, &vec_visited_pcs); } } @@ -109,12 +109,14 @@ impl CachedState { } } -impl UpdatableState for CachedState { +impl UpdatableState for CachedState { + type T = V; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - visited_pcs: &VisitedPcs, + visited_pcs: &V, ) { // TODO(Noa,15/5/24): Reconsider the clone. self.update_cache(writes, class_hash_to_class.clone()); @@ -123,13 +125,13 @@ impl UpdatableState for CachedState { } #[cfg(any(feature = "testing", test))] -impl From for CachedState { +impl From for CachedState { fn from(state_reader: S) -> Self { CachedState::new(state_reader) } } -impl StateReader for CachedState { +impl StateReader for CachedState { fn get_storage_at( &self, contract_address: ContractAddress, @@ -224,7 +226,7 @@ impl StateReader for CachedState { } } -impl State for CachedState { +impl State for CachedState { fn set_storage_at( &mut self, contract_address: ContractAddress, @@ -277,13 +279,13 @@ impl State for CachedState { Ok(()) } - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet) { - self.visited_pcs.entry(class_hash).or_default().extend(pcs); + fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Vec) { + self.visited_pcs.insert(&class_hash, pcs); } } #[cfg(any(feature = "testing", test))] -impl Default for CachedState { +impl Default for CachedState { fn default() -> Self { Self { state: Default::default(), @@ -506,14 +508,14 @@ impl<'a, S: StateReader + ?Sized> StateReader for MutRefState<'a, S> { } } -pub type TransactionalState<'a, U> = CachedState>; +pub type TransactionalState<'a, U, V> = CachedState, V>; -impl<'a, S: StateReader> TransactionalState<'a, S> { +impl<'a, S: StateReader, V: VisitedPcsTrait> TransactionalState<'a, S, V> { /// Creates a transactional instance from the given updatable state. /// It allows performing buffered modifying actions on the given state, which /// will either all happen (will be updated in the state and committed) /// or none of them (will be discarded). - pub fn create_transactional(state: &mut S) -> TransactionalState<'_, S> { + pub fn create_transactional(state: &mut S) -> TransactionalState<'_, S, V> { CachedState::new(MutRefState::new(state)) } @@ -522,7 +524,7 @@ impl<'a, S: StateReader> TransactionalState<'a, S> { } /// Adds the ability to perform a transactional execution. -impl<'a, U: UpdatableState> TransactionalState<'a, U> { +impl<'a, V: VisitedPcsTrait, U: UpdatableState> TransactionalState<'a, U, V> { /// Commits changes in the child (wrapping) state to its parent. pub fn commit(self) { let state = self.state.0; diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 37c8e72ba5..ba9f3b9302 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -17,7 +17,7 @@ use crate::{compiled_class_hash, nonce, storage_key}; const CONTRACT_ADDRESS: &str = "0x100"; fn set_initial_state_values( - state: &mut CachedState, + state: &mut CachedState, class_hash_to_class: ContractClassMapping, nonce_initial_values: HashMap, class_hash_initial_values: HashMap, @@ -33,7 +33,7 @@ fn set_initial_state_values( #[test] fn get_uninitialized_storage_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); let key = storage_key!(0x10_u16); @@ -49,13 +49,14 @@ fn get_and_set_storage_value() { let storage_val0: Felt = felt!("0x1"); let storage_val1: Felt = felt!("0x5"); - let mut state = CachedState::from(DictStateReader { - storage_view: HashMap::from([ - ((contract_address0, key0), storage_val0), - ((contract_address1, key1), storage_val1), - ]), - ..Default::default() - }); + let mut state: CachedState = + CachedState::from(DictStateReader { + storage_view: HashMap::from([ + ((contract_address0, key0), storage_val0), + ((contract_address1, key1), storage_val1), + ]), + ..Default::default() + }); assert_eq!(state.get_storage_at(contract_address0, key0).unwrap(), storage_val0); assert_eq!(state.get_storage_at(contract_address1, key1).unwrap(), storage_val1); @@ -98,7 +99,7 @@ fn cast_between_storage_mapping_types() { #[test] fn get_uninitialized_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); assert_eq!(state.get_nonce_at(contract_address).unwrap(), Nonce::default()); @@ -106,7 +107,8 @@ fn get_uninitialized_value() { #[test] fn declare_contract() { - let mut state = CachedState::from(DictStateReader { ..Default::default() }); + let mut state: CachedState = + CachedState::from(DictStateReader { ..Default::default() }); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); let contract_class = test_contract.get_class(); @@ -135,13 +137,14 @@ fn get_and_increment_nonce() { let contract_address2 = contract_address!("0x200"); let initial_nonce = Nonce(felt!(1_u8)); - let mut state = CachedState::from(DictStateReader { - address_to_nonce: HashMap::from([ - (contract_address1, initial_nonce), - (contract_address2, initial_nonce), - ]), - ..Default::default() - }); + let mut state: CachedState = + CachedState::from(DictStateReader { + address_to_nonce: HashMap::from([ + (contract_address1, initial_nonce), + (contract_address2, initial_nonce), + ]), + ..Default::default() + }); assert_eq!(state.get_nonce_at(contract_address1).unwrap(), initial_nonce); assert_eq!(state.get_nonce_at(contract_address2).unwrap(), initial_nonce); @@ -181,7 +184,7 @@ fn get_contract_class() { #[test] fn get_uninitialized_class_hash_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let valid_contract_address = contract_address!("0x1"); assert_eq!(state.get_class_hash_at(valid_contract_address).unwrap(), ClassHash::default()); @@ -190,7 +193,7 @@ fn get_uninitialized_class_hash_value() { #[test] fn set_and_get_contract_hash() { let contract_address = contract_address!("0x1"); - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let class_hash = class_hash!("0x10"); assert!(state.set_class_hash_at(contract_address, class_hash).is_ok()); @@ -199,7 +202,7 @@ fn set_and_get_contract_hash() { #[test] fn cannot_set_class_hash_to_uninitialized_contract() { - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let uninitialized_contract_address = ContractAddress::default(); let class_hash = class_hash!("0x100"); @@ -289,8 +292,8 @@ fn cached_state_state_diff_conversion() { assert_eq!(expected_state_diff, state.to_state_diff().unwrap().into()); } -fn create_state_changes_for_test( - state: &mut CachedState, +fn create_state_changes_for_test( + state: &mut CachedState, sender_address: Option, fee_token_address: ContractAddress, ) -> StateChanges { @@ -331,7 +334,7 @@ fn create_state_changes_for_test( fn test_from_state_changes_for_fee_charge( #[values(Some(contract_address!("0x102")), None)] sender_address: Option, ) { - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let fee_token_address = contract_address!("0x17"); let state_changes = create_state_changes_for_test(&mut state, sender_address, fee_token_address); @@ -352,8 +355,9 @@ fn test_state_changes_merge( ) { // Create a transactional state containing the `create_state_changes_for_test` logic, get the // state changes and then commit. - let mut state: CachedState = CachedState::default(); - let mut transactional_state = TransactionalState::create_transactional(&mut state); + let mut state: CachedState = CachedState::default(); + let mut transactional_state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(&mut state); let block_context = BlockContext::create_for_testing(); let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address; let state_changes1 = @@ -362,7 +366,8 @@ fn test_state_changes_merge( // After performing `commit`, the transactional state is moved (into state). We need to create // a new transactional state that wraps `state` to continue. - let mut transactional_state = TransactionalState::create_transactional(&mut state); + let mut transactional_state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(&mut state); // Make sure that `get_actual_state_changes` on a newly created transactional state returns null // state changes and that merging null state changes with non-null state changes results in the // non-null state changes, no matter the order. @@ -422,7 +427,7 @@ fn test_contract_cache_is_used() { let contract_class = test_contract.get_class(); let mut reader = DictStateReader::default(); reader.class_hash_to_class.insert(class_hash, contract_class.clone()); - let state = CachedState::new(reader); + let state: CachedState = CachedState::new(reader); // Assert local cache is initialized empty. assert!(state.class_hash_to_class.borrow().get(&class_hash).is_none()); diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 8c6a1db559..3ca626aa72 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -1,10 +1,8 @@ -use std::collections::HashSet; - use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use super::cached_state::{ContractClassMapping, StateMaps, VisitedPcs}; +use super::cached_state::{ContractClassMapping, StateMaps}; use crate::abi::abi_utils::get_fee_token_var_address; use crate::abi::sierra_types::next_storage_key; use crate::execution::contract_class::ContractClass; @@ -107,15 +105,28 @@ pub trait State: StateReader { /// Marks the given set of PC values as visited for the given class hash. // TODO(lior): Once we have a BlockResources object, move this logic there. Make sure reverted // entry points do not affect the final set of PCs. - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet); + fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Vec); } /// A class defining the API for updating a state with transactions writes. pub trait UpdatableState: StateReader { + type T; + + fn apply_writes( + &mut self, + writes: &StateMaps, + class_hash_to_class: &ContractClassMapping, + visited_pcs: &Self::T, + ); +} + +pub trait UpdatableStatetTest: StateReader { + type T; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - visited_pcs: &VisitedPcs, + visited_pcs: &Self::T, ); } diff --git a/crates/blockifier/src/state/visited_pcs.rs b/crates/blockifier/src/state/visited_pcs.rs new file mode 100644 index 0000000000..548793849b --- /dev/null +++ b/crates/blockifier/src/state/visited_pcs.rs @@ -0,0 +1,119 @@ +use std::collections::hash_map::{Entry, IntoIter, Iter}; +use std::collections::{HashMap, HashSet}; + +use starknet_api::core::ClassHash; + +use crate::state::state_api::StateReader; + +// pub trait VisitedPcsTraitSecond +// where +// Self: std::default::Default + IntoIterator + FromIterator, +// { +// type Collection; + +// /// `pcs` type is matching the output from `runner.relocated_trace` +// fn insert(&mut self, class_hash: &ClassHash, pcs: &Vec); + +// fn iter(&self) -> Iter<'_, ClassHash, Self::Collection>; + +// fn iter_second(&self) -> impl Iterator; + +// fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, Self::Collection>; +// } + +pub trait VisitedPcsTrait +where + Self: std::default::Default + Sized, +{ + type T: Clone + Default; + + fn new() -> Self; + + /// `pcs` type is matching the output from `runner.relocated_trace` + fn insert(&mut self, class_hash: &ClassHash, pcs: &Vec); + + fn extend(&mut self, class_hash: &ClassHash, pcs: &Self::T); + + fn iter(&self) -> Iter<'_, ClassHash, Self::T>; + + fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, Self::T>; + + fn to_vec(pcs: Self::T) -> Vec; + + fn to_set(pcs: Self::T) -> HashSet; +} + +#[derive(Debug, Default, PartialEq, Eq)] +pub struct VisitedPcsSet(HashMap>); +impl VisitedPcsTrait for VisitedPcsSet { + type T = HashSet; + + fn new() -> Self { + VisitedPcsSet(HashMap::default()) + } + + fn insert(&mut self, class_hash: &ClassHash, pcs: &Vec) { + self.0.entry(*class_hash).or_default().extend(pcs); + } + + fn iter(&self) -> Iter<'_, ClassHash, HashSet> { + self.0.iter() + } + + fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, HashSet> { + self.0.entry(class_hash) + } + + fn to_vec(pcs: Self::T) -> Vec { + Vec::from_iter(pcs) + } + + fn extend(&mut self, class_hash: &ClassHash, pcs: &Self::T) { + self.0.entry(*class_hash).or_default().extend(pcs); + } + + fn to_set(pcs: Self::T) -> HashSet { + pcs + } +} +impl IntoIterator for VisitedPcsSet { + type Item = (ClassHash, HashSet); + type IntoIter = IntoIter>; + + fn into_iter(self) -> IntoIter> { + self.0.into_iter() + } +} +impl<'a> IntoIterator for &'a VisitedPcsSet { + type Item = (&'a ClassHash, &'a HashSet); + type IntoIter = Iter<'a, ClassHash, HashSet>; + + fn into_iter(self) -> Iter<'a, ClassHash, HashSet> { + self.0.iter() + } +} + +#[derive(Debug, Default)] +pub struct CachedStateTest { + pub state: S, + /// A map from class hash to the set of PC values that were visited in the class. + pub visited_pcs: V, +} +impl CachedStateTest { + pub fn new(state: S) -> Self { + Self { state, visited_pcs: V::default() } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::dict_state_reader::DictStateReader; + + #[test] + fn test_cached_state() { + let state = DictStateReader::default(); + let _cached_state: CachedStateTest = + CachedStateTest::new(state); + } +} diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 6e0268cb29..6bf2fc56b3 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -7,6 +7,7 @@ use strum::IntoEnumIterator; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::ChainInfo; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::CairoVersion; @@ -40,7 +41,7 @@ pub fn test_state_inner( initial_balances: u128, contract_instances: &[(FeatureContract, u16)], erc20_contract_version: CairoVersion, -) -> CachedState { +) -> CachedState { let mut class_hash_to_class = HashMap::new(); let mut address_to_class_hash = HashMap::new(); @@ -87,6 +88,6 @@ pub fn test_state( chain_info: &ChainInfo, initial_balances: u128, contract_instances: &[(FeatureContract, u16)], -) -> CachedState { +) -> CachedState { test_state_inner(chain_info, initial_balances, contract_instances, CairoVersion::Cairo0) } diff --git a/crates/blockifier/src/test_utils/transfers_generator.rs b/crates/blockifier/src/test_utils/transfers_generator.rs index 3d3a6a911d..1bd11e587f 100644 --- a/crates/blockifier/src/test_utils/transfers_generator.rs +++ b/crates/blockifier/src/test_utils/transfers_generator.rs @@ -10,6 +10,7 @@ use crate::blockifier::config::{ConcurrencyConfig, TransactionExecutorConfig}; use crate::blockifier::transaction_executor::TransactionExecutor; use crate::context::{BlockContext, ChainInfo}; use crate::invoke_tx_args; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -73,7 +74,7 @@ pub enum RecipientGeneratorType { pub struct TransfersGenerator { account_addresses: Vec, chain_info: ChainInfo, - executor: TransactionExecutor, + executor: TransactionExecutor, nonce_manager: NonceManager, sender_index: usize, random_recipient_generator: Option, diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 51ee47b660..a514c029b8 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -21,6 +21,7 @@ use crate::fee::gas_usage::{compute_discounted_gas_from_gas_vector, estimate_min use crate::retdata; use crate::state::cached_state::{StateChanges, TransactionalState}; use crate::state::state_api::{State, StateReader, UpdatableState}; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::constants; use crate::transaction::errors::{ TransactionExecutionError, TransactionFeeError, TransactionPreValidationError, @@ -303,9 +304,9 @@ impl AccountTransaction { Ok(()) } - fn handle_fee( + fn handle_fee( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, actual_fee: Fee, charge_fee: bool, @@ -370,8 +371,8 @@ impl AccountTransaction { /// manipulates the state to avoid that part. /// Note: the returned transfer call info is partial, and should be completed at the commit /// stage, as well as the actual sequencer balance. - fn concurrency_execute_fee_transfer( - state: &mut TransactionalState<'_, S>, + fn concurrency_execute_fee_transfer( + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, actual_fee: Fee, ) -> TransactionExecutionResult { @@ -379,7 +380,8 @@ impl AccountTransaction { let fee_address = block_context.chain_info.fee_token_address(&tx_info.fee_type()); let (sequencer_balance_key_low, sequencer_balance_key_high) = get_sequencer_balance_keys(block_context); - let mut transfer_state = TransactionalState::create_transactional(state); + let mut transfer_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional(state); // Set the initial sequencer balance to avoid tarnishing the read-set of the transaction. let cache = transfer_state.cache.get_mut(); @@ -411,9 +413,9 @@ impl AccountTransaction { } } - fn run_non_revertible( + fn run_non_revertible( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, remaining_gas: &mut u64, validate: bool, @@ -474,9 +476,9 @@ impl AccountTransaction { } } - fn run_revertible( + fn run_revertible( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, remaining_gas: &mut u64, validate: bool, @@ -508,7 +510,8 @@ impl AccountTransaction { // Create copies of state and resources for the execution. // Both will be rolled back if the execution is reverted or committed upon success. let mut execution_resources = resources.clone(); - let mut execution_state = TransactionalState::create_transactional(state); + let mut execution_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional(state); let execution_result = self.run_execute( &mut execution_state, @@ -615,9 +618,9 @@ impl AccountTransaction { } /// Runs validation and execution. - fn run_or_revert( + fn run_or_revert( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, remaining_gas: &mut u64, tx_context: Arc, validate: bool, @@ -631,10 +634,12 @@ impl AccountTransaction { } } -impl ExecutableTransaction for AccountTransaction { +impl> ExecutableTransaction + for AccountTransaction +{ fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index 89eb2fb9ca..68361fc43e 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -144,7 +144,13 @@ fn test_fee_enforcement( let account_tx = AccountTransaction::DeployAccount(deploy_account_tx); let enforce_fee = account_tx.create_tx_info().enforce_fee().unwrap(); - let result = account_tx.execute(state, &block_context, true, true); + let result = >::execute( + &account_tx, + state, + &block_context, + true, + true, + ); assert_eq!(result.is_err(), enforce_fee); } @@ -333,7 +339,14 @@ fn test_max_fee_limit_validate( }, class_info, ); - account_tx.execute(&mut state, &block_context, true, true).unwrap(); + >::execute( + &account_tx, + &mut state, + &block_context, + true, + true, + ) + .unwrap(); // Deploy grindy account with a lot of grind in the constructor. // Expect this to fail without bumping nonce, so pass a temporary nonce manager. @@ -349,8 +362,15 @@ fn test_max_fee_limit_validate( constructor_calldata: calldata![ctor_grind_arg, ctor_storage_arg], }, ); - let error_trace = - deploy_account_tx.execute(&mut state, &block_context, true, true).unwrap_err().to_string(); + let error_trace = >::execute( + &deploy_account_tx, + &mut state, + &block_context, + true, + true, + ) + .unwrap_err() + .to_string(); assert!(error_trace.contains("no remaining steps")); // Deploy grindy account successfully this time. @@ -365,7 +385,14 @@ fn test_max_fee_limit_validate( constructor_calldata: calldata![ctor_grind_arg, ctor_storage_arg], }, ); - deploy_account_tx.execute(&mut state, &block_context, true, true).unwrap(); + >::execute( + &deploy_account_tx, + &mut state, + &block_context, + true, + true, + ) + .unwrap(); // Invoke a function that grinds validate (any function will do); set bounds low enough to fail // on this grind. @@ -578,7 +605,14 @@ fn test_fail_deploy_account( let initial_balance = state.get_fee_token_balance(deploy_address, fee_token_address).unwrap(); - let error = deploy_account_tx.execute(state, &block_context, true, true).unwrap_err(); + let error = >::execute( + &deploy_account_tx, + state, + &block_context, + true, + true, + ) + .unwrap_err(); // Check the error is as expected. Assure the error message is not nonce or fee related. check_transaction_execution_error_for_invalid_scenario!(cairo_version, error, false); @@ -628,7 +662,14 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { let initial_balance = state .get_fee_token_balance(account_address, chain_info.fee_token_address(&tx_info.fee_type())) .unwrap(); - declare_account_tx.execute(&mut state, &block_context, true, true).unwrap_err(); + >::execute( + &declare_account_tx, + &mut state, + &block_context, + true, + true, + ) + .unwrap_err(); assert_eq!(state.get_nonce_at(account_address).unwrap(), next_nonce); assert_eq!( @@ -904,7 +945,14 @@ fn test_max_fee_to_max_steps_conversion( let tx_context1 = Arc::new(block_context.to_tx_context(&account_tx1)); let execution_context1 = EntryPointExecutionContext::new_invoke(tx_context1, true).unwrap(); let max_steps_limit1 = execution_context1.vm_run_resources.get_n_steps(); - let tx_execution_info1 = account_tx1.execute(&mut state, &block_context, true, true).unwrap(); + let tx_execution_info1 = >::execute( + &account_tx1, + &mut state, + &block_context, + true, + true, + ) + .unwrap(); let n_steps1 = tx_execution_info1.transaction_receipt.resources.vm_resources.n_steps; let gas_used_vector1 = tx_execution_info1 .transaction_receipt @@ -924,7 +972,14 @@ fn test_max_fee_to_max_steps_conversion( let tx_context2 = Arc::new(block_context.to_tx_context(&account_tx2)); let execution_context2 = EntryPointExecutionContext::new_invoke(tx_context2, true).unwrap(); let max_steps_limit2 = execution_context2.vm_run_resources.get_n_steps(); - let tx_execution_info2 = account_tx2.execute(&mut state, &block_context, true, true).unwrap(); + let tx_execution_info2 = >::execute( + &account_tx2, + &mut state, + &block_context, + true, + true, + ) + .unwrap(); let n_steps2 = tx_execution_info2.transaction_receipt.resources.vm_resources.n_steps; let gas_used_vector2 = tx_execution_info2 .transaction_receipt @@ -1043,7 +1098,14 @@ fn test_deploy_account_constructor_storage_write( constructor_calldata: constructor_calldata.clone(), }, ); - deploy_account_tx.execute(state, &block_context, true, true).unwrap(); + >::execute( + &deploy_account_tx, + state, + &block_context, + true, + true, + ) + .unwrap(); // Check that the constructor wrote ctor_arg to the storage. let storage_key = get_storage_var_address("ctor_arg", &[]); @@ -1105,7 +1167,8 @@ fn test_count_actual_storage_changes( // Run transactions; using transactional state to count only storage changes of the current // transaction. // First transaction: storage cell value changes from 0 to 1. - let mut state = TransactionalState::create_transactional(&mut state); + let mut state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(&mut state); let invoke_args = invoke_tx_args! { max_fee, resource_bounds: max_resource_bounds, @@ -1155,7 +1218,8 @@ fn test_count_actual_storage_changes( assert_eq!(state_changes_count_1, expected_state_changes_count_1); // Second transaction: storage cell starts and ends with value 1. - let mut state = TransactionalState::create_transactional(&mut state); + let mut state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(&mut state); let account_tx = account_invoke_tx(InvokeTxArgs { nonce: nonce_manager.next(account_address), ..invoke_args.clone() @@ -1192,7 +1256,8 @@ fn test_count_actual_storage_changes( assert_eq!(state_changes_count_2, expected_state_changes_count_2); // Transfer transaction: transfer 1 ETH to recepient. - let mut state = TransactionalState::create_transactional(&mut state); + let mut state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(&mut state); let account_tx = account_invoke_tx(InvokeTxArgs { nonce: nonce_manager.next(account_address), calldata: transfer_calldata, @@ -1273,7 +1338,8 @@ fn test_concurrency_execute_fee_transfer( // Case 1: The transaction did not read form/ write to the sequenser balance before executing // fee transfer. - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(state); let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode: true }; let result = @@ -1305,7 +1371,8 @@ fn test_concurrency_execute_fee_transfer( SEQUENCER_BALANCE_LOW_INITIAL, &mut state.state, ); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(state); // Invokes transfer to the sequencer. let account_tx = account_invoke_tx(invoke_tx_args! { @@ -1369,7 +1436,8 @@ fn test_concurrent_fee_transfer_when_sender_is_sequencer( let fee_type = &account_tx.fee_type(); let fee_token_address = block_context.chain_info.fee_token_address(fee_type); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, _> = + TransactionalState::create_transactional(state); let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode: true }; let result = diff --git a/crates/blockifier/src/transaction/execution_flavors_test.rs b/crates/blockifier/src/transaction/execution_flavors_test.rs index e186ad8151..4e7c9ee71e 100644 --- a/crates/blockifier/src/transaction/execution_flavors_test.rs +++ b/crates/blockifier/src/transaction/execution_flavors_test.rs @@ -13,6 +13,7 @@ use crate::execution::syscalls::SyscallSelector; use crate::fee::fee_utils::get_fee_by_gas_vector; use crate::state::cached_state::CachedState; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::{VisitedPcsSet, VisitedPcsTrait}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -34,7 +35,7 @@ use crate::{invoke_tx_args, nonce}; const VALIDATE_GAS_OVERHEAD: u64 = 21; struct FlavorTestInitialState { - pub state: CachedState, + pub state: CachedState, pub account_address: ContractAddress, pub faulty_account_address: ContractAddress, pub test_contract_address: ContractAddress, @@ -64,9 +65,9 @@ fn create_flavors_test_state( /// Checks that balance of the account decreased if and only if `charge_fee` is true. /// Returns the new balance. -fn check_balance( +fn check_balance( current_balance: Felt, - state: &mut CachedState, + state: &mut CachedState, account_address: ContractAddress, chain_info: &ChainInfo, fee_type: &FeeType, @@ -153,6 +154,8 @@ fn test_simulate_validate_charge_fee_pre_validate( #[case] fee_type: FeeType, #[case] is_deprecated: bool, ) { + use crate::transaction::account_transaction::AccountTransaction; + let block_context = BlockContext::create_for_account_testing(); let max_fee = Fee(MAX_FEE); // The max resource bounds fixture is not used here because this function already has the @@ -185,10 +188,16 @@ fn test_simulate_validate_charge_fee_pre_validate( // First scenario: invalid nonce. Regardless of flags, should fail. let invalid_nonce = nonce!(7_u8); let account_nonce = state.get_nonce_at(account_address).unwrap(); - let result = account_invoke_tx( + let account_tx = account_invoke_tx( invoke_tx_args! {nonce: invalid_nonce, ..pre_validation_base_args.clone()}, - ) - .execute(&mut state, &block_context, charge_fee, validate); + ); + let result = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ); assert_matches!( result.unwrap_err(), TransactionExecutionError::TransactionPreValidationError( @@ -210,13 +219,19 @@ fn test_simulate_validate_charge_fee_pre_validate( validate, &fee_type, ); - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: Fee(10), resource_bounds: l1_resource_bounds(10, 10), nonce: nonce_manager.next(account_address), ..pre_validation_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate); + }); + let result = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ); if !charge_fee { check_gas_and_fee( &block_context, @@ -254,13 +269,19 @@ fn test_simulate_validate_charge_fee_pre_validate( // TODO(Ori, 1/2/2024): Write an indicative expect message explaining why the conversion works. let balance_over_gas_price: u64 = (BALANCE / gas_price).try_into().expect("Failed to convert u128 to u64."); - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: Fee(BALANCE + 1), resource_bounds: l1_resource_bounds(balance_over_gas_price + 10, gas_price.into()), nonce: nonce_manager.next(account_address), ..pre_validation_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate); + }); + let result = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ); if !charge_fee { check_gas_and_fee( &block_context, @@ -295,12 +316,18 @@ fn test_simulate_validate_charge_fee_pre_validate( // Fourth scenario: L1 gas price bound lower than the price on the block. if !is_deprecated { - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { resource_bounds: l1_resource_bounds(MAX_L1_GAS_AMOUNT, u128::from(gas_price) - 1), nonce: nonce_manager.next(account_address), ..pre_validation_base_args - }) - .execute(&mut state, &block_context, charge_fee, validate); + }); + let result = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ); if !charge_fee { check_gas_and_fee( &block_context, @@ -338,6 +365,8 @@ fn test_simulate_validate_charge_fee_fail_validate( #[case] fee_type: FeeType, max_resource_bounds: ResourceBoundsMapping, ) { + use super::AccountTransaction; + let block_context = BlockContext::create_for_account_testing(); let max_fee = Fee(MAX_FEE); @@ -355,7 +384,7 @@ fn test_simulate_validate_charge_fee_fail_validate( validate, &fee_type, ); - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee, resource_bounds: max_resource_bounds, signature: TransactionSignature(vec![ @@ -367,8 +396,14 @@ fn test_simulate_validate_charge_fee_fail_validate( version, nonce: nonce_manager.next(faulty_account_address), only_query, - }) - .execute(&mut falliable_state, &block_context, charge_fee, validate); + }); + let result = >::execute( + &account_tx, + &mut falliable_state, + &block_context, + charge_fee, + validate, + ); if !validate { // The reported fee should be the actual cost, regardless of whether or not fee is charged. check_gas_and_fee( @@ -400,6 +435,8 @@ fn test_simulate_validate_charge_fee_mid_execution( #[case] fee_type: FeeType, max_resource_bounds: ResourceBoundsMapping, ) { + use crate::transaction::account_transaction::AccountTransaction; + let block_context = BlockContext::create_for_account_testing(); let chain_info = &block_context.chain_info; let gas_price = block_context.block_info.gas_prices.get_gas_price_by_fee_type(&fee_type); @@ -434,12 +471,18 @@ fn test_simulate_validate_charge_fee_mid_execution( validate, &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { calldata: recurse_calldata(test_contract_address, true, 3), nonce: nonce_manager.next(account_address), ..execution_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate) + }); + let tx_execution_info = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ) .unwrap(); assert!(tx_execution_info.is_reverted()); check_gas_and_fee( @@ -474,14 +517,20 @@ fn test_simulate_validate_charge_fee_mid_execution( validate, &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: fee_bound, resource_bounds: l1_resource_bounds(gas_bound, gas_price.into()), calldata: recurse_calldata(test_contract_address, false, 1000), nonce: nonce_manager.next(account_address), ..execution_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate) + }); + let tx_execution_info = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ) .unwrap(); assert_eq!(tx_execution_info.is_reverted(), charge_fee); if charge_fee { @@ -526,14 +575,20 @@ fn test_simulate_validate_charge_fee_mid_execution( GasVector::from_l1_gas(block_limit_gas.into()), &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: huge_fee, resource_bounds: l1_resource_bounds(huge_gas_limit, gas_price.into()), calldata: recurse_calldata(test_contract_address, false, 10000), nonce: nonce_manager.next(account_address), ..execution_base_args - }) - .execute(&mut state, &low_step_block_context, charge_fee, validate) + }); + let tx_execution_info = >::execute( + &account_tx, + &mut state, + &low_step_block_context, + charge_fee, + validate, + ) .unwrap(); assert!(tx_execution_info.revert_error.clone().unwrap().contains("no remaining steps")); // Complete resources used are reported as transaction_receipt.resources; but only the charged @@ -563,6 +618,8 @@ fn test_simulate_validate_charge_fee_post_execution( #[case] fee_type: FeeType, #[case] is_deprecated: bool, ) { + use crate::transaction::account_transaction::AccountTransaction; + let block_context = BlockContext::create_for_account_testing(); let gas_price = block_context.block_info.gas_prices.get_gas_price_by_fee_type(&fee_type); let chain_info = &block_context.chain_info; @@ -607,7 +664,7 @@ fn test_simulate_validate_charge_fee_post_execution( validate, &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: just_not_enough_fee_bound, resource_bounds: l1_resource_bounds(just_not_enough_gas_bound, gas_price.into()), calldata: recurse_calldata(test_contract_address, false, 1000), @@ -615,8 +672,14 @@ fn test_simulate_validate_charge_fee_post_execution( sender_address: account_address, version, only_query, - }) - .execute(&mut state, &block_context, charge_fee, validate) + }); + let tx_execution_info = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ) .unwrap(); assert_eq!(tx_execution_info.is_reverted(), charge_fee); if charge_fee { @@ -672,7 +735,7 @@ fn test_simulate_validate_charge_fee_post_execution( felt!(0_u8), ], ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: actual_fee, resource_bounds: l1_resource_bounds(success_actual_gas, gas_price.into()), calldata: transfer_calldata, @@ -680,8 +743,14 @@ fn test_simulate_validate_charge_fee_post_execution( sender_address: account_address, version, only_query, - }) - .execute(&mut state, &block_context, charge_fee, validate) + }); + let tx_execution_info = >::execute( + &account_tx, + &mut state, + &block_context, + charge_fee, + validate, + ) .unwrap(); assert_eq!(tx_execution_info.is_reverted(), charge_fee); if charge_fee { diff --git a/crates/blockifier/src/transaction/post_execution_test.rs b/crates/blockifier/src/transaction/post_execution_test.rs index a9b9e67321..cdc850f37e 100644 --- a/crates/blockifier/src/transaction/post_execution_test.rs +++ b/crates/blockifier/src/transaction/post_execution_test.rs @@ -108,8 +108,14 @@ fn test_revert_on_overdraft( nonce: nonce_manager.next(account_address), }); let tx_info = approve_tx.create_tx_info(); - let approval_execution_info = - approve_tx.execute(&mut state, &block_context, true, true).unwrap(); + let approval_execution_info = >::execute( + &approve_tx, + &mut state, + &block_context, + true, + true, + ) + .unwrap(); assert!(!approval_execution_info.is_reverted()); // Transfer a valid amount of funds to compute the cost of a successful diff --git a/crates/blockifier/src/transaction/test_utils.rs b/crates/blockifier/src/transaction/test_utils.rs index a5a353a056..b79e5b5098 100644 --- a/crates/blockifier/src/transaction/test_utils.rs +++ b/crates/blockifier/src/transaction/test_utils.rs @@ -14,6 +14,7 @@ use crate::context::{BlockContext, ChainInfo}; use crate::execution::contract_class::{ClassInfo, ContractClass}; use crate::state::cached_state::CachedState; use crate::state::state_api::State; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::{deploy_account_tx, DeployAccountTxArgs}; @@ -79,7 +80,7 @@ pub fn block_context() -> BlockContext { /// Struct containing the data usually needed to initialize a test. pub struct TestInitData { - pub state: CachedState, + pub state: CachedState, pub account_address: ContractAddress, pub contract_address: ContractAddress, pub nonce_manager: NonceManager, @@ -88,7 +89,7 @@ pub struct TestInitData { /// Deploys a new account with the given class hash, funds with both fee tokens, and returns the /// deploy tx and address. pub fn deploy_and_fund_account( - state: &mut CachedState, + state: &mut CachedState, nonce_manager: &mut NonceManager, chain_info: &ChainInfo, deploy_tx_args: DeployAccountTxArgs, @@ -268,11 +269,18 @@ pub fn account_invoke_tx(invoke_args: InvokeTxArgs) -> AccountTransaction { } pub fn run_invoke_tx( - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, invoke_args: InvokeTxArgs, ) -> TransactionExecutionResult { - account_invoke_tx(invoke_args).execute(state, block_context, true, true) + let account_tx = account_invoke_tx(invoke_args); + >::execute( + &account_tx, + state, + block_context, + true, + true, + ) } /// Creates a `ResourceBoundsMapping` with the given `max_amount` and `max_price` for L1 gas limits. diff --git a/crates/blockifier/src/transaction/transaction_execution.rs b/crates/blockifier/src/transaction/transaction_execution.rs index e617ad1c5f..ad8b2681ce 100644 --- a/crates/blockifier/src/transaction/transaction_execution.rs +++ b/crates/blockifier/src/transaction/transaction_execution.rs @@ -11,6 +11,7 @@ use crate::execution::entry_point::EntryPointExecutionContext; use crate::fee::actual_cost::TransactionReceipt; use crate::state::cached_state::TransactionalState; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::errors::TransactionFeeError; use crate::transaction::objects::{ @@ -100,10 +101,12 @@ impl TransactionInfoCreator for Transaction { } } -impl ExecutableTransaction for L1HandlerTransaction { +impl> ExecutableTransaction + for L1HandlerTransaction +{ fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, _execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { @@ -151,10 +154,10 @@ impl ExecutableTransaction for L1HandlerTransaction { } } -impl ExecutableTransaction for Transaction { +impl> ExecutableTransaction for Transaction { fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index 4e3188c150..ad6abdab12 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -21,6 +21,7 @@ use crate::execution::execution_utils::execute_deployment; use crate::state::cached_state::TransactionalState; use crate::state::errors::StateError; use crate::state::state_api::{State, UpdatableState}; +use crate::state::visited_pcs::VisitedPcsTrait; use crate::transaction::constants; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::{ @@ -48,7 +49,7 @@ pub struct ExecutionFlags { pub concurrency_mode: bool, } -pub trait ExecutableTransaction: Sized { +pub trait ExecutableTransaction>: Sized { /// Executes the transaction in a transactional manner /// (if it fails, given state does not modify). fn execute( @@ -84,7 +85,7 @@ pub trait ExecutableTransaction: Sized { /// for automatic handling of such cases. fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult; diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index e21c90de2e..953f7872b5 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -41,6 +41,7 @@ use crate::fee::gas_usage::{ use crate::state::cached_state::{CachedState, StateChangesCount, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader}; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -248,7 +249,7 @@ fn get_expected_cairo_resources( /// and the sequencer (in both fee types) are as expected (assuming the initial sequencer balances /// are zero). fn validate_final_balances( - state: &mut CachedState, + state: &mut CachedState, chain_info: &ChainInfo, expected_actual_fee: Fee, erc20_account_balance_key: StorageKey, @@ -373,7 +374,14 @@ fn test_invoke_tx( let account_tx = AccountTransaction::Invoke(invoke_tx); let tx_context = block_context.to_tx_context(&account_tx); - let actual_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let actual_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); // Build expected validate call info. let expected_account_class_hash = account_contract.get_class_hash(); @@ -496,7 +504,7 @@ fn test_invoke_tx( // Verifies the storage after each invoke execution in test_invoke_tx_advanced_operations. fn verify_storage_after_invoke_advanced_operations( - state: &mut CachedState, + state: &mut CachedState, contract_address: ContractAddress, account_address: ContractAddress, index: Felt, @@ -556,7 +564,14 @@ fn test_invoke_tx_advanced_operations( create_calldata(contract_address, "advance_counter", &calldata_args), ..base_tx_args.clone() }); - account_tx.execute(state, block_context, true, true).unwrap(); + >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let next_nonce = nonce_manager.next(account_address); let initial_ec_point = [Felt::ZERO, Felt::ZERO]; @@ -581,7 +596,14 @@ fn test_invoke_tx_advanced_operations( create_calldata(contract_address, "call_xor_counters", &calldata_args), ..base_tx_args.clone() }); - account_tx.execute(state, block_context, true, true).unwrap(); + >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let expected_counters = [felt!(counter_diffs[0] ^ xor_values[0]), felt!(counter_diffs[1] ^ xor_values[1])]; @@ -603,7 +625,14 @@ fn test_invoke_tx_advanced_operations( create_calldata(contract_address, "test_ec_op", &[]), ..base_tx_args.clone() }); - account_tx.execute(state, block_context, true, true).unwrap(); + >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let expected_ec_point = [ Felt::from_bytes_be(&[ @@ -641,7 +670,14 @@ fn test_invoke_tx_advanced_operations( create_calldata(contract_address, "add_signature_to_counters", &[index]), ..base_tx_args.clone() }); - account_tx.execute(state, block_context, true, true).unwrap(); + >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let expected_counters = [ (expected_counters[0] + signature_values[0]), @@ -666,7 +702,14 @@ fn test_invoke_tx_advanced_operations( create_calldata(contract_address, "send_message", &[to_address]), ..base_tx_args }); - let execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let next_nonce = nonce_manager.next(account_address); verify_storage_after_invoke_advanced_operations( state, @@ -729,7 +772,14 @@ fn test_state_get_fee_token_balance( version: tx_version, nonce: Nonce::default(), }); - account_tx.execute(state, block_context, true, true).unwrap(); + >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); // Get balance from state, and validate. let (low, high) = @@ -740,29 +790,31 @@ fn test_state_get_fee_token_balance( } fn assert_failure_if_resource_bounds_exceed_balance( - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, invalid_tx: AccountTransaction, ) { match block_context.to_tx_context(&invalid_tx).tx_info { TransactionInfo::Deprecated(context) => { assert_matches!( - invalid_tx.execute(state, block_context, true, true).unwrap_err(), - TransactionExecutionError::TransactionPreValidationError( - TransactionPreValidationError::TransactionFeeError( - TransactionFeeError::MaxFeeExceedsBalance{ max_fee, .. })) - if max_fee == context.max_fee - ); + >::execute( + &invalid_tx, state, block_context, true, true).unwrap_err(), + TransactionExecutionError::TransactionPreValidationError( + TransactionPreValidationError::TransactionFeeError( + TransactionFeeError::MaxFeeExceedsBalance{ max_fee, .. })) + if max_fee == context.max_fee + ); } TransactionInfo::Current(context) => { let l1_bounds = context.l1_resource_bounds().unwrap(); assert_matches!( - invalid_tx.execute(state, block_context, true, true).unwrap_err(), - TransactionExecutionError::TransactionPreValidationError( - TransactionPreValidationError::TransactionFeeError( - TransactionFeeError::L1GasBoundsExceedBalance{ max_amount, max_price, .. })) - if max_amount == l1_bounds.max_amount && max_price == l1_bounds.max_price_per_unit - ); + >::execute( + &invalid_tx, state, block_context, true, true).unwrap_err(), + TransactionExecutionError::TransactionPreValidationError( + TransactionPreValidationError::TransactionFeeError( + TransactionFeeError::L1GasBoundsExceedBalance{ max_amount, max_price, .. })) + if max_amount == l1_bounds.max_amount && max_price == l1_bounds.max_price_per_unit + ); } }; } @@ -867,7 +919,14 @@ fn test_insufficient_resource_bounds( let invalid_v1_tx = account_invoke_tx( invoke_tx_args! { max_fee: invalid_max_fee, version: TransactionVersion::ONE, ..valid_invoke_tx_args.clone() }, ); - let execution_error = invalid_v1_tx.execute(state, block_context, true, true).unwrap_err(); + let execution_error = >::execute( + &invalid_v1_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); // Test error. assert_matches!( @@ -889,7 +948,14 @@ fn test_insufficient_resource_bounds( resource_bounds: l1_resource_bounds(insufficient_max_l1_gas_amount, actual_strk_l1_gas_price.into()), ..valid_invoke_tx_args.clone() }); - let execution_error = invalid_v3_tx.execute(state, block_context, true, true).unwrap_err(); + let execution_error = >::execute( + &invalid_v3_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); // TODO(Ori, 1/2/2024): Write an indicative expect message explaining why the conversion works. let minimal_l1_gas_as_u64 = u64::try_from(minimal_l1_gas).expect("Failed to convert u128 to u64."); @@ -911,7 +977,14 @@ fn test_insufficient_resource_bounds( resource_bounds: l1_resource_bounds(minimal_l1_gas.try_into().expect("Failed to convert u128 to u64."), insufficient_max_l1_gas_price), ..valid_invoke_tx_args }); - let execution_error = invalid_v3_tx.execute(state, block_context, true, true).unwrap_err(); + let execution_error = >::execute( + &invalid_v3_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); assert_matches!( execution_error, TransactionExecutionError::TransactionPreValidationError( @@ -953,7 +1026,14 @@ fn test_actual_fee_gt_resource_bounds( invoke_tx_args! { resource_bounds: minimal_resource_bounds, ..invoke_tx_args }, ); - let execution_result = invalid_tx.execute(state, block_context, true, true).unwrap(); + let execution_result = >::execute( + &invalid_tx, + state, + block_context, + true, + true, + ) + .unwrap(); let execution_error = execution_result.revert_error.unwrap(); // Test error. assert!(execution_error.starts_with("Insufficient max L1 gas:")); @@ -981,7 +1061,8 @@ fn test_invalid_nonce( calldata: create_trivial_calldata(test_contract.get_instance_address(0)), resource_bounds: max_resource_bounds, }; - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(state); // Strict, negative flow: account nonce = 0, incoming tx nonce = 1. let invalid_nonce = nonce!(1_u8); @@ -1133,7 +1214,14 @@ fn test_declare_tx( ); let fee_type = &account_tx.fee_type(); let tx_context = &block_context.to_tx_context(&account_tx); - let actual_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let actual_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); // Build expected validate call info. let expected_validate_call_info = declare_validate_callinfo( @@ -1230,7 +1318,13 @@ fn test_declare_tx( }, class_info.clone(), ); - let result = account_tx2.execute(state, block_context, true, true); + let result = >::execute( + &account_tx2, + state, + block_context, + true, + true, + ); assert_matches!( result.unwrap_err(), TransactionExecutionError::DeclareTransactionError{ class_hash:already_declared_class_hash } if @@ -1279,7 +1373,14 @@ fn test_deploy_account_tx( let account_tx = AccountTransaction::DeployAccount(deploy_account); let fee_type = &account_tx.fee_type(); let tx_context = &block_context.to_tx_context(&account_tx); - let actual_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let actual_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); // Build expected validate call info. let validate_calldata = @@ -1396,7 +1497,14 @@ fn test_deploy_account_tx( &mut nonce_manager, ); let account_tx = AccountTransaction::DeployAccount(deploy_account); - let error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); assert_matches!( error, TransactionExecutionError::ContractConstructorExecutionFailed( @@ -1437,7 +1545,14 @@ fn test_fail_deploy_account_undeclared_class_hash( .unwrap(); let account_tx = AccountTransaction::DeployAccount(deploy_account); - let error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); assert_matches!( error, TransactionExecutionError::ContractConstructorExecutionFailed( @@ -1497,7 +1612,14 @@ fn test_validate_accounts_tx( additional_data: None, ..default_args }); - let error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); check_transaction_execution_error_for_invalid_scenario!( cairo_version, error, @@ -1513,7 +1635,14 @@ fn test_validate_accounts_tx( contract_address_salt: salt_manager.next_salt(), ..default_args }); - let error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); check_transaction_execution_error_for_custom_hint!( &error, "Unauthorized syscall call_contract in execution mode Validate.", @@ -1528,7 +1657,14 @@ fn test_validate_accounts_tx( additional_data: None, ..default_args }); - let error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); check_transaction_execution_error_for_custom_hint!( &error, "Unauthorized syscall get_block_hash in execution mode Validate.", @@ -1542,7 +1678,14 @@ fn test_validate_accounts_tx( contract_address_salt: salt_manager.next_salt(), ..default_args }); - let error = account_tx.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap_err(); check_transaction_execution_error_for_custom_hint!( &error, "Unauthorized syscall get_sequencer_address in execution mode Validate.", @@ -1565,7 +1708,13 @@ fn test_validate_accounts_tx( ..default_args }, ); - let result = account_tx.execute(state, block_context, true, true); + let result = >::execute( + &account_tx, + state, + block_context, + true, + true, + ); assert!(result.is_ok(), "Execution failed: {:?}", result.unwrap_err()); if tx_type != TransactionType::DeployAccount { @@ -1581,7 +1730,13 @@ fn test_validate_accounts_tx( ..default_args }, ); - let result = account_tx.execute(state, block_context, true, true); + let result = >::execute( + &account_tx, + state, + block_context, + true, + true, + ); assert!(result.is_ok(), "Execution failed: {:?}", result.unwrap_err()); } @@ -1600,7 +1755,13 @@ fn test_validate_accounts_tx( ..default_args }, ); - let result = account_tx.execute(state, block_context, true, true); + let result = >::execute( + &account_tx, + state, + block_context, + true, + true, + ); assert!(result.is_ok(), "Execution failed: {:?}", result.unwrap_err()); // Call the syscall get_block_timestamp and assert the returned timestamp was modified @@ -1615,7 +1776,13 @@ fn test_validate_accounts_tx( ..default_args }, ); - let result = account_tx.execute(state, block_context, true, true); + let result = >::execute( + &account_tx, + state, + block_context, + true, + true, + ); assert!(result.is_ok(), "Execution failed: {:?}", result.unwrap_err()); } @@ -1636,7 +1803,13 @@ fn test_validate_accounts_tx( ..default_args }, ); - let result = account_tx.execute(state, block_context, true, true); + let result = >::execute( + &account_tx, + state, + block_context, + true, + true, + ); assert!(result.is_ok(), "Execution failed: {:?}", result.unwrap_err()); } } @@ -1663,7 +1836,14 @@ fn test_valid_flag( resource_bounds: max_resource_bounds, }); - let actual_execution_info = account_tx.execute(state, block_context, true, false).unwrap(); + let actual_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + false, + ) + .unwrap(); assert!(actual_execution_info.validate_call_info.is_none()); } @@ -1762,7 +1942,14 @@ fn test_only_query_flag( }); let account_tx = AccountTransaction::Invoke(invoke_tx); - let tx_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let tx_execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); assert!(!tx_execution_info.is_reverted()) } @@ -1781,7 +1968,14 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) { let value = calldata.0[2]; let payload_size = tx.payload_size(); - let actual_execution_info = tx.execute(state, block_context, true, true).unwrap(); + let actual_execution_info = >::execute( + &tx, + state, + block_context, + true, + true, + ) + .unwrap(); // Build the expected call info. let accessed_storage_key = StorageKey::try_from(key).unwrap(); @@ -1902,7 +2096,14 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) { // always uptade the storage instad. state.set_storage_at(contract_address, StorageKey::try_from(key).unwrap(), Felt::ZERO).unwrap(); let tx_no_fee = L1HandlerTransaction::create_for_testing(Fee(0), contract_address); - let error = tx_no_fee.execute(state, block_context, true, true).unwrap_err(); + let error = >::execute( + &tx_no_fee, + state, + block_context, + true, + true, + ) + .unwrap_err(); // Today, we check that the paid_fee is positive, no matter what was the actual fee. let expected_actual_fee = (expected_execution_info .transaction_receipt @@ -1940,7 +2141,14 @@ fn test_execute_tx_with_invalid_transaction_version( calldata, }); - let execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); assert!( execution_info .revert_error @@ -2035,7 +2243,14 @@ fn test_emit_event_exceeds_limit( resource_bounds: max_resource_bounds, nonce: nonce!(0_u8), }); - let execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + let execution_info = >::execute( + &account_tx, + state, + block_context, + true, + true, + ) + .unwrap(); match &expected_error { Some(expected_error) => { let error_string = execution_info.revert_error.unwrap(); diff --git a/crates/native_blockifier/src/py_block_executor.rs b/crates/native_blockifier/src/py_block_executor.rs index bf3c244eba..c29b7f5f00 100644 --- a/crates/native_blockifier/src/py_block_executor.rs +++ b/crates/native_blockifier/src/py_block_executor.rs @@ -8,6 +8,7 @@ use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses}; use blockifier::execution::call_info::CallInfo; use blockifier::state::cached_state::CachedState; use blockifier::state::global_cache::GlobalContractCache; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::objects::{GasVector, ResourcesMapping, TransactionExecutionInfo}; use blockifier::transaction::transaction_execution::Transaction; use blockifier::versioned_constants::VersionedConstants; @@ -82,7 +83,7 @@ pub struct PyBlockExecutor { pub tx_executor_config: TransactionExecutorConfig, pub chain_info: ChainInfo, pub versioned_constants: VersionedConstants, - pub tx_executor: Option>, + pub tx_executor: Option>, /// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe. pub storage: Box, pub global_contract_cache: GlobalContractCache, @@ -370,7 +371,7 @@ impl PyBlockExecutor { } impl PyBlockExecutor { - pub fn tx_executor(&mut self) -> &mut TransactionExecutor { + pub fn tx_executor(&mut self) -> &mut TransactionExecutor { self.tx_executor.as_mut().expect("Transaction executor should be initialized") } diff --git a/crates/native_blockifier/src/py_test_utils.rs b/crates/native_blockifier/src/py_test_utils.rs index 0e66423790..e5c7fccc4a 100644 --- a/crates/native_blockifier/src/py_test_utils.rs +++ b/crates/native_blockifier/src/py_test_utils.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use blockifier::execution::contract_class::ContractClassV0; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::test_utils::dict_state_reader::DictStateReader; use starknet_api::core::ClassHash; use starknet_api::{class_hash, felt}; @@ -12,7 +13,7 @@ pub const TOKEN_FOR_TESTING_CONTRACT_PATH: &str = "./src/starkware/starknet/core/test_contract/starknet_compiled_contracts_lib/starkware/\ starknet/core/test_contract/token_for_testing.json"; -pub fn create_py_test_state() -> CachedState { +pub fn create_py_test_state() -> CachedState { let class_hash_to_class = HashMap::from([( class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), ContractClassV0::from_file(TOKEN_FOR_TESTING_CONTRACT_PATH).into(), diff --git a/crates/native_blockifier/src/py_validator.rs b/crates/native_blockifier/src/py_validator.rs index 8398e48c1a..3b547c3e3a 100644 --- a/crates/native_blockifier/src/py_validator.rs +++ b/crates/native_blockifier/src/py_validator.rs @@ -2,6 +2,7 @@ use blockifier::blockifier::stateful_validator::{StatefulValidator, StatefulVali use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::account_transaction::AccountTransaction; use blockifier::transaction::objects::TransactionInfoCreator; use blockifier::transaction::transaction_types::TransactionType; @@ -21,7 +22,7 @@ use crate::state_readers::py_state_reader::PyStateReader; #[pyclass] pub struct PyValidator { - pub stateful_validator: StatefulValidator, + pub stateful_validator: StatefulValidator, pub max_nonce_for_validation_skip: Nonce, } diff --git a/crates/native_blockifier/src/state_readers/papyrus_state_test.rs b/crates/native_blockifier/src/state_readers/papyrus_state_test.rs index e999276084..89a5ba133c 100644 --- a/crates/native_blockifier/src/state_readers/papyrus_state_test.rs +++ b/crates/native_blockifier/src/state_readers/papyrus_state_test.rs @@ -7,6 +7,7 @@ use blockifier::retdata; use blockifier::state::cached_state::CachedState; use blockifier::state::global_cache::{GlobalContractCache, GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST}; use blockifier::state::state_api::StateReader; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::test_utils::contracts::FeatureContract; use blockifier::test_utils::{trivial_external_entry_point_new, CairoVersion}; use indexmap::IndexMap; @@ -56,7 +57,7 @@ fn test_entry_point_with_papyrus_state() -> papyrus_storage::StorageResult<()> { block_number, GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST), ); - let mut state = CachedState::from(papyrus_reader); + let mut state: CachedState<_, VisitedPcsSet> = CachedState::from(papyrus_reader); // Call entrypoint that want to write to storage, which updates the cached state's write cache. let key = felt!(1234_u16);