diff --git a/crates/katana/storage/provider/src/providers/fork/backend.rs b/crates/katana/storage/provider/src/providers/fork/backend.rs index 5fe359b8a1..f1c065c444 100644 --- a/crates/katana/storage/provider/src/providers/fork/backend.rs +++ b/crates/katana/storage/provider/src/providers/fork/backend.rs @@ -372,7 +372,7 @@ impl BackendHandle { /// cache to avoid fetching it again. This is shared across multiple instances of /// [`ForkedStateDb`](super::state::ForkedStateDb). #[derive(Clone)] -pub struct SharedStateProvider(Arc>); +pub struct SharedStateProvider(pub(crate) Arc>); impl SharedStateProvider { pub(crate) fn new_with_backend(backend: BackendHandle) -> Self { @@ -599,14 +599,11 @@ fn handle_not_found_err(result: Result) -> Result, } #[cfg(test)] -mod tests { +pub(crate) mod test_utils { use std::sync::mpsc::sync_channel; - use std::time::Duration; use katana_primitives::block::BlockNumber; - use katana_primitives::contract::GenericContractInfo; - use starknet::macros::felt; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; use tokio::net::TcpListener; @@ -614,23 +611,14 @@ mod tests { use super::*; - const LOCAL_RPC_URL: &str = "http://localhost:5050"; - - const STORAGE_KEY: StorageKey = felt!("0x1"); - const ADDR_1: ContractAddress = ContractAddress(felt!("0xADD1")); - const ADDR_1_NONCE: Nonce = felt!("0x1"); - const ADDR_1_STORAGE_VALUE: StorageKey = felt!("0x8080"); - const ADDR_1_CLASS_HASH: StorageKey = felt!("0x1"); - - fn create_forked_backend(rpc_url: String, block_num: BlockNumber) -> BackendHandle { - let url = Url::parse(&rpc_url).expect("valid url"); + pub fn create_forked_backend(rpc_url: &str, block_num: BlockNumber) -> BackendHandle { + let url = Url::parse(rpc_url).expect("valid url"); let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(url))); - let block_id = BlockHashOrNumber::Num(block_num); - Backend::new(provider, block_id).unwrap() + Backend::new(provider, block_num.into()).unwrap() } // Starts a TCP server that never close the connection. - fn start_tcp_server() { + pub fn start_tcp_server() { use tokio::runtime::Builder; let (tx, rx) = sync_channel::<()>(1); @@ -650,22 +638,36 @@ mod tests { rx.recv().unwrap(); } +} + +#[cfg(test)] +mod tests { + + use std::time::Duration; + + use katana_primitives::contract::GenericContractInfo; + use starknet::macros::felt; + + use super::test_utils::*; + use super::*; + + const LOCAL_RPC_URL: &str = "http://localhost:5050"; + + const STORAGE_KEY: StorageKey = felt!("0x1"); + const ADDR_1: ContractAddress = ContractAddress(felt!("0xADD1")); + const ADDR_1_NONCE: Nonce = felt!("0x1"); + const ADDR_1_STORAGE_VALUE: StorageKey = felt!("0x8080"); + const ADDR_1_CLASS_HASH: StorageKey = felt!("0x1"); - const ERROR_INIT_BACKEND: &str = "Failed to create backend"; const ERROR_SEND_REQUEST: &str = "Failed to send request to backend"; const ERROR_STATS: &str = "Failed to get stats"; #[test] fn handle_incoming_requests() { - let url = Url::try_from("http://127.0.0.1:8080").unwrap(); - let provider = JsonRpcClient::new(HttpTransport::new(url)); - let block_id = BlockHashOrNumber::Num(1); - // start a mock remote network start_tcp_server(); - // start backend - let handle = Backend::new(Arc::new(provider), block_id).expect(ERROR_INIT_BACKEND); + let handle = create_forked_backend("http://127.0.0.1:8080", 1); // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -704,7 +706,7 @@ mod tests { #[test] fn get_from_cache_if_exist() { // setup - let backend = create_forked_backend(LOCAL_RPC_URL.into(), 1); + let backend = create_forked_backend(LOCAL_RPC_URL, 1); let state_db = CacheStateDb::new(backend); state_db @@ -734,7 +736,7 @@ mod tests { #[test] fn fetch_from_fork_will_err_if_backend_thread_not_running() { - let backend = create_forked_backend(LOCAL_RPC_URL.into(), 1); + let backend = create_forked_backend(LOCAL_RPC_URL, 1); let provider = SharedStateProvider(Arc::new(CacheStateDb::new(backend))); assert!(StateProvider::nonce(&provider, ADDR_1).is_err()) } @@ -751,7 +753,7 @@ mod tests { #[test] #[ignore] fn fetch_from_fork_if_not_in_cache() { - let backend = create_forked_backend(FORKED_URL.into(), 908622); + let backend = create_forked_backend(FORKED_URL, 908622); let provider = SharedStateProvider(Arc::new(CacheStateDb::new(backend))); // fetch from remote diff --git a/crates/katana/storage/provider/src/providers/fork/state.rs b/crates/katana/storage/provider/src/providers/fork/state.rs index fb69c6f9d2..793dc540d2 100644 --- a/crates/katana/storage/provider/src/providers/fork/state.rs +++ b/crates/katana/storage/provider/src/providers/fork/state.rs @@ -39,9 +39,23 @@ impl StateProvider for ForkedStateDb { StateProvider::class_hash_of_contract(&self.db, address) } + // When reading from local storage, we only consider entries that have non-zero nonce + // values OR non-zero class hashes. + // + // Nonce == 0 && ClassHash == 0 + // - Contract does not exist locally (so try find from remote state) + // Nonce != 0 && ClassHash == 0 + // - Contract exists and was deployed remotely but new nonce was set locally (so no need to read + // from remote state anymore) + // Nonce == 0 && ClassHash != 0 + // - Contract exists and was deployed locally (always read from local state) fn nonce(&self, address: ContractAddress) -> ProviderResult> { - if let nonce @ Some(_) = - self.contract_state.read().get(&address).map(|i| i.nonce).filter(|n| n != &Nonce::ZERO) + if let nonce @ Some(_) = self + .contract_state + .read() + .get(&address) + .filter(|c| c.nonce != Nonce::default() || c.class_hash != ClassHash::default()) + .map(|c| c.nonce) { return Ok(nonce); } @@ -134,8 +148,8 @@ impl StateProvider for ForkedSnapshot { .inner .contract_state .get(&address) - .map(|info| info.nonce) - .filter(|n| n != &Nonce::ZERO) + .filter(|c| c.nonce != Nonce::default() || c.class_hash != ClassHash::default()) + .map(|c| c.nonce) { return Ok(nonce); } @@ -199,3 +213,135 @@ impl ContractClassProvider for ForkedSnapshot { } } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use katana_primitives::state::{StateUpdates, StateUpdatesWithDeclaredClasses}; + use starknet::macros::felt; + + use super::*; + use crate::providers::fork::backend::test_utils::create_forked_backend; + + #[test] + fn test_get_nonce() { + let backend = create_forked_backend("http://localhost:8080", 1); + + let address: ContractAddress = felt!("1").into(); + let class_hash = felt!("11"); + let remote_nonce = felt!("111"); + let local_nonce = felt!("1111"); + + // Case: contract doesn't exist at all + { + let remote = SharedStateProvider::new_with_backend(backend.clone()); + let local = ForkedStateDb::new(remote.clone()); + + // asserts that its error for now + assert!(local.nonce(address).is_err()); + assert!(remote.nonce(address).is_err()); + + // make sure the snapshot maintains the same behavior + let snapshot = local.create_snapshot(); + assert!(snapshot.nonce(address).is_err()); + } + + // Case: contract exist remotely + { + let remote = SharedStateProvider::new_with_backend(backend.clone()); + let local = ForkedStateDb::new(remote.clone()); + + let nonce_updates = HashMap::from([(address, remote_nonce)]); + let updates = StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { nonce_updates, ..Default::default() }, + ..Default::default() + }; + remote.0.insert_updates(updates); + + assert_eq!(local.nonce(address).unwrap(), Some(remote_nonce)); + assert_eq!(remote.nonce(address).unwrap(), Some(remote_nonce)); + + // make sure the snapshot maintains the same behavior + let snapshot = local.create_snapshot(); + assert_eq!(snapshot.nonce(address).unwrap(), Some(remote_nonce)); + } + + // Case: contract exist remotely but nonce was updated locally + { + let remote = SharedStateProvider::new_with_backend(backend.clone()); + let local = ForkedStateDb::new(remote.clone()); + + let nonce_updates = HashMap::from([(address, remote_nonce)]); + let contract_updates = HashMap::from([(address, class_hash)]); + let updates = StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { + nonce_updates, + contract_updates, + ..Default::default() + }, + ..Default::default() + }; + remote.0.insert_updates(updates); + + let nonce_updates = HashMap::from([(address, local_nonce)]); + let updates = StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { nonce_updates, ..Default::default() }, + ..Default::default() + }; + local.insert_updates(updates); + + assert_eq!(local.nonce(address).unwrap(), Some(local_nonce)); + assert_eq!(remote.nonce(address).unwrap(), Some(remote_nonce)); + + // make sure the snapshot maintains the same behavior + let snapshot = local.create_snapshot(); + assert_eq!(snapshot.nonce(address).unwrap(), Some(local_nonce)); + } + + // Case: contract was deployed locally only and has non-zero nonce + { + let remote = SharedStateProvider::new_with_backend(backend.clone()); + let local = ForkedStateDb::new(remote.clone()); + + let contract_updates = HashMap::from([(address, class_hash)]); + let nonce_updates = HashMap::from([(address, local_nonce)]); + let updates = StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { + nonce_updates, + contract_updates, + ..Default::default() + }, + ..Default::default() + }; + local.insert_updates(updates); + + assert_eq!(local.nonce(address).unwrap(), Some(local_nonce)); + assert!(remote.nonce(address).is_err()); + + // make sure the snapshot maintains the same behavior + let snapshot = local.create_snapshot(); + assert_eq!(snapshot.nonce(address).unwrap(), Some(local_nonce)); + } + + // Case: contract was deployed locally only and has zero nonce + { + let remote = SharedStateProvider::new_with_backend(backend.clone()); + let local = ForkedStateDb::new(remote.clone()); + + let contract_updates = HashMap::from([(address, class_hash)]); + let updates = StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { contract_updates, ..Default::default() }, + ..Default::default() + }; + local.insert_updates(updates); + + assert_eq!(local.nonce(address).unwrap(), Some(Default::default())); + assert!(remote.nonce(address).is_err()); + + // make sure the snapshot maintains the same behavior + let snapshot = local.create_snapshot(); + assert_eq!(snapshot.nonce(address).unwrap(), Some(Default::default())); + } + } +}