diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..b591d387 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.cairo linguist-language=python linguist-detectable=false diff --git a/CMakeLists.txt b/CMakeLists.txt index 9dc30886..c896351f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required (VERSION 3.5) +cmake_minimum_required (VERSION 3.22) project(CairoLang VERSION 0.1.0) include(CTest) diff --git a/Dockerfile b/Dockerfile index f9bff081..1fe62e71 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,9 @@ FROM ciimage/python:3.7 RUN apt update -RUN apt install -y cmake libgmp3-dev g++ python3-pip python3.7-dev python3.7-venv npm +RUN apt install -y make libgmp3-dev g++ python3-pip python3.7-dev python3.7-venv npm +# Installing cmake via apt doesn't bring the most up-to-date version. +RUN pip install cmake==3.22 # Install solc and ganache RUN curl https://binaries.soliditylang.org/linux-amd64/solc-linux-amd64-v0.6.12+commit.27d51765 -o /usr/local/bin/solc-0.6.12 @@ -19,7 +21,7 @@ WORKDIR /app/build/Release RUN make all -j8 # Run tests. -RUN ctest -V +RUN ctest -V -j8 WORKDIR /app/ RUN src/starkware/cairo/lang/package_test/run_test.sh diff --git a/README.md b/README.md index b4d51645..8eb0d7ad 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.7.1.zip . +> docker cp ${container_id}:/app/cairo-lang-0.8.0.zip . > docker rm -v ${container_id} ``` diff --git a/src/cmake_utils/pip_rules.cmake b/src/cmake_utils/pip_rules.cmake index 3f4e98f3..3b5064f3 100644 --- a/src/cmake_utils/pip_rules.cmake +++ b/src/cmake_utils/pip_rules.cmake @@ -12,10 +12,13 @@ function(python_pip TARGET) cmake_parse_arguments(ARGS "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) # Create a list of all dependencies regardless of python's version. - execute_process( - COMMAND ${UNITE_LIBS_EXECUTABLE} ${ARGS_LIBS} - OUTPUT_VARIABLE UNITED_LIBS - ) + set(UNITED_LIBS ${ARGS_LIBS}) + if("${UNITED_LIBS}" MATCHES ":") + execute_process( + COMMAND ${UNITE_LIBS_EXECUTABLE} ${UNITED_LIBS} + OUTPUT_VARIABLE UNITED_LIBS + ) + endif() separate_arguments(UNITED_LIBS) set(ALL_STAMPS) @@ -32,7 +35,7 @@ function(python_pip TARGET) set(STAMP_FILE ${CMAKE_BINARY_DIR}/python_pip/${TARGET}_${INTERPRETER}_${REQ}.stamp) # Creating library directory. - if (${REQ} MATCHES "==local$") + if (${REQ} MATCHES "\\+local$") string(REPLACE "==" "-" PACKAGE_NAME ${REQ}) set(ZIP_FILE "${PROJECT_SOURCE_DIR}/${PACKAGE_NAME}.zip") add_custom_command( @@ -40,9 +43,7 @@ function(python_pip TARGET) COMMENT "Building ${REQ} from a local copy." COMMAND rm -rf ${LIB_DIR}/* COMMAND unzip ${ZIP_FILE} -d ${LIB_DIR} > /dev/null - # We don't know if the directory in the zip has the same name as the package. - COMMAND ls ${LIB_DIR} | grep -v -x ${PACKAGE_NAME} | xargs -r -I {} mv ${LIB_DIR}/{} ${LIB_DIR}/${PACKAGE_NAME} - COMMAND mv ${LIB_DIR}/${PACKAGE_NAME}/* ${LIB_DIR} + COMMAND mv ${LIB_DIR}/${PACKAGE_NAME}/* ${LIB_DIR}/ COMMAND rm -rf ${LIB_DIR}/${PACKAGE_NAME}/ COMMAND ${CMAKE_COMMAND} -E touch ${STAMP_FILE} DEPENDS ${ZIP_FILE} @@ -70,8 +71,8 @@ function(python_pip TARGET) ) endif() - set(ALL_STAMPS ${ALL_STAMPS} ${STAMP_FILE}) - set(ALL_LIB_DIRS ${ALL_LIB_DIRS} "${INTERPRETER}:${LIB_DIR}") + list(APPEND ALL_STAMPS ${STAMP_FILE}) + list(APPEND ALL_LIB_DIRS "${INTERPRETER}:${LIB_DIR}") endforeach() # Info target. @@ -108,10 +109,13 @@ function(python_get_pip_deps TARGET) set(CMAKE_FILE "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_generated_rules.cmake") # Create a list of all dependency files. - execute_process( - COMMAND ${UNITE_LIBS_EXECUTABLE} ${ARGN} - OUTPUT_VARIABLE UNITED_DEP_FILES - ) + set(UNITED_DEP_FILES ${ARGN}) + if("${UNITED_DEP_FILES}" MATCHES ":") + execute_process( + COMMAND ${UNITE_LIBS_EXECUTABLE} ${UNITED_DEP_FILES} + OUTPUT_VARIABLE UNITED_DEP_FILES + ) + endif() separate_arguments(UNITED_DEP_FILES) # Add as a reconfigure dependency, so that CMake will reconfigure on change. diff --git a/src/cmake_utils/python_rules.cmake b/src/cmake_utils/python_rules.cmake index b45a0449..27e3efce 100644 --- a/src/cmake_utils/python_rules.cmake +++ b/src/cmake_utils/python_rules.cmake @@ -58,7 +58,7 @@ function(python_lib LIB) # Copy files. copy_files(${LIB}_copy_files ${CMAKE_CURRENT_SOURCE_DIR} ${LIB_DIR} ${ARGS_FILES}) get_target_property(COPY_STAMP ${LIB}_copy_files STAMP_FILE) - set(ALL_FILE_DEPS ${ALL_FILE_DEPS} ${COPY_STAMP}) + list(APPEND ALL_FILE_DEPS ${COPY_STAMP}) # Copy artifacts. foreach(ARTIFACT ${ARGS_ARTIFACTS}) @@ -73,22 +73,25 @@ function(python_lib LIB) DEPENDS ${ARTIFACT_SRC} COMMENT "Copying artifact ${ARTIFACT_SRC} to ${LIB_DIR}/${ARTIFACT_DEST}" ) - set(ALL_FILE_DEPS ${ALL_FILE_DEPS} ${LIB_DIR}/${ARTIFACT_DEST}) - set(LIB_FILES ${LIB_FILES} ${ARGS_PREFIX}${ARTIFACT_DEST}) + list(APPEND ALL_FILE_DEPS ${LIB_DIR}/${ARTIFACT_DEST}) + list(APPEND LIB_FILES ${ARGS_PREFIX}${ARTIFACT_DEST}) endforeach() # Create a list of all dependencies regardless of python's version. - execute_process( - COMMAND ${UNITE_LIBS_EXECUTABLE} ${ARGS_LIBS} - OUTPUT_VARIABLE UNITED_LIBS - ) + set(UNITED_LIBS ${ARGS_LIBS}) + if("${UNITED_LIBS}" MATCHES ":") + execute_process( + COMMAND ${UNITE_LIBS_EXECUTABLE} ${UNITED_LIBS} + OUTPUT_VARIABLE UNITED_LIBS + ) + endif() separate_arguments(UNITED_LIBS) # Info target. set(DEP_INFO) foreach(DEP_LIB ${UNITED_LIBS} ${ARGS_PY_EXE_DEPENDENCIES}) get_lib_info_file(DEP_INFO_FILE ${DEP_LIB}) - set(DEP_INFO ${DEP_INFO} ${DEP_INFO_FILE}) + LIST(APPEND DEP_INFO ${DEP_INFO_FILE}) endforeach() get_lib_info_file(INFO_FILE ${LIB}) @@ -140,7 +143,7 @@ function(python_venv VENV_NAME) set(DEP_INFO) foreach(DEP_LIB ${ARGS_LIBS}) get_lib_info_file(DEP_INFO_FILE ${DEP_LIB}) - set(DEP_INFO ${DEP_INFO} ${DEP_INFO_FILE}) + list(APPEND DEP_INFO ${DEP_INFO_FILE}) endforeach() add_custom_command( diff --git a/src/demo/amm_demo/amm.cairo b/src/demo/amm_demo/amm.cairo index 98278bfc..ba6feff4 100644 --- a/src/demo/amm_demo/amm.cairo +++ b/src/demo/amm_demo/amm.cairo @@ -6,7 +6,7 @@ from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.hash import hash2 from starkware.cairo.common.math import assert_nn_le, unsigned_div_rem from starkware.cairo.common.registers import get_fp_and_pc -from starkware.cairo.common.small_merkle_tree import small_merkle_tree +from starkware.cairo.common.small_merkle_tree import small_merkle_tree_update struct Account: member public_key : felt @@ -213,7 +213,7 @@ func compute_merkle_roots{pedersen_ptr : HashBuiltin*, range_check_ptr}(state : hash_dict_start=hash_dict_start) # Compute the two Merkle roots. - let (root_before, root_after) = small_merkle_tree{hash_ptr=pedersen_ptr}( + let (root_before, root_after) = small_merkle_tree_update{hash_ptr=pedersen_ptr}( squashed_dict_start=hash_dict_start, squashed_dict_end=hash_dict_end, height=LOG_N_ACCOUNTS) diff --git a/src/demo/amm_demo/demo.py b/src/demo/amm_demo/demo.py index cb9fdd97..de36f4a3 100644 --- a/src/demo/amm_demo/demo.py +++ b/src/demo/amm_demo/demo.py @@ -13,7 +13,7 @@ from web3 import HTTPProvider, Web3, eth from demo.amm_demo.prove_batch import Account, Balance, BatchProver, SwapTransaction -from starkware.cairo.bootloader.hash_program import compute_program_hash_chain +from starkware.cairo.bootloaders.hash_program import compute_program_hash_chain from starkware.cairo.common.small_merkle_tree import MerkleTree from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager, pedersen_hash from starkware.cairo.sharp.sharp_client import init_client diff --git a/src/demo/amm_demo/prove_batch.py b/src/demo/amm_demo/prove_batch.py index 12ecd78a..7dbdcdc8 100644 --- a/src/demo/amm_demo/prove_batch.py +++ b/src/demo/amm_demo/prove_batch.py @@ -3,7 +3,7 @@ import tempfile from typing import Any, Dict, List, Tuple -from starkware.cairo.bootloader.generate_fact import get_program_output +from starkware.cairo.bootloaders.generate_fact import get_program_output from starkware.cairo.sharp.sharp_client import Program, SharpClient diff --git a/src/services/everest/api/feeder_gateway/feeder_gateway_client.py b/src/services/everest/api/feeder_gateway/feeder_gateway_client.py index 737f6658..e661c918 100644 --- a/src/services/everest/api/feeder_gateway/feeder_gateway_client.py +++ b/src/services/everest/api/feeder_gateway/feeder_gateway_client.py @@ -14,3 +14,7 @@ class EverestFeederGatewayClient(BaseClient): async def get_last_batch_id(self) -> int: raw_response = await self._send_request(send_method="GET", uri="/get_last_batch_id") return json.loads(raw_response) + + async def get_l1_blockchain_id(self) -> int: + raw_response = await self._send_request(send_method="GET", uri="/get_l1_blockchain_id") + return json.loads(raw_response) diff --git a/src/services/everest/business_logic/internal_transaction.py b/src/services/everest/business_logic/internal_transaction.py index 5fbdc184..7ae21c49 100644 --- a/src/services/everest/business_logic/internal_transaction.py +++ b/src/services/everest/business_logic/internal_transaction.py @@ -49,7 +49,7 @@ def add_class(self, cls: type): self.classes[cls_name] = cls -class EverestTransactionExecutionInfo: +class EverestTransactionExecutionInfo(ValidatedMarshmallowDataclass): """ Base class of classes containing information generated from an execution of a transaction on the state. Each Everest application may implement it specifically. diff --git a/src/services/everest/business_logic/transaction_execution_objects.py b/src/services/everest/business_logic/transaction_execution_objects.py index 946f449c..b4e5cf58 100644 --- a/src/services/everest/business_logic/transaction_execution_objects.py +++ b/src/services/everest/business_logic/transaction_execution_objects.py @@ -13,10 +13,14 @@ class TransactionFailureReason(ValidatedMarshmallowDataclass): transaction. """ - tx_id: int code: str error_message: Optional[str] + @marshmallow.decorators.pre_load + def remove_tx_id(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: + data.pop("tx_id", None) + return data + @marshmallow.decorators.post_dump def truncate_error_message(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: error_message = data["error_message"] diff --git a/src/services/everest/definitions/CMakeLists.txt b/src/services/everest/definitions/CMakeLists.txt index 788a4158..95139736 100644 --- a/src/services/everest/definitions/CMakeLists.txt +++ b/src/services/everest/definitions/CMakeLists.txt @@ -12,3 +12,14 @@ python_lib(everest_definitions_lib pip_marshmallow pip_web3 ) + +python_lib(everest_general_config_lib + PREFIX services/everest/definitions + + FILES + general_config.py + + LIBS + starkware_config_utils_lib + pip_marshmallow_dataclass +) diff --git a/src/services/everest/definitions/fields.py b/src/services/everest/definitions/fields.py index 1388a47a..64bb471b 100644 --- a/src/services/everest/definitions/fields.py +++ b/src/services/everest/definitions/fields.py @@ -93,5 +93,9 @@ def format(self, value: str) -> str: ) +def felt(name_in_error_message: str) -> RangeValidatedField: + return dataclasses.replace(FeltField, name=name_in_error_message) + + def felt_metadata(name_in_error_message: str) -> Dict[str, Any]: - return dataclasses.replace(FeltField, name=name_in_error_message).metadata() + return felt(name_in_error_message=name_in_error_message).metadata() diff --git a/src/services/everest/definitions/general_config.py b/src/services/everest/definitions/general_config.py new file mode 100644 index 00000000..b396ba0b --- /dev/null +++ b/src/services/everest/definitions/general_config.py @@ -0,0 +1,8 @@ +import marshmallow_dataclass + +from starkware.starkware_utils.config_base import Config + + +@marshmallow_dataclass.dataclass(frozen=True) +class EverestGeneralConfig(Config): + pass diff --git a/src/services/external_api/CMakeLists.txt b/src/services/external_api/CMakeLists.txt index 053aa39e..15c3ecda 100644 --- a/src/services/external_api/CMakeLists.txt +++ b/src/services/external_api/CMakeLists.txt @@ -1,3 +1,10 @@ +python_lib(services_external_api_utils_lib + PREFIX services/external_api + + FILES + utils.py +) + python_lib(services_external_api_lib PREFIX services/external_api @@ -7,6 +14,7 @@ python_lib(services_external_api_lib ${SERVICES_EXTERNAL_API_LIB_ADDITIONAL_FILES} LIBS + services_external_api_utils_lib starkware_dataclasses_utils_lib pip_aiohttp ${SERVICES_EXTERNAL_API_LIB_ADDITIONAL_LIBS} diff --git a/src/services/external_api/base_client.py b/src/services/external_api/base_client.py index 43c73762..17f4ee42 100644 --- a/src/services/external_api/base_client.py +++ b/src/services/external_api/base_client.py @@ -117,15 +117,18 @@ async def _send_request( return text except aiohttp.ClientError as exception: - error_message = f"Got {type(exception).__name__}" + error_message = f"Got {type(exception).__name__} while trying to access {url}." if limited_retries and n_retries_left == 0: logger.error(error_message, exc_info=True) raise - logger.debug(f"{error_message}, retrying...", exc_info=True) + logger.debug(f"{error_message}, retrying...") except BadRequest as exception: - error_message = f"Got {type(exception).__name__}" + error_message = ( + f"Got {type(exception).__name__} while trying to access {url}. " + f"Status code: {exception.status_code}; text: {exception.text}." + ) if limited_retries and ( n_retries_left == 0 @@ -134,12 +137,7 @@ async def _send_request( logger.error(error_message, exc_info=True) raise - logger.debug( - f"{error_message} while trying to access {url}. " - f"status_code: {exception.status_code}. text: {exception.text}, " - "retrying...", - exc_info=True, - ) + logger.debug(f"{error_message}, retrying...") await asyncio.sleep(1) diff --git a/src/services/external_api/has_uri_prefix.py b/src/services/external_api/has_uri_prefix.py index 9e293e53..c47880b5 100644 --- a/src/services/external_api/has_uri_prefix.py +++ b/src/services/external_api/has_uri_prefix.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -from typing import cast +from typing import Optional, cast + +from services.external_api.utils import join_routes class HasUriPrefix(ABC): @@ -17,9 +19,10 @@ def prefix(cls) -> str: """ @classmethod - def format_uri(cls, name: str) -> str: + def format_uri(cls, name: str, version: Optional[str] = None) -> str: """ - Concatenates cls.prefix with given URI. + Concatenates version/cls.prefix with given URI. """ prefix = cast(str, cls.prefix) # Mypy sees the property as a callable. - return name if len(prefix) == 0 else f"{cls.prefix}{name}" + route_list = [s for s in [version, prefix, name] if s is not None and len(s) != 0] + return join_routes(route_list=route_list) diff --git a/src/services/external_api/utils.py b/src/services/external_api/utils.py new file mode 100644 index 00000000..c6b36668 --- /dev/null +++ b/src/services/external_api/utils.py @@ -0,0 +1,11 @@ +from typing import List + + +def join_routes(route_list: List[str]) -> str: + """ + Joins a list of routes where the result will start with '/' and between every two routes there + will be exactly one '/'. The reason why it is implemented and the builtin urljoin isn't being + used, is that urljoin ignores preceding strings in the path if a leading slash is encountered. + """ + assert None not in route_list and "" not in route_list + return "/" + "/".join(s.strip("/") for s in route_list) diff --git a/src/starkware/cairo/CMakeLists.txt b/src/starkware/cairo/CMakeLists.txt index 705fd453..3681cef9 100644 --- a/src/starkware/cairo/CMakeLists.txt +++ b/src/starkware/cairo/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(bootloader) +add_subdirectory(bootloaders) add_subdirectory(common) add_subdirectory(lang) add_subdirectory(sharp) diff --git a/src/starkware/cairo/bootloader/fact_topology.py b/src/starkware/cairo/bootloader/fact_topology.py deleted file mode 100644 index 65290003..00000000 --- a/src/starkware/cairo/bootloader/fact_topology.py +++ /dev/null @@ -1,32 +0,0 @@ -import dataclasses -import json -from typing import ClassVar, List, Type - -import marshmallow -import marshmallow_dataclass - -GPS_FACT_TOPOLOGY = "gps_fact_topology" - - -@dataclasses.dataclass(frozen=True) -class FactTopology: - tree_structure: List[int] - # List of page sizes, in words. - page_sizes: List[int] - - -@marshmallow_dataclass.dataclass(frozen=True) -class FactTopologiesFile: - fact_topologies: List[FactTopology] - Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema - - -def load_fact_topologies(path) -> List[FactTopology]: - return FactTopologiesFile.Schema().load(json.load(open(path))).fact_topologies - - -@dataclasses.dataclass(frozen=True) -class FactInfo: - program_output: List[int] - fact_topology: FactTopology - fact: str diff --git a/src/starkware/cairo/bootloader/CMakeLists.txt b/src/starkware/cairo/bootloaders/CMakeLists.txt similarity index 75% rename from src/starkware/cairo/bootloader/CMakeLists.txt rename to src/starkware/cairo/bootloaders/CMakeLists.txt index dcf0aeb1..6d676ab0 100644 --- a/src/starkware/cairo/bootloader/CMakeLists.txt +++ b/src/starkware/cairo/bootloaders/CMakeLists.txt @@ -1,9 +1,12 @@ set(PROGRAM_HASH_TEST_UTILS_LIB_ADDITIONAL_LIBS cairo_hash_program_lib) -include(bootloader_test_utils.cmake) +add_subdirectory(bootloader) +add_subdirectory(simple_bootloader) + +include(program_hash_test_utils.cmake) python_lib(cairo_hash_program_lib - PREFIX starkware/cairo/bootloader + PREFIX starkware/cairo/bootloaders FILES hash_program.py @@ -24,11 +27,11 @@ python_venv(cairo_hash_program_venv python_exe(cairo_hash_program_exe VENV cairo_hash_program_venv - MODULE starkware.cairo.bootloader.hash_program + MODULE starkware.cairo.bootloaders.hash_program ) python_lib(cairo_bootloader_fact_topology_lib - PREFIX starkware/cairo/bootloader + PREFIX starkware/cairo/bootloaders FILES fact_topology.py @@ -38,7 +41,7 @@ python_lib(cairo_bootloader_fact_topology_lib ) python_lib(cairo_bootloader_generate_fact_lib - PREFIX starkware/cairo/bootloader + PREFIX starkware/cairo/bootloaders FILES compute_fact.py generate_fact.py diff --git a/src/starkware/cairo/bootloader/__init__.py b/src/starkware/cairo/bootloaders/__init__.py similarity index 100% rename from src/starkware/cairo/bootloader/__init__.py rename to src/starkware/cairo/bootloaders/__init__.py diff --git a/src/starkware/cairo/bootloaders/bootloader/CMakeLists.txt b/src/starkware/cairo/bootloaders/bootloader/CMakeLists.txt new file mode 100644 index 00000000..a3a650c8 --- /dev/null +++ b/src/starkware/cairo/bootloaders/bootloader/CMakeLists.txt @@ -0,0 +1,2 @@ +cairo_compile(bootloader_program + bootloader_compiled.json bootloader.cairo "--debug_info_with_source --proof_mode") diff --git a/src/starkware/cairo/bootloaders/bootloader/bootloader.cairo b/src/starkware/cairo/bootloaders/bootloader/bootloader.cairo new file mode 100644 index 00000000..047f4c49 --- /dev/null +++ b/src/starkware/cairo/bootloaders/bootloader/bootloader.cairo @@ -0,0 +1,287 @@ +%builtins output pedersen range_check ecdsa bitwise + +from starkware.cairo.bootloaders.simple_bootloader.run_simple_bootloader import ( + run_simple_bootloader) +from starkware.cairo.cairo_verifier.objects import CairoVerifierOutput +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash_state import HashState, hash_finalize, hash_init, hash_update +from starkware.cairo.common.memcpy import memcpy + +struct BootloaderConfig: + # The hash of the simple bootloader program. + member simple_bootloader_program_hash : felt + # The hash of a (Cairo) program that verifies a STARK proof for the Cairo machine. + member cairo_verifier_program_hash : felt +end + +struct TaskOutputHeader: + member size : felt + member program_hash : felt +end + +# Runs the simple bootloader on tasks and unpacks them to the output. +# +# Hint arguments: +# program_input - Contains the inputs for the bootloader. +func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr, bitwise_ptr}( + ): + alloc_locals + local simple_bootloader_output_start : felt* + %{ + from starkware.cairo.bootloaders.bootloader.objects import BootloaderInput + bootloader_input = BootloaderInput.Schema().load(program_input) + + ids.simple_bootloader_output_start = segments.add() + + # Change output builtin state to a different segment in preparation for calling the + # simple bootloader. + output_builtin_state = output_builtin.get_state() + output_builtin.new_state(base=ids.simple_bootloader_output_start) + %} + + # Save segment's start. + let simple_bootloader_output_ptr : felt* = simple_bootloader_output_start + + # Call the simple bootloader program to execute direct subtasks. Simple bootloader input is + # contained in the bootloader input. + %{ simple_bootloader_input = bootloader_input %} + run_simple_bootloader{output_ptr=simple_bootloader_output_ptr}() + let simple_bootloader_output_end : felt* = simple_bootloader_output_ptr + + %{ + # Restore the bootloader's output builtin state. + output_builtin.set_state(output_builtin_state) + %} + + # The bootloader config appears at the beginning of the output. + let bootloader_config = cast(output_ptr, BootloaderConfig*) + let output_ptr = output_ptr + BootloaderConfig.SIZE + + %{ + segments.write_arg( + ids.bootloader_config.address_, + [ + bootloader_input.simple_bootloader_program_hash, + bootloader_input.cairo_verifier_program_hash, + ], + ) + %} + + # Increment output_ptr to save place for n_total_tasks. + let output_n_total_tasks = [output_ptr] + let output_ptr = output_ptr + 1 + %{ output_start = ids.output_ptr %} + + let simple_bootloader_output_ptr = simple_bootloader_output_start + + # Skip n_subtasks in the simple bootloader output. + let n_subtasks = [simple_bootloader_output_ptr] + let simple_bootloader_output_ptr = simple_bootloader_output_ptr + 1 + + # Parse outputs recursively and write it to the output builtin. + let n_total_tasks : felt = 0 + %{ packed_outputs = bootloader_input.packed_outputs %} + with simple_bootloader_output_ptr, n_total_tasks: + parse_tasks{subtasks_output=simple_bootloader_output_ptr}( + bootloader_config=bootloader_config, n_subtasks=n_subtasks) + end + + # Assert that parse_tasks used the entire output of the simple bootloader. + let parse_tasks_end = simple_bootloader_output_ptr + assert simple_bootloader_output_end = parse_tasks_end + + # Output the total number of tasks. + assert output_n_total_tasks = n_total_tasks + + %{ + from typing import List + + from starkware.cairo.bootloaders.bootloader.utils import compute_fact_topologies + from starkware.cairo.bootloaders.fact_topology import FactTopology + from starkware.cairo.bootloaders.simple_bootloader.utils import ( + configure_fact_topologies, write_to_fact_topologies_file) + + # Compute the fact topologies of the plain packed outputs based on packed_outputs and + # fact_topologies of the inner tasks. + plain_fact_topologies: List[FactTopology] = compute_fact_topologies( + packed_outputs=packed_outputs, fact_topologies=fact_topologies, + ) + + # Configure the memory pages in the output builtin, based on plain_fact_topologies. + configure_fact_topologies( + fact_topologies=plain_fact_topologies, output_start=output_start, + output_builtin=output_builtin, + ) + + # Dump fact topologies to a json file. + if bootloader_input.fact_topologies_path is not None: + write_to_fact_topologies_file( + fact_topologies_path=bootloader_input.fact_topologies_path, + fact_topologies=plain_fact_topologies, + ) + %} + return () +end + +# Unpacks composite packed outputs recursively and writes each task's plain output to the output +# builtin. +# +# Arguments: +# n_subtasks - Number of direct subtasks to unfold. +# bootloader_config. +# +# Hint arguments: +# packed_outputs - PackedOutput object that stores the task tree structure. +# +# Implicit arguments: +# n_total_tasks - Number of PlainPackedOutput that were unpacked. This function increments this +# value for each unpacked output. +# subtasks_output - Contains direct subtasks outputs which is used for unpacking. This is an input +# to this function and is returned for validation purposes. +func parse_tasks{ + output_ptr : felt*, pedersen_ptr : HashBuiltin*, n_total_tasks : felt, + subtasks_output : felt*}(bootloader_config : BootloaderConfig*, n_subtasks : felt): + if n_subtasks == 0: + return () + end + + alloc_locals + + %{ + from starkware.cairo.bootloaders.bootloader.objects import PackedOutput + + task_id = len(packed_outputs) - ids.n_subtasks + packed_output: PackedOutput = packed_outputs[task_id] + + vm_enter_scope(new_scope_locals=dict(packed_output=packed_output)) + %} + + %{ + from starkware.cairo.bootloaders.bootloader.objects import ( + CompositePackedOutput, PlainPackedOutput) + %} + + if nondet %{ isinstance(packed_output, PlainPackedOutput) %} != 0: + # Handle plain packed task. + unpack_plain_packed_task{task_output=subtasks_output}(bootloader_config=bootloader_config) + else: + # Handle composite packed task. + %{ assert isinstance(packed_output, CompositePackedOutput) %} + unpack_composite_packed_task{task_output=subtasks_output}( + bootloader_config=bootloader_config) + end + + %{ vm_exit_scope() %} + + # Call recursively for handling the other tasks. + return parse_tasks(bootloader_config=bootloader_config, n_subtasks=n_subtasks - 1) +end + +# Parses the task header. +# +# Implicit arguments: +# task_output - A pointer to the output of the plain packed task. Assumes that task_output is of +# the following format: (task_header, output). +func parse_task_header{task_output : felt*}() -> (task_header : TaskOutputHeader*): + let task_header = cast(task_output, TaskOutputHeader*) + let task_output = task_output + TaskOutputHeader.SIZE + return (task_header=task_header) +end + +# Unpacks a composite packed task output. +# +# Arguments: +# bootloader_config. +# +# Implicit arguments: +# task_output - A pointer to the output of the composite packed task. task_output should be of the +# following format: +# (output_len, cairo_verifier_program_hash, simple_bootloader_program_hash, output_hash). +# n_total_tasks - Number of PlainPackedOutput that were unpacked. +# +# Hint arguments: +# packed_output - CompositePackedOutput object which uses for unpacking the task. +func unpack_composite_packed_task{ + output_ptr : felt*, pedersen_ptr : HashBuiltin*, n_total_tasks : felt, task_output : felt*}( + bootloader_config : BootloaderConfig*): + alloc_locals + + # Guess the pre-image of subtasks_output_hash (subtasks_output_hash appears in task_output). + local nested_subtasks_output : felt* + local nested_subtasks_output_len + %{ + data = packed_output.elements_for_hash() + ids.nested_subtasks_output_len = len(data) + ids.nested_subtasks_output = segments.gen_arg(data) + %} + + # Compute the hash of nested_subtasks_output. + let (hash_state_ptr : HashState*) = hash_init() + let (hash_state_ptr) = hash_update{hash_ptr=pedersen_ptr}( + hash_state_ptr=hash_state_ptr, + data_ptr=nested_subtasks_output, + data_length=nested_subtasks_output_len) + let (subtasks_output_hash) = hash_finalize{hash_ptr=pedersen_ptr}(hash_state_ptr=hash_state_ptr) + + # Verify task output header. + let (task_header : TaskOutputHeader*) = parse_task_header() + assert [task_header] = TaskOutputHeader( + size=TaskOutputHeader.SIZE + CairoVerifierOutput.SIZE, + program_hash=bootloader_config.cairo_verifier_program_hash) + + # Verify task output. + assert [cast(task_output, CairoVerifierOutput*)] = CairoVerifierOutput( + program_hash=bootloader_config.simple_bootloader_program_hash, + output_hash=subtasks_output_hash) + let task_output = task_output + CairoVerifierOutput.SIZE + + # Call recursively to parse the composite task's subtasks. + local nested_subtasks_output_start : felt* = nested_subtasks_output + let n_subtasks = [nested_subtasks_output] + let nested_subtasks_output = nested_subtasks_output + 1 + %{ packed_outputs = packed_output.subtasks %} + with nested_subtasks_output: + parse_tasks{subtasks_output=nested_subtasks_output}( + bootloader_config=bootloader_config, n_subtasks=n_subtasks) + end + + # Assert that the entire subtask output was used. + assert nested_subtasks_output = nested_subtasks_output_start + nested_subtasks_output_len + return () +end + +# Unpacks a plain packed task output to the output builtin. +# +# Arguments: +# bootloader_config. +# +# Implicit arguments: +# task_output - A pointer to the output of the plain packed task. Assumes that task_output is of +# the following format: (output_len, cairo_verifier_program_hash, *output). +# n_total_tasks - Number of PlainPackedOutput that were unpacked. This function increments this +# value by 1. +func unpack_plain_packed_task{ + output_ptr : felt*, pedersen_ptr : HashBuiltin*, n_total_tasks : felt, task_output : felt*}( + bootloader_config : BootloaderConfig*): + alloc_locals + + # Parse task output header. + let (task_header : TaskOutputHeader*) = parse_task_header() + + # Copy the simple bootloader output header to the bootloader output. + assert [cast(output_ptr, TaskOutputHeader*)] = [task_header] + + # Increment output pointer. + let output_ptr = output_ptr + TaskOutputHeader.SIZE + + # Copy the program output to the bootloader output. + let output_size = task_header.size - TaskOutputHeader.SIZE + memcpy(dst=output_ptr, src=task_output, len=output_size) + + # Increment pointers. + let output_ptr = output_ptr + output_size + let task_output = task_output + output_size + let n_total_tasks = n_total_tasks + 1 + return () +end diff --git a/src/starkware/cairo/bootloader/compute_fact.py b/src/starkware/cairo/bootloaders/compute_fact.py similarity index 97% rename from src/starkware/cairo/bootloader/compute_fact.py rename to src/starkware/cairo/bootloaders/compute_fact.py index 2942e0c4..06648f2d 100644 --- a/src/starkware/cairo/bootloader/compute_fact.py +++ b/src/starkware/cairo/bootloaders/compute_fact.py @@ -4,7 +4,7 @@ from eth_hash.auto import keccak -from starkware.cairo.bootloader.fact_topology import FactTopology +from starkware.cairo.bootloaders.fact_topology import FactTopology from starkware.python.utils import to_bytes diff --git a/src/starkware/cairo/bootloader/generate_fact.py b/src/starkware/cairo/bootloaders/fact_topology.py similarity index 61% rename from src/starkware/cairo/bootloader/generate_fact.py rename to src/starkware/cairo/bootloaders/fact_topology.py index c7121186..7fcac394 100644 --- a/src/starkware/cairo/bootloader/generate_fact.py +++ b/src/starkware/cairo/bootloaders/fact_topology.py @@ -1,48 +1,35 @@ -from typing import Any, Dict, List, Optional +import dataclasses +import json +from typing import Any, ClassVar, Dict, List, Type -from starkware.cairo.bootloader.compute_fact import generate_program_fact -from starkware.cairo.bootloader.fact_topology import GPS_FACT_TOPOLOGY, FactInfo, FactTopology -from starkware.cairo.bootloader.hash_program import compute_program_hash_chain -from starkware.cairo.lang.vm.cairo_pie import CairoPie -from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +import marshmallow +import marshmallow_dataclass +GPS_FACT_TOPOLOGY = "gps_fact_topology" -def get_program_output(cairo_pie: CairoPie) -> List[int]: - """ - Returns the program output. - """ - assert "output" in cairo_pie.metadata.builtin_segments, "The output builtin must be used." - output = cairo_pie.metadata.builtin_segments["output"] - - def verify_int(x: MaybeRelocatable) -> int: - assert isinstance( - x, int - ), f"Expected program output to contain absolute values, found: {x}." - return x - - return [ - verify_int(cairo_pie.memory[RelocatableValue(segment_index=output.index, offset=i)]) - for i in range(output.size) - ] +@dataclasses.dataclass(frozen=True) +class FactTopology: + tree_structure: List[int] + # List of page sizes, in words. + page_sizes: List[int] -def get_cairo_pie_fact_info(cairo_pie: CairoPie, program_hash: Optional[int] = None) -> FactInfo: - """ - Generates the fact of the Cairo program of cairo_pie. Returns the cairo-pie fact info. - """ - program_output = get_program_output(cairo_pie=cairo_pie) - fact_topology = get_fact_topology_from_additional_data( - output_size=len(program_output), - output_builtin_additional_data=cairo_pie.additional_data["output_builtin"], - ) - if program_hash is None: - program_hash = get_program_hash(cairo_pie) - fact = generate_program_fact(program_hash, program_output, fact_topology=fact_topology) - return FactInfo(program_output=program_output, fact_topology=fact_topology, fact=fact) + +@marshmallow_dataclass.dataclass(frozen=True) +class FactTopologiesFile: + fact_topologies: List[FactTopology] + Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema + + +def load_fact_topologies(path) -> List[FactTopology]: + return FactTopologiesFile.Schema().load(json.load(open(path))).fact_topologies -def get_program_hash(cairo_pie: CairoPie) -> int: - return compute_program_hash_chain(cairo_pie.metadata.program) +@dataclasses.dataclass(frozen=True) +class FactInfo: + program_output: List[int] + fact_topology: FactTopology + fact: str def get_page_sizes_from_page_dict(output_size: int, pages: dict) -> List[int]: diff --git a/src/starkware/cairo/bootloaders/generate_fact.py b/src/starkware/cairo/bootloaders/generate_fact.py new file mode 100644 index 00000000..66c8b15d --- /dev/null +++ b/src/starkware/cairo/bootloaders/generate_fact.py @@ -0,0 +1,48 @@ +from typing import List, Optional + +from starkware.cairo.bootloaders.compute_fact import generate_program_fact +from starkware.cairo.bootloaders.fact_topology import ( + FactInfo, + get_fact_topology_from_additional_data, +) +from starkware.cairo.bootloaders.hash_program import compute_program_hash_chain +from starkware.cairo.lang.vm.cairo_pie import CairoPie +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue + + +def get_program_output(cairo_pie: CairoPie) -> List[int]: + """ + Returns the program output. + """ + assert "output" in cairo_pie.metadata.builtin_segments, "The output builtin must be used." + output = cairo_pie.metadata.builtin_segments["output"] + + def verify_int(x: MaybeRelocatable) -> int: + assert isinstance( + x, int + ), f"Expected program output to contain absolute values, found: {x}." + return x + + return [ + verify_int(cairo_pie.memory[RelocatableValue(segment_index=output.index, offset=i)]) + for i in range(output.size) + ] + + +def get_cairo_pie_fact_info(cairo_pie: CairoPie, program_hash: Optional[int] = None) -> FactInfo: + """ + Generates the fact of the Cairo program of cairo_pie. Returns the cairo-pie fact info. + """ + program_output = get_program_output(cairo_pie=cairo_pie) + fact_topology = get_fact_topology_from_additional_data( + output_size=len(program_output), + output_builtin_additional_data=cairo_pie.additional_data["output_builtin"], + ) + if program_hash is None: + program_hash = get_program_hash(cairo_pie) + fact = generate_program_fact(program_hash, program_output, fact_topology=fact_topology) + return FactInfo(program_output=program_output, fact_topology=fact_topology, fact=fact) + + +def get_program_hash(cairo_pie: CairoPie) -> int: + return compute_program_hash_chain(cairo_pie.metadata.program) diff --git a/src/starkware/cairo/bootloader/hash_program.py b/src/starkware/cairo/bootloaders/hash_program.py similarity index 100% rename from src/starkware/cairo/bootloader/hash_program.py rename to src/starkware/cairo/bootloaders/hash_program.py diff --git a/src/starkware/cairo/bootloader/bootloader_test_utils.cmake b/src/starkware/cairo/bootloaders/program_hash_test_utils.cmake similarity index 78% rename from src/starkware/cairo/bootloader/bootloader_test_utils.cmake rename to src/starkware/cairo/bootloaders/program_hash_test_utils.cmake index 69b83034..1293006c 100644 --- a/src/starkware/cairo/bootloader/bootloader_test_utils.cmake +++ b/src/starkware/cairo/bootloaders/program_hash_test_utils.cmake @@ -1,5 +1,5 @@ python_lib(program_hash_test_utils_lib - PREFIX starkware/cairo/bootloader + PREFIX starkware/cairo/bootloaders FILES program_hash_test_utils.py diff --git a/src/starkware/cairo/bootloader/program_hash_test_utils.py b/src/starkware/cairo/bootloaders/program_hash_test_utils.py similarity index 93% rename from src/starkware/cairo/bootloader/program_hash_test_utils.py rename to src/starkware/cairo/bootloaders/program_hash_test_utils.py index 618c0082..80e3de9b 100644 --- a/src/starkware/cairo/bootloader/program_hash_test_utils.py +++ b/src/starkware/cairo/bootloaders/program_hash_test_utils.py @@ -1,6 +1,6 @@ import json -from starkware.cairo.bootloader.hash_program import compute_program_hash_chain +from starkware.cairo.bootloaders.hash_program import compute_program_hash_chain from starkware.cairo.lang.compiler.program import Program diff --git a/src/starkware/cairo/bootloaders/simple_bootloader/CMakeLists.txt b/src/starkware/cairo/bootloaders/simple_bootloader/CMakeLists.txt new file mode 100644 index 00000000..3557d452 --- /dev/null +++ b/src/starkware/cairo/bootloaders/simple_bootloader/CMakeLists.txt @@ -0,0 +1,2 @@ +cairo_compile(simple_bootloader_program + simple_bootloader_compiled.json simple_bootloader.cairo "--debug_info_with_source --proof_mode") diff --git a/src/starkware/cairo/bootloaders/simple_bootloader/execute_task.cairo b/src/starkware/cairo/bootloaders/simple_bootloader/execute_task.cairo new file mode 100644 index 00000000..f0d24f24 --- /dev/null +++ b/src/starkware/cairo/bootloaders/simple_bootloader/execute_task.cairo @@ -0,0 +1,213 @@ +from starkware.cairo.builtin_selection.inner_select_builtins import inner_select_builtins +from starkware.cairo.builtin_selection.select_input_builtins import select_input_builtins +from starkware.cairo.builtin_selection.validate_builtins import validate_builtins +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash_chain import hash_chain +from starkware.cairo.common.registers import get_ap, get_fp_and_pc + +const BOOTLOADER_VERSION = 0 + +# Use an empty struct to encode an arbitrary-length array. +struct BuiltinList: +end + +struct ProgramHeader: + # The data length field specifies the length of the data (i.e., program header + program) + # and guarantees unique decoding of the program hash. + member data_length : felt + member bootloader_version : felt + member program_main : felt + member n_builtins : felt + # 'builtin_list' is a continuous memory segment containing the ASCII encoding of the (ordered) + # builtins used by the program. + member builtin_list : BuiltinList +end + +struct BuiltinData: + member output : felt + member pedersen : felt + member range_check : felt + member ecdsa : felt + member bitwise : felt +end + +# Executes a single task. +# The task is passed in the 'task' hint variable. +# Outputs of the task are prefixed by: +# a. Output size (including this prefix) +# b. hash_chain(ProgramHeader || task.program.data) where ProgramHeader is defined below. +# The function returns a pointer to the updated builtin pointers after executing the task. +func execute_task{builtin_ptrs : BuiltinData*, self_range_check_ptr}( + builtin_encodings : BuiltinData*, builtin_instance_sizes : BuiltinData*): + # Allocate memory for local variables. + alloc_locals + + # Get the value of fp. + let (local __fp__, _) = get_fp_and_pc() + + # Pointer to the program data (which starts with ProgramHeader). + local program_data_ptr : felt* + %{ ids.program_data_ptr = program_data_base = segments.add() %} + + # The struct of input builtin pointers pointed by the given builtin_ptrs. + let input_builtin_ptrs : BuiltinData* = builtin_ptrs + local output_ptr = input_builtin_ptrs.output + + let program_header = cast(program_data_ptr, ProgramHeader*) + %{ + from starkware.cairo.bootloaders.simple_bootloader.utils import load_program + + # Call load_program to load the program header and code to memory. + program_address, program_data_size = load_program( + task=task, memory=memory, program_header=ids.program_header, + builtins_offset=ids.ProgramHeader.builtin_list) + segments.finalize(program_data_base.segment_index, program_data_size) + %} + + # Verify that the bootloader version is compatible with the bootloader. + assert program_header.bootloader_version = BOOTLOADER_VERSION + + # Call hash_chain, to verify the program hash. + let pedersen_ptr = cast(input_builtin_ptrs.pedersen, HashBuiltin*) + let (hash) = hash_chain{hash_ptr=pedersen_ptr}(data_ptr=program_data_ptr) + # Write hash_chain result to output_ptr + 1. + assert [output_ptr + 1] = hash + %{ + # Validate hash. + from starkware.cairo.bootloaders.hash_program import compute_program_hash_chain + + assert memory[ids.output_ptr + 1] == compute_program_hash_chain(task.get_program()), \ + 'Computed hash does not match input.' + %} + + # Set the program entry point, so the bootloader can later run the program. + local builtin_list : felt* = &program_header.builtin_list + local n_builtins = program_header.n_builtins + tempvar program_address = builtin_list + n_builtins + %{ + # Sanity check. + assert ids.program_address == program_address + %} + tempvar program_main = program_header.program_main + # The address in memory where the main function of the task is loaded. + local program_entry_point : felt* = program_address + program_main + + # Fill in all builtin pointers which may be used by the task. + # Skip the 2 slots prefix that we add to the task output. + local pre_execution_builtin_ptrs : BuiltinData = BuiltinData( + output=output_ptr + 2, + pedersen=cast(pedersen_ptr, felt), + range_check=input_builtin_ptrs.range_check, + ecdsa=input_builtin_ptrs.ecdsa, + bitwise=input_builtin_ptrs.bitwise) + + # Call select_input_builtins to get the relevant input builtin pointers for the task. + select_input_builtins( + all_encodings=builtin_encodings, + all_ptrs=&pre_execution_builtin_ptrs, + selected_encodings=builtin_list, + n_selected_builtins=n_builtins) + + call_task: + %{ + from starkware.cairo.bootloaders.simple_bootloader.objects import ( + CairoPieTask, RunProgramTask, Task) + from starkware.cairo.bootloaders.simple_bootloader.utils import ( + load_cairo_pie, prepare_output_runner) + + assert isinstance(task, Task) + n_builtins = len(task.get_program().builtins) + new_task_locals = {} + if isinstance(task, RunProgramTask): + new_task_locals['program_input'] = task.program_input + new_task_locals['WITH_BOOTLOADER'] = True + + vm_load_program(task.program, program_address) + elif isinstance(task, CairoPieTask): + ret_pc = ids.ret_pc_label.instruction_offset_ - ids.call_task.instruction_offset_ + pc + load_cairo_pie( + task=task.cairo_pie, memory=memory, segments=segments, + program_address=program_address, execution_segment_address= ap - n_builtins, + ecdsa_builtin=ecdsa_builtin, ret_fp=fp, ret_pc=ret_pc) + else: + raise NotImplementedError(f'Unexpected task type: {type(task).__name__}.') + + output_runner_data = prepare_output_runner( + task=task, + output_builtin=output_builtin, + output_ptr=ids.pre_execution_builtin_ptrs.output) + vm_enter_scope(new_task_locals) + %} + + # Call the inner program's main() function. + call abs program_entry_point + + ret_pc_label: + %{ + vm_exit_scope() + # Note that bootloader_input will only be available in the next hint. + %} + + # Note that used_builtins_addr cannot be set in a hint because doing so will allow a malicious + # prover to lie about the outputs of a valid program. + let (ap_val) = get_ap() + local used_builtins_addr : felt* = cast(ap_val - n_builtins, felt*) + + # Call inner_select_builtins to validate that the values of the builtin pointers for the next + # task are updated according to the task return builtin pointers. + + # Allocate a struct containing all builtin pointers just after the program returns. + local return_builtin_ptrs : BuiltinData + %{ + from starkware.cairo.bootloaders.simple_bootloader.utils import write_return_builtins + + # Fill the values of all builtin pointers after executing the task. + builtins = task.get_program().builtins + write_return_builtins( + memory=memory, return_builtins_addr=ids.return_builtin_ptrs.address_, + used_builtins=builtins, used_builtins_addr=ids.used_builtins_addr, + pre_execution_builtins_addr=ids.pre_execution_builtin_ptrs.address_, task=task) + + vm_enter_scope({'n_selected_builtins': n_builtins}) + %} + let select_builtins_ret = inner_select_builtins( + all_encodings=builtin_encodings, + all_ptrs=&return_builtin_ptrs, + selected_encodings=builtin_list, + selected_ptrs=used_builtins_addr, + n_builtins=BuiltinData.SIZE) + %{ vm_exit_scope() %} + + # Assert that the correct number of builtins was selected. + # Note that builtin_list is a pointer to the list containing the selected encodings. + assert n_builtins = select_builtins_ret.selected_encodings_end - builtin_list + + # Call validate_builtins to validate that the builtin pointers have advanced correctly. + validate_builtins{range_check_ptr=self_range_check_ptr}( + prev_builtin_ptrs=&pre_execution_builtin_ptrs, + new_builtin_ptrs=&return_builtin_ptrs, + builtin_instance_sizes=builtin_instance_sizes, + n_builtins=BuiltinData.SIZE) + + # Verify that [output_ptr] = return_builtin_ptrs.output - output_ptr. + # Output size should be 2 + the number of output slots that were consumed by the task. + local output_size = return_builtin_ptrs.output - output_ptr + assert [output_ptr] = output_size + + %{ + from starkware.cairo.bootloaders.simple_bootloader.utils import get_task_fact_topology + + # Add the fact topology of the current task to 'fact_topologies'. + output_start = ids.pre_execution_builtin_ptrs.output + output_end = ids.return_builtin_ptrs.output + fact_topologies.append(get_task_fact_topology( + output_size=output_end - output_start, + task=task, + output_builtin=output_builtin, + output_runner_data=output_runner_data, + )) + %} + + let builtin_ptrs = &return_builtin_ptrs + return () +end diff --git a/src/starkware/cairo/bootloaders/simple_bootloader/run_simple_bootloader.cairo b/src/starkware/cairo/bootloaders/simple_bootloader/run_simple_bootloader.cairo new file mode 100644 index 00000000..78de40b1 --- /dev/null +++ b/src/starkware/cairo/bootloaders/simple_bootloader/run_simple_bootloader.cairo @@ -0,0 +1,132 @@ +from starkware.cairo.bootloaders.simple_bootloader.execute_task import BuiltinData, execute_task +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.registers import get_fp_and_pc + +# Loads the programs and executes them. +# +# Hint Arguments: +# simple_bootloader_input - contains the tasks to execute. +# +# Returns: +# Updated builtin pointers after executing all programs. +# fact_topologies - that corresponds to the tasks (hint variable). +func run_simple_bootloader{ + output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr, bitwise_ptr}(): + alloc_locals + local task_range_check_ptr + + %{ + n_tasks = len(simple_bootloader_input.tasks) + memory[ids.output_ptr] = n_tasks + + # Task range checks are located right after simple bootloader validation range checks, and + # this is validated later in this function. + ids.task_range_check_ptr = ids.range_check_ptr + ids.BuiltinData.SIZE * n_tasks + + # A list of fact_tooplogies that instruct how to generate the fact from the program output + # for each task. + fact_topologies = [] + %} + + let n_tasks = [output_ptr] + let output_ptr = output_ptr + 1 + + # A struct containing the pointer to each builtin. + local builtin_ptrs_before : BuiltinData = BuiltinData( + output=cast(output_ptr, felt), + pedersen=cast(pedersen_ptr, felt), + range_check=task_range_check_ptr, + ecdsa=ecdsa_ptr, + bitwise=bitwise_ptr) + + # A struct containing the encoding of each builtin. + local builtin_encodings : BuiltinData = BuiltinData( + output='output', + pedersen='pedersen', + range_check='range_check', + ecdsa='ecdsa', + bitwise='bitwise') + + local builtin_instance_sizes : BuiltinData = BuiltinData( + output=1, pedersen=3, range_check=1, ecdsa=2, bitwise=5) + + # Call execute_tasks. + let (__fp__, _) = get_fp_and_pc() + + %{ tasks = simple_bootloader_input.tasks %} + let builtin_ptrs = &builtin_ptrs_before + let self_range_check_ptr = range_check_ptr + with builtin_ptrs, self_range_check_ptr: + execute_tasks( + builtin_encodings=&builtin_encodings, + builtin_instance_sizes=&builtin_instance_sizes, + n_tasks=n_tasks) + end + + # Verify that the task range checks appear after the self range checks of execute_task. + assert self_range_check_ptr = task_range_check_ptr + + # Return the updated builtin pointers. + local builtin_ptrs : BuiltinData* = builtin_ptrs + let output_ptr = cast(builtin_ptrs.output, felt*) + let pedersen_ptr = cast(builtin_ptrs.pedersen, HashBuiltin*) + let range_check_ptr = builtin_ptrs.range_check + let ecdsa_ptr = builtin_ptrs.ecdsa + let bitwise_ptr = builtin_ptrs.bitwise + + # Verify that range_check has indeed advanced. + let additional_range_checks = range_check_ptr - self_range_check_ptr + verify_non_negative(num=additional_range_checks, n_bits=64) + + return () +end + +# Verifies that a field element is in the range [0, 2^n_bits), without relying on the range_check +# builtin. +func verify_non_negative(num : felt, n_bits : felt): + if n_bits == 0: + assert num = 0 + return () + end + + tempvar num_div2 = nondet %{ ids.num // 2 %} + tempvar bit = num - (num_div2 + num_div2) + # Check that bit is 0 or 1. + assert bit = bit * bit + return verify_non_negative(num=num_div2, n_bits=n_bits - 1) +end + +# Executes the last n_tasks from simple_bootloader_input.tasks. +# +# Arguments: +# builtin_encodings - String encodings of the builtins. +# builtin_instance_sizes - Mapping to builtin sizes. +# n_tasks - The number of tasks to execute. +# +# Implicit arguments: +# builtin_ptrs - Pointer to the builtin pointers before/after executing the tasks. +# self_range_check_ptr - range_check pointer (used for validating the builtins). +# +# Hint arguments: +# tasks - A list of tasks to execute. +func execute_tasks{builtin_ptrs : BuiltinData*, self_range_check_ptr}( + builtin_encodings : BuiltinData*, builtin_instance_sizes : BuiltinData*, n_tasks): + if n_tasks == 0: + return () + end + + %{ + from starkware.cairo.bootloaders.simple_bootloader.objects import Task + + # Pass current task to execute_task. + task_id = len(simple_bootloader_input.tasks) - ids.n_tasks + task = simple_bootloader_input.tasks[task_id].load_task() + %} + # Call execute_task to execute the current task. + execute_task(builtin_encodings=builtin_encodings, builtin_instance_sizes=builtin_instance_sizes) + + return execute_tasks( + builtin_encodings=builtin_encodings, + builtin_instance_sizes=builtin_instance_sizes, + n_tasks=n_tasks - 1) +end diff --git a/src/starkware/cairo/bootloaders/simple_bootloader/simple_bootloader.cairo b/src/starkware/cairo/bootloaders/simple_bootloader/simple_bootloader.cairo new file mode 100644 index 00000000..1cd3771e --- /dev/null +++ b/src/starkware/cairo/bootloaders/simple_bootloader/simple_bootloader.cairo @@ -0,0 +1,39 @@ +%builtins output pedersen range_check ecdsa bitwise + +from starkware.cairo.bootloaders.simple_bootloader.run_simple_bootloader import ( + run_simple_bootloader) +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.registers import get_fp_and_pc + +func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr, bitwise_ptr}( + ): + %{ + from starkware.cairo.bootloaders.simple_bootloader.objects import SimpleBootloaderInput + simple_bootloader_input = SimpleBootloaderInput.Schema().load(program_input) + %} + + # Execute tasks. + run_simple_bootloader() + + %{ + # Dump fact topologies to a json file. + from starkware.cairo.bootloaders.simple_bootloader.utils import ( + configure_fact_topologies, write_to_fact_topologies_file) + + # The task-related output is prefixed by a single word that contains the number of tasks. + tasks_output_start = output_builtin.base + 1 + + # Configure the memory pages in the output builtin, based on fact_topologies. + configure_fact_topologies( + fact_topologies=fact_topologies, output_start=tasks_output_start, + output_builtin=output_builtin, + ) + + if simple_bootloader_input.fact_topologies_path is not None: + write_to_fact_topologies_file( + fact_topologies_path=simple_bootloader_input.fact_topologies_path, + fact_topologies=fact_topologies, + ) + %} + return () +end diff --git a/src/starkware/cairo/builtin_selection/select_builtins.cairo b/src/starkware/cairo/builtin_selection/select_builtins.cairo index 7e97077c..fbc43a56 100644 --- a/src/starkware/cairo/builtin_selection/select_builtins.cairo +++ b/src/starkware/cairo/builtin_selection/select_builtins.cairo @@ -13,7 +13,7 @@ func select_builtins( n_builtins=n_builtins) %{ vm_exit_scope() %} # Assert that the correct number of builtins was selected. - n_selected_builtins = selected_encodings_end - selected_encodings + assert n_selected_builtins = selected_encodings_end - selected_encodings return () end diff --git a/src/starkware/cairo/builtin_selection/validate_builtins.cairo b/src/starkware/cairo/builtin_selection/validate_builtins.cairo index 93a0a06f..68db45f8 100644 --- a/src/starkware/cairo/builtin_selection/validate_builtins.cairo +++ b/src/starkware/cairo/builtin_selection/validate_builtins.cairo @@ -20,13 +20,16 @@ func validate_builtin{range_check_ptr}( end # Validates that the builtin pointers were advanced correctly. +# # The inputs are: # The previous list of builtin pointers. # The new list of builtin pointers. # The sizes of the builtin instances. # The number of builtins. -# For each builtin the function validates that difference between the new builtin pointer and -# the old builtin pointer is positive integer divisible by the corresponding builtin instance size. +# +# For each builtin the function validates that the difference between the new builtin pointer and +# the old builtin pointer is a nonnegative integer divisible by the corresponding builtin +# instance size. # # The function consumes n_builtins range check instances starting at range_check_ptr and returns the # updated range check pointer. diff --git a/src/starkware/cairo/cairo_verifier/objects.cairo b/src/starkware/cairo/cairo_verifier/objects.cairo new file mode 100644 index 00000000..596d6adb --- /dev/null +++ b/src/starkware/cairo/cairo_verifier/objects.cairo @@ -0,0 +1,7 @@ +struct CairoVerifierOutput: + member program_hash : felt + member output_hash : felt +end + +struct StarkProof: +end diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 4800add2..09ffe717 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -3,8 +3,11 @@ python_lib(cairo_common_lib FILES alloc.cairo bitwise.cairo - cairo_builtins.cairo + bool.cairo + cairo_blake2s/blake2s.cairo cairo_blake2s/blake2s_utils.py + cairo_blake2s/packed_blake2s.cairo + cairo_builtins.cairo cairo_keccak/keccak_utils.py cairo_secp/secp_utils.py cairo_sha256/sha256_utils.py @@ -41,6 +44,7 @@ python_lib(cairo_common_lib squash_dict.cairo structs.py uint256.cairo + usort.cairo ${CAIRO_COMMON_LIB_ADDITIONAL_FILES} LIBS @@ -65,3 +69,20 @@ python_lib(cairo_function_runner_lib cairo_vm_crypto_lib cairo_vm_lib ) + +full_python_test(cairo_common_test + PREFIX starkware/cairo/common + PYTHON python3.7 + TESTED_MODULES starkware/cairo/common + + FILES + cairo_blake2s/blake2s_test.cairo + cairo_blake2s/blake2s_test.py + + LIBS + cairo_common_lib + cairo_constants_lib + cairo_function_runner_lib + pip_pytest + pip_pytest_asyncio +) diff --git a/src/starkware/cairo/common/bool.cairo b/src/starkware/cairo/common/bool.cairo new file mode 100644 index 00000000..76d9a2a2 --- /dev/null +++ b/src/starkware/cairo/common/bool.cairo @@ -0,0 +1,3 @@ +# Represents boolean values in Cairo. +const FALSE = 0 +const TRUE = 1 diff --git a/src/starkware/cairo/common/cairo_blake2s/blake2s.cairo b/src/starkware/cairo/common/cairo_blake2s/blake2s.cairo new file mode 100644 index 00000000..ab3cbf01 --- /dev/null +++ b/src/starkware/cairo/common/cairo_blake2s/blake2s.cairo @@ -0,0 +1,398 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_blake2s.packed_blake2s import N_PACKED_INSTANCES, blake2s_compress +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.math import assert_nn_le, unsigned_div_rem +from starkware.cairo.common.math_cmp import is_le +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.memset import memset +from starkware.cairo.common.pow import pow +from starkware.cairo.common.registers import get_fp_and_pc, get_label_location + +const INPUT_BLOCK_FELTS = 16 +const INPUT_BLOCK_BYTES = 64 +const STATE_SIZE_FELTS = 8 +# Each instance consists of 8 words for the input state, 16 words of message, 2 words for t0 and f0, +# and 8 words for the output state. +const INSTANCE_SIZE = STATE_SIZE_FELTS + INPUT_BLOCK_FELTS + 2 + STATE_SIZE_FELTS + +# Computes blake2s of 'input'. +# To use this function, split the input into words of 32 bits (little endian). +# For example, to compute blake2s('Hello world'), use: +# input = [1819043144, 1870078063, 6581362] +# where: +# 1819043144 == int.from_bytes(b'Hell', 'little') +# 1870078063 == int.from_bytes(b'o wo', 'little') +# 6581362 == int.from_bytes(b'rld', 'little') +# +# output is an array of 8 32-bit words (little endian). +# +# Note: You must call finalize_blake2s() at the end of the program. Otherwise, this function +# is not sound and a malicious prover may return a wrong result. +# Note: the interface of this function may change in the future. +func blake2s{range_check_ptr, blake2s_ptr : felt*}(data : felt*, n_bytes : felt) -> ( + output : felt*): + # Set the initial state to IV (IV[0] is modified). + assert blake2s_ptr[0] = 0x6B08E647 # IV[0] ^ 0x01010020 (config: no key, 32 bytes output). + assert blake2s_ptr[1] = 0xBB67AE85 + assert blake2s_ptr[2] = 0x3C6EF372 + assert blake2s_ptr[3] = 0xA54FF53A + assert blake2s_ptr[4] = 0x510E527F + assert blake2s_ptr[5] = 0x9B05688C + assert blake2s_ptr[6] = 0x1F83D9AB + assert blake2s_ptr[7] = 0x5BE0CD19 + static_assert STATE_SIZE_FELTS == 8 + let blake2s_ptr = blake2s_ptr + STATE_SIZE_FELTS + + let (output) = blake2s_inner(data=data, n_bytes=n_bytes, counter=0) + return (output) +end + +# Inner loop for blake2s. blake2s_ptr points to the middle of an instance: after the initial state, +# before the message. +func blake2s_inner{range_check_ptr, blake2s_ptr : felt*}( + data : felt*, n_bytes : felt, counter : felt) -> (output : felt*): + alloc_locals + let (is_last_block) = is_le(n_bytes, INPUT_BLOCK_BYTES) + if is_last_block != 0: + return blake2s_last_block(data=data, n_bytes=n_bytes, counter=counter) + end + + memcpy(blake2s_ptr, data, INPUT_BLOCK_FELTS) + let blake2s_ptr = blake2s_ptr + INPUT_BLOCK_FELTS + + assert blake2s_ptr[0] = counter + INPUT_BLOCK_BYTES # n_bytes. + assert blake2s_ptr[1] = 0 # Is last byte = False. + let blake2s_ptr = blake2s_ptr + 2 + + # Write output. + let output = blake2s_ptr + %{ + from starkware.cairo.common.cairo_blake2s.blake2s_utils import compute_blake2s_func + compute_blake2s_func(segments=segments, output_ptr=ids.output) + %} + let blake2s_ptr = blake2s_ptr + STATE_SIZE_FELTS + + # Write the current output to the input state for the next instance. + memcpy(blake2s_ptr, output, STATE_SIZE_FELTS) + let blake2s_ptr = blake2s_ptr + STATE_SIZE_FELTS + return blake2s_inner( + data=data + INPUT_BLOCK_FELTS, + n_bytes=n_bytes - INPUT_BLOCK_BYTES, + counter=counter + INPUT_BLOCK_BYTES) +end + +func blake2s_last_block{range_check_ptr, blake2s_ptr : felt*}( + data : felt*, n_bytes : felt, counter : felt) -> (output : felt*): + alloc_locals + let (n_felts, _) = unsigned_div_rem(n_bytes + 3, 4) + memcpy(blake2s_ptr, data, n_felts) + memset(blake2s_ptr + n_felts, 0, INPUT_BLOCK_FELTS - n_felts) + let blake2s_ptr = blake2s_ptr + INPUT_BLOCK_FELTS + + assert blake2s_ptr[0] = counter + n_bytes # n_bytes. + assert blake2s_ptr[1] = 0xffffffff # Is last byte = True. + let blake2s_ptr = blake2s_ptr + 2 + + # Write output. + let output = blake2s_ptr + %{ + from starkware.cairo.common.cairo_blake2s.blake2s_utils import compute_blake2s_func + compute_blake2s_func(segments=segments, output_ptr=ids.output) + %} + let blake2s_ptr = blake2s_ptr + STATE_SIZE_FELTS + + return (output=output) +end + +# Verifies that the results of blake2s() are valid. +func finalize_blake2s{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}( + blake2s_ptr_start : felt*, blake2s_ptr_end : felt*): + alloc_locals + + let (__fp__, _) = get_fp_and_pc() + + let (sigma) = _get_sigma() + + tempvar n = (blake2s_ptr_end - blake2s_ptr_start) / INSTANCE_SIZE + if n == 0: + return () + end + + %{ + # Add dummy pairs of input and output. + from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress + + _n_packed_instances = int(ids.N_PACKED_INSTANCES) + assert 0 <= _n_packed_instances < 20 + _blake2s_input_chunk_size_felts = int(ids.INPUT_BLOCK_FELTS) + assert 0 <= _blake2s_input_chunk_size_felts < 100 + + message = [0] * _blake2s_input_chunk_size_felts + modified_iv = [IV[0] ^ 0x01010020] + IV[1:] + output = blake2s_compress( + message=message, + h=modified_iv, + t0=0, + t1=0, + f0=0xffffffff, + f1=0, + ) + padding = (modified_iv + message + [0, 0xffffffff] + output) * (_n_packed_instances - 1) + segments.write_arg(ids.blake2s_ptr_end, padding) + %} + + # Compute the amount of chunks (rounded up). + let (local n_chunks, _) = unsigned_div_rem(n + N_PACKED_INSTANCES - 1, N_PACKED_INSTANCES) + let blake2s_ptr = blake2s_ptr_start + _finalize_blake2s_inner{blake2s_ptr=blake2s_ptr}(n=n_chunks, sigma=sigma) + return () +end + +func _get_sigma() -> (sigma : felt*): + alloc_locals + let (sigma_address) = get_label_location(data) + return (sigma=cast(sigma_address, felt*)) + + data: + dw 0 + dw 1 + dw 2 + dw 3 + dw 4 + dw 5 + dw 6 + dw 7 + dw 8 + dw 9 + dw 10 + dw 11 + dw 12 + dw 13 + dw 14 + dw 15 + dw 14 + dw 10 + dw 4 + dw 8 + dw 9 + dw 15 + dw 13 + dw 6 + dw 1 + dw 12 + dw 0 + dw 2 + dw 11 + dw 7 + dw 5 + dw 3 + dw 11 + dw 8 + dw 12 + dw 0 + dw 5 + dw 2 + dw 15 + dw 13 + dw 10 + dw 14 + dw 3 + dw 6 + dw 7 + dw 1 + dw 9 + dw 4 + dw 7 + dw 9 + dw 3 + dw 1 + dw 13 + dw 12 + dw 11 + dw 14 + dw 2 + dw 6 + dw 5 + dw 10 + dw 4 + dw 0 + dw 15 + dw 8 + dw 9 + dw 0 + dw 5 + dw 7 + dw 2 + dw 4 + dw 10 + dw 15 + dw 14 + dw 1 + dw 11 + dw 12 + dw 6 + dw 8 + dw 3 + dw 13 + dw 2 + dw 12 + dw 6 + dw 10 + dw 0 + dw 11 + dw 8 + dw 3 + dw 4 + dw 13 + dw 7 + dw 5 + dw 15 + dw 14 + dw 1 + dw 9 + dw 12 + dw 5 + dw 1 + dw 15 + dw 14 + dw 13 + dw 4 + dw 10 + dw 0 + dw 7 + dw 6 + dw 3 + dw 9 + dw 2 + dw 8 + dw 11 + dw 13 + dw 11 + dw 7 + dw 14 + dw 12 + dw 1 + dw 3 + dw 9 + dw 5 + dw 0 + dw 15 + dw 4 + dw 8 + dw 6 + dw 2 + dw 10 + dw 6 + dw 15 + dw 14 + dw 9 + dw 11 + dw 3 + dw 0 + dw 8 + dw 12 + dw 2 + dw 13 + dw 7 + dw 1 + dw 4 + dw 10 + dw 5 + dw 10 + dw 2 + dw 8 + dw 4 + dw 7 + dw 6 + dw 1 + dw 5 + dw 15 + dw 11 + dw 9 + dw 14 + dw 3 + dw 12 + dw 13 + dw 0 +end + +# Handles n chunks of N_PACKED_INSTANCES blake2s instances. +func _finalize_blake2s_inner{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, blake2s_ptr : felt*}( + n : felt, sigma : felt*): + if n == 0: + return () + end + + alloc_locals + let blake2s_start = blake2s_ptr + + # Load instance data. + let (local data : felt*) = alloc() + _pack_ints(INSTANCE_SIZE, data) + + let input_state : felt* = data + let message : felt* = input_state + STATE_SIZE_FELTS + let t0_and_f0 : felt* = message + INPUT_BLOCK_FELTS + let output_state : felt* = t0_and_f0 + 2 + + # Run blake2s on N_PACKED_INSTANCES instances. + blake2s_compress( + h=input_state, + message=message, + t0=t0_and_f0[0], + f0=t0_and_f0[1], + sigma=sigma, + output=output_state) + let blake2s_ptr = blake2s_start + INSTANCE_SIZE * N_PACKED_INSTANCES + + return _finalize_blake2s_inner(n=n - 1, sigma=sigma) +end + +# Given N_PACKED_INSTANCES sets of m (32-bit) integers in the blake2s implicit argument, +# where each set starts at offset INSTANCE_SIZE from the previous set, +# computes m packed integers. +# blake2s_ptr is advanced m steps (just after the first set). +func _pack_ints{range_check_ptr, blake2s_ptr : felt*}(m, packed_values : felt*): + static_assert N_PACKED_INSTANCES == 7 + alloc_locals + + local MAX_VALUE = 2 ** 32 - 1 + + tempvar packed_values = packed_values + tempvar blake2s_ptr = blake2s_ptr + tempvar range_check_ptr = range_check_ptr + tempvar m = m + + loop: + tempvar x0 = blake2s_ptr[0 * INSTANCE_SIZE] + assert [range_check_ptr + 0] = x0 + assert [range_check_ptr + 1] = MAX_VALUE - x0 + tempvar x1 = blake2s_ptr[1 * INSTANCE_SIZE] + assert [range_check_ptr + 2] = x1 + assert [range_check_ptr + 3] = MAX_VALUE - x1 + tempvar x2 = blake2s_ptr[2 * INSTANCE_SIZE] + assert [range_check_ptr + 4] = x2 + assert [range_check_ptr + 5] = MAX_VALUE - x2 + tempvar x3 = blake2s_ptr[3 * INSTANCE_SIZE] + assert [range_check_ptr + 6] = x3 + assert [range_check_ptr + 7] = MAX_VALUE - x3 + tempvar x4 = blake2s_ptr[4 * INSTANCE_SIZE] + assert [range_check_ptr + 8] = x4 + assert [range_check_ptr + 9] = MAX_VALUE - x4 + tempvar x5 = blake2s_ptr[5 * INSTANCE_SIZE] + assert [range_check_ptr + 10] = x5 + assert [range_check_ptr + 11] = MAX_VALUE - x5 + tempvar x6 = blake2s_ptr[6 * INSTANCE_SIZE] + assert [range_check_ptr + 12] = x6 + assert [range_check_ptr + 13] = MAX_VALUE - x6 + assert packed_values[0] = x0 + 2 ** 35 * x1 + 2 ** (35 * 2) * x2 + 2 ** (35 * 3) * x3 + + 2 ** (35 * 4) * x4 + 2 ** (35 * 5) * x5 + 2 ** (35 * 6) * x6 + + tempvar packed_values = packed_values + 1 + tempvar blake2s_ptr = blake2s_ptr + 1 + tempvar range_check_ptr = range_check_ptr + 14 + tempvar m = m - 1 + jmp loop if m != 0 + + return () +end diff --git a/src/starkware/cairo/common/cairo_blake2s/blake2s_test.cairo b/src/starkware/cairo/common/cairo_blake2s/blake2s_test.cairo new file mode 100644 index 00000000..e4e5a084 --- /dev/null +++ b/src/starkware/cairo/common/cairo_blake2s/blake2s_test.cairo @@ -0,0 +1,25 @@ +%builtins range_check bitwise + +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_blake2s.blake2s import INSTANCE_SIZE, blake2s, finalize_blake2s +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin + +func run_blake2s{range_check_ptr, blake2s_ptr : felt*}(inputs : felt**, lengths : felt*, n : felt): + if n == 0: + return () + end + + blake2s(inputs[0], lengths[0]) + return run_blake2s(inputs + 1, lengths + 1, n - 1) +end + +func run_blake2s_and_finalize{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}( + inputs : felt**, lengths : felt*, n : felt): + alloc_locals + let (local blake2s_ptr_start) = alloc() + let blake2s_ptr = blake2s_ptr_start + + run_blake2s{blake2s_ptr=blake2s_ptr}(inputs, lengths, n) + finalize_blake2s(blake2s_ptr_start=blake2s_ptr_start, blake2s_ptr_end=blake2s_ptr) + return () +end diff --git a/src/starkware/cairo/common/cairo_blake2s/blake2s_test.py b/src/starkware/cairo/common/cairo_blake2s/blake2s_test.py new file mode 100644 index 00000000..a3f06922 --- /dev/null +++ b/src/starkware/cairo/common/cairo_blake2s/blake2s_test.py @@ -0,0 +1,191 @@ +import hashlib +import itertools +import os +import random +from typing import List, Sequence + +import pytest + +from starkware.cairo.common.cairo_blake2s.blake2s_utils import ( + IV, + SIGMA, + blake2s_compress, + blake_round, +) +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.lang.builtins.bitwise.instance_def import CELLS_PER_BITWISE +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.python.math_utils import div_ceil, safe_div +from starkware.python.utils import blockify + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), "blake2s_test.cairo") + + +@pytest.fixture(scope="session") +def program(): + return compile_cairo_files([CAIRO_FILE], prime=DEFAULT_PRIME, debug_info=True) + + +def test_blake_round(program): + runner = CairoFunctionRunner(program, layout="all") + + state = [random.randrange(0, 2 ** 32) for i in range(16)] + message = [random.randrange(0, 2 ** 32) for i in range(16)] + sigma = list(range(16)) + random.shuffle(sigma) + runner.run( + "starkware.cairo.common.cairo_blake2s.packed_blake2s.blake_round", + runner.bitwise_builtin.base, + state, + message, + sigma, + use_full_name=True, + ) + bitwise_builtin_end, new_state_ptr = runner.get_return_values(2) + assert bitwise_builtin_end == runner.bitwise_builtin.base + 96 * CELLS_PER_BITWISE + new_state = runner.memory.get_range(new_state_ptr, 16) + expected_new_state = blake_round(state=state, message=message, sigma=sigma) + assert new_state == expected_new_state + + print(f"Number of steps: {runner.vm.current_step}.") + + +def test_compress(program): + N_INSTANCES = 7 + SHIFT = 35 + + runner = CairoFunctionRunner(program, layout="all") + + h = [[random.randrange(0, 2 ** 32) for _ in range(8)] for _ in range(N_INSTANCES)] + message = [[random.randrange(0, 2 ** 32) for _ in range(16)] for _ in range(N_INSTANCES)] + t0 = [random.randrange(0, 2 ** 32) for _ in range(N_INSTANCES)] + f0 = [random.randrange(0, 2 ** 32) for _ in range(N_INSTANCES)] + + def pack_value(values: Sequence[int]) -> int: + assert len(values) == N_INSTANCES + return sum(val * 2 ** (SHIFT * i) for i, val in enumerate(values)) + + def pack_array(lists: Sequence[Sequence[int]]) -> List[int]: + return [pack_value(values) for values in zip(*lists)] + + new_h_ptr = runner.segments.add() + runner.run( + "starkware.cairo.common.cairo_blake2s.packed_blake2s.blake2s_compress", + runner.bitwise_builtin.base, + pack_array(h), + pack_array(message), + pack_value(t0), + pack_value(f0), + list(itertools.chain(*SIGMA)), + new_h_ptr, + use_full_name=True, + ) + (bitwise_builtin_end,) = runner.get_return_values(1) + assert bitwise_builtin_end == runner.bitwise_builtin.base + 978 * CELLS_PER_BITWISE + new_h = [x for x in runner.memory.get_range(new_h_ptr, 8)] + expected_new_h_list = [ + blake2s_compress(h=h[i], message=message[i], t0=t0[i], t1=0, f0=f0[i], f1=0) + for i in range(N_INSTANCES) + ] + assert new_h == pack_array(expected_new_h_list) + + print(f"Number of steps: {runner.vm.current_step}.") + + +@pytest.mark.parametrize("n_bytes", list(range(70)) + [100, 200, 255, 256, 257]) +def test_blake2s(program, n_bytes): + runner = CairoFunctionRunner(program, layout="all") + + value = bytes([random.randrange(256) for i in range(n_bytes)]) + value_words = [int.from_bytes(x, "little") for x in blockify(value, 4)] + blake2s_ptr = runner.segments.add() + runner.run( + "blake2s", + range_check_ptr=runner.range_check_builtin.base, + blake2s_ptr=blake2s_ptr, + data=value_words, + n_bytes=n_bytes, + ) + range_check_builtin_end, blake2s_ptr_end, output = runner.get_return_values(3) + assert range_check_builtin_end.segment_index == runner.range_check_builtin.base.segment_index + + n_instances = max(1, div_ceil(n_bytes, 64)) + INSTANCE_SIZE = program.get_const("INSTANCE_SIZE") + assert blake2s_ptr_end == blake2s_ptr + INSTANCE_SIZE * n_instances + + h = IV[:] + h[0] ^= 0x01010020 + for i in range(n_instances): + message = (value_words[i * 16 : (i + 1) * 16] + [0] * 16)[:16] + t = min((i + 1) * 64, n_bytes) + f = 0xFFFFFFFF if n_bytes <= 64 * (i + 1) else 0 + next_state = blake2s_compress(h=h, message=message, t0=t, t1=0, f0=f, f1=0) + assert runner.memory.get_range(blake2s_ptr + i * INSTANCE_SIZE, INSTANCE_SIZE) == ( + h + message + [t, f] + next_state + ) + h = next_state + + output = "".join(x.to_bytes(4, "little").hex() for x in runner.memory.get_range(output, 8)) + expected_output = hashlib.blake2s(value).hexdigest() + assert expected_output == output + + +@pytest.mark.parametrize("n", [0, 1, 6, 7, 8, 13, 14, 15]) +def test_finalize_blake2s(program, n): + random.seed(0) + runner = CairoFunctionRunner(program, layout="all") + + values = [] + for _ in range(n): + h = [random.randrange(0, 2 ** 32) for _ in range(8)] + message = [random.randrange(0, 2 ** 32) for _ in range(16)] + t0 = random.randrange(0, 2 ** 32) + f0 = random.randrange(0, 2 ** 32) + output = blake2s_compress(h=h, message=message, t0=t0, t1=0, f0=f0, f1=0) + assert len(output) == 8 + values += h + message + [t0, f0] + output + + values_ptr = runner.segments.gen_arg(values) + runner.run( + "finalize_blake2s", + runner.range_check_builtin.base, + runner.bitwise_builtin.base, + blake2s_ptr_start=values_ptr, + blake2s_ptr_end=values_ptr + len(values), + ) + range_check_builtin_end, bitwise_ptr_end = runner.get_return_values(2) + assert range_check_builtin_end.segment_index == runner.range_check_builtin.base.segment_index + n_bitwise = safe_div(bitwise_ptr_end - runner.bitwise_builtin.base, CELLS_PER_BITWISE) + n_packed_instances = div_ceil(n, 7) + assert n_bitwise == n_packed_instances * 978 + print("Steps:", runner.vm.current_step) + print("Estimated trace cells:", 50 * runner.vm.current_step + 300 * n_bitwise) + + +@pytest.mark.parametrize("n", [0, 1, 7]) +def test_run_and_finalize_blake2s(program, n): + random.seed(0) + runner = CairoFunctionRunner(program, layout="all") + + values = [ + bytes([random.randrange(256) for _ in range(random.randrange(257))]) for _ in range(n) + ] + n_expected_instances = sum(max(1, div_ceil(len(x), 64)) for x in values) + n_expected_packed_instances = div_ceil(n_expected_instances, 7) + + runner.run( + "run_blake2s_and_finalize", + runner.range_check_builtin.base, + runner.bitwise_builtin.base, + [[int.from_bytes(x, "little") for x in blockify(value, 4)] for value in values], + list(map(len, values)), + n, + ) + range_check_builtin_end, bitwise_ptr_end = runner.get_return_values(2) + assert range_check_builtin_end.segment_index == runner.range_check_builtin.base.segment_index + n_bitwise = safe_div(bitwise_ptr_end - runner.bitwise_builtin.base, CELLS_PER_BITWISE) + assert n_bitwise == n_expected_packed_instances * 978 + + print("Steps:", runner.vm.current_step) + print("Estimated trace cells:", 50 * runner.vm.current_step + 300 * n_bitwise) diff --git a/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py b/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py index 8c4b0507..e166ed7a 100644 --- a/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py +++ b/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py @@ -1,5 +1,8 @@ from typing import List, Tuple +from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager +from starkware.cairo.lang.vm.relocatable import RelocatableValue + IV = [ 0x6A09E667, 0xBB67AE85, @@ -29,6 +32,27 @@ def right_rot(value, n): return (value >> n) | ((value & (2 ** n - 1)) << (32 - n)) +# Helper function for the Cairo blake2s() implementation. +# Computes the blake2s compress function and fills the value in the right position. +# output_ptr should point to the middle of an instance, right after initial_state, message, t, f, +# which should all have a value at this point, and right before the output portion which will be +# written by this function. +def compute_blake2s_func(segments: MemorySegmentManager, output_ptr: RelocatableValue): + h = segments.memory.get_range(output_ptr - 26, 8) + message = segments.memory.get_range(output_ptr - 18, 16) + t = segments.memory[output_ptr - 2] + f = segments.memory[output_ptr - 1] + new_state = blake2s_compress( + message=message, + h=h, + t0=t, + t1=0, + f0=f, + f1=0, + ) + segments.write_arg(output_ptr, new_state) + + def blake2s_compress( h: List[int], message: List[int], t0: int, t1: int, f0: int, f1: int ) -> List[int]: diff --git a/src/starkware/cairo/common/cairo_blake2s/packed_blake2s.cairo b/src/starkware/cairo/common/cairo_blake2s/packed_blake2s.cairo new file mode 100644 index 00000000..41369610 --- /dev/null +++ b/src/starkware/cairo/common/cairo_blake2s/packed_blake2s.cairo @@ -0,0 +1,212 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.registers import get_fp_and_pc + +const N_PACKED_INSTANCES = 7 +const ALL_ONES = 2 ** 251 - 1 +const SHIFTS = 1 + 2 ** 35 + 2 ** (35 * 2) + 2 ** (35 * 3) + 2 ** (35 * 4) + 2 ** (35 * 5) + + 2 ** (35 * 6) + +func mix{bitwise_ptr : BitwiseBuiltin*}( + a : felt, b : felt, c : felt, d : felt, m0 : felt, m1 : felt) -> ( + a : felt, b : felt, c : felt, d : felt): + alloc_locals + + # Defining the following constant as local variables saves some instructions. + local mask32ones = SHIFTS * (2 ** 32 - 1) + + # a = (a + b + m0) % 2**32 + assert bitwise_ptr[0].x = a + b + m0 + assert bitwise_ptr[0].y = mask32ones + tempvar a = bitwise_ptr[0].x_and_y + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE + + # d = right_rot((d ^ a), 16) + assert bitwise_ptr[0].x = a + assert bitwise_ptr[0].y = d + tempvar a_xor_d = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = a_xor_d + assert bitwise_ptr[1].y = SHIFTS * (2 ** 32 - 2 ** 16) + tempvar d = (2 ** (32 - 16)) * a_xor_d + (1 / 2 ** 16 - 2 ** (32 - 16)) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + # c = (c + d) % 2**32 + assert bitwise_ptr[0].x = c + d + assert bitwise_ptr[0].y = mask32ones + tempvar c = bitwise_ptr[0].x_and_y + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE + + # b = right_rot((b ^ c), 12) + assert bitwise_ptr[0].x = b + assert bitwise_ptr[0].y = c + tempvar b_xor_c = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = b_xor_c + assert bitwise_ptr[1].y = SHIFTS * (2 ** 32 - 2 ** 12) + tempvar b = (2 ** (32 - 12)) * b_xor_c + (1 / 2 ** 12 - 2 ** (32 - 12)) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + # a = (a + b + m1) % 2**32 + assert bitwise_ptr[0].x = a + b + m1 + assert bitwise_ptr[0].y = mask32ones + tempvar a = bitwise_ptr[0].x_and_y + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE + + # d = right_rot((d ^ a), 8) + assert bitwise_ptr[0].x = d + assert bitwise_ptr[0].y = a + tempvar d_xor_a = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = d_xor_a + assert bitwise_ptr[1].y = SHIFTS * (2 ** 32 - 2 ** 8) + tempvar d = (2 ** (32 - 8)) * d_xor_a + (1 / 2 ** 8 - 2 ** (32 - 8)) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + # c = (c + d) % 2**32 + assert bitwise_ptr[0].x = c + d + assert bitwise_ptr[0].y = mask32ones + tempvar c = bitwise_ptr[0].x_and_y + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE + + # b = right_rot((b ^ c), 7) + assert bitwise_ptr[0].x = b + assert bitwise_ptr[0].y = c + tempvar b_xor_c = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = b_xor_c + assert bitwise_ptr[1].y = SHIFTS * (2 ** 32 - 2 ** 7) + tempvar b = (2 ** (32 - 7)) * b_xor_c + (1 / 2 ** 7 - 2 ** (32 - 7)) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + return (a, b, c, d) +end + +func blake_round{bitwise_ptr : BitwiseBuiltin*}(state : felt*, message : felt*, sigma : felt*) -> ( + new_state : felt*): + let state0 = state[0] + let state1 = state[1] + let state2 = state[2] + let state3 = state[3] + let state4 = state[4] + let state5 = state[5] + let state6 = state[6] + let state7 = state[7] + let state8 = state[8] + let state9 = state[9] + let state10 = state[10] + let state11 = state[11] + let state12 = state[12] + let state13 = state[13] + let state14 = state[14] + let state15 = state[15] + + let (state0, state4, state8, state12) = mix( + state0, state4, state8, state12, message[sigma[0]], message[sigma[1]]) + let (state1, state5, state9, state13) = mix( + state1, state5, state9, state13, message[sigma[2]], message[sigma[3]]) + let (state2, state6, state10, state14) = mix( + state2, state6, state10, state14, message[sigma[4]], message[sigma[5]]) + let (state3, state7, state11, state15) = mix( + state3, state7, state11, state15, message[sigma[6]], message[sigma[7]]) + + let (state0, state5, state10, state15) = mix( + state0, state5, state10, state15, message[sigma[8]], message[sigma[9]]) + let (state1, state6, state11, state12) = mix( + state1, state6, state11, state12, message[sigma[10]], message[sigma[11]]) + let (state2, state7, state8, state13) = mix( + state2, state7, state8, state13, message[sigma[12]], message[sigma[13]]) + let (state3, state4, state9, state14) = mix( + state3, state4, state9, state14, message[sigma[14]], message[sigma[15]]) + + let (new_state : felt*) = alloc() + assert new_state[0] = state0 + assert new_state[1] = state1 + assert new_state[2] = state2 + assert new_state[3] = state3 + assert new_state[4] = state4 + assert new_state[5] = state5 + assert new_state[6] = state6 + assert new_state[7] = state7 + assert new_state[8] = state8 + assert new_state[9] = state9 + assert new_state[10] = state10 + assert new_state[11] = state11 + assert new_state[12] = state12 + assert new_state[13] = state13 + assert new_state[14] = state14 + assert new_state[15] = state15 + + return (new_state) +end + +# Performs the blake compression function. +# +# h is a list of 8 32-bit words. +# message is a list of 16 32-bit words. +# t1 and f1 are assumed to be 0. +func blake2s_compress{bitwise_ptr : BitwiseBuiltin*}( + h : felt*, message : felt*, t0 : felt, f0 : felt, sigma : felt*, output : felt*): + alloc_locals + let (__fp__, _) = get_fp_and_pc() + + # Compute state[12]. + assert bitwise_ptr[0].x = 0x510e527f * SHIFTS + assert bitwise_ptr[0].y = t0 + let state12 = bitwise_ptr[0].x_xor_y + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE + + # Compute state[14]. + assert bitwise_ptr[0].x = 0x1f83d9ab * SHIFTS + assert bitwise_ptr[0].y = f0 + let state14 = bitwise_ptr[0].x_xor_y + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE + + local initial_state = h[0] + local initial_state_ = h[1] + local initial_state_ = h[2] + local initial_state_ = h[3] + local initial_state_ = h[4] + local initial_state_ = h[5] + local initial_state_ = h[6] + local initial_state_ = h[7] + local initial_state_ = 0x6a09e667 * SHIFTS + local initial_state_ = 0xbb67ae85 * SHIFTS + local initial_state_ = 0x3c6ef372 * SHIFTS + local initial_state_ = 0xa54ff53a * SHIFTS + local initial_state_ = state12 + local initial_state_ = 0x9b05688c * SHIFTS + local initial_state_ = state14 + local initial_state_ = 0x5be0cd19 * SHIFTS + + let state = &initial_state + + let (state) = blake_round(state, message, sigma + 16 * 0) + let (state) = blake_round(state, message, sigma + 16 * 1) + let (state) = blake_round(state, message, sigma + 16 * 2) + let (state) = blake_round(state, message, sigma + 16 * 3) + let (state) = blake_round(state, message, sigma + 16 * 4) + let (state) = blake_round(state, message, sigma + 16 * 5) + let (state) = blake_round(state, message, sigma + 16 * 6) + let (state) = blake_round(state, message, sigma + 16 * 7) + let (state) = blake_round(state, message, sigma + 16 * 8) + let (state) = blake_round(state, message, sigma + 16 * 9) + + tempvar old_h = h + tempvar last_state = state + tempvar new_h = output + tempvar bitwise_ptr = bitwise_ptr + tempvar n = 8 + + loop: + assert bitwise_ptr[0].x = old_h[0] + assert bitwise_ptr[0].y = last_state[0] + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].y = last_state[8] + assert new_h[0] = bitwise_ptr[1].x_xor_y + + tempvar old_h = old_h + 1 + tempvar last_state = last_state + 1 + tempvar new_h = new_h + 1 + tempvar bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + tempvar n = n - 1 + jmp loop if n != 0 + + return () +end diff --git a/src/starkware/cairo/common/cairo_function_runner.py b/src/starkware/cairo/common/cairo_function_runner.py index 42ca7f00..6b4e9f8e 100644 --- a/src/starkware/cairo/common/cairo_function_runner.py +++ b/src/starkware/cairo/common/cairo_function_runner.py @@ -32,7 +32,7 @@ def __init__(self, *args, **kwargs): ) self.builtin_runners["pedersen_builtin"] = pedersen_builtin range_check_builtin = RangeCheckBuiltinRunner( - included=True, ratio=None, inner_rc_bound=2 ** 16, n_parts=8 + included=True, ratio=1, inner_rc_bound=2 ** 16, n_parts=8 ) self.builtin_runners["range_check_builtin"] = range_check_builtin output_builtin = OutputBuiltinRunner(included=True) @@ -40,19 +40,19 @@ def __init__(self, *args, **kwargs): signature_builtin = SignatureBuiltinRunner( name="ecdsa", included=True, - ratio=None, + ratio=1, process_signature=process_ecdsa, verify_signature=verify_ecdsa_sig, ) self.builtin_runners["ecdsa_builtin"] = signature_builtin bitwise_builtin = BitwiseBuiltinRunner( - included=True, bitwise_builtin=BitwiseInstanceDef(ratio=None, total_n_bits=251) + included=True, bitwise_builtin=BitwiseInstanceDef(ratio=1, total_n_bits=251) ) self.builtin_runners["bitwise_builtin"] = bitwise_builtin ec_op_builtin = EcOpBuiltinRunner( included=True, ec_op_builtin=EcOpInstanceDef( - ratio=None, + ratio=1, scalar_height=256, scalar_bits=252, scalar_limit=None, @@ -140,7 +140,8 @@ def run( try: self.run_from_entrypoint( entrypoint, - *all_args, + all_args, + typed_args=True, hint_locals=hint_locals, static_locals=static_locals, verify_secure=verify_secure, @@ -161,6 +162,7 @@ def run_from_entrypoint( self, entrypoint: Union[str, int], *args, + typed_args: Optional[bool] = False, hint_locals: Optional[Dict[str, Any]] = None, static_locals: Optional[Dict[str, Any]] = None, run_resources: Optional[RunResources] = None, @@ -171,6 +173,8 @@ def run_from_entrypoint( Runs the program from the given entrypoint. Additional params: + typed_args - If true, the arguments are given as Cairo typed NamedTuple generated + with CairoStructFactory. verify_secure - Run verify_secure_runner to do extra verifications. apply_modulo_to_args - Apply modulo operation on integer arguments. """ @@ -183,7 +187,13 @@ def run_from_entrypoint( if apply_modulo_to_args is None: apply_modulo_to_args = True - real_args = [self.gen_arg(arg=x, apply_modulo_to_args=apply_modulo_to_args) for x in args] + if typed_args: + assert len(args) == 1, "len(args) must be 1 when using typed args." + real_args = self.segments.gen_typed_args(args=args[0]) + else: + real_args = [ + self.gen_arg(arg=x, apply_modulo_to_args=apply_modulo_to_args) for x in args + ] end = self.initialize_function_entrypoint(entrypoint=entrypoint, args=real_args) self.initialize_vm(hint_locals=hint_locals, static_locals=static_locals) diff --git a/src/starkware/cairo/common/cairo_secp/secp_utils.py b/src/starkware/cairo/common/cairo_secp/secp_utils.py index 40dd67ca..8ffe1af2 100644 --- a/src/starkware/cairo/common/cairo_secp/secp_utils.py +++ b/src/starkware/cairo/common/cairo_secp/secp_utils.py @@ -2,6 +2,7 @@ from starkware.cairo.common.math_utils import as_int +BASE = 2 ** 86 SECP_P = 2 ** 256 - 2 ** 32 - 2 ** 9 - 2 ** 8 - 2 ** 7 - 2 ** 6 - 2 ** 4 - 1 @@ -11,7 +12,6 @@ def split(num: int) -> List[int]: d0 + BASE * d1 + BASE**2 * d2, where BASE = 2**86. """ - BASE = 2 ** 86 a = [] for _ in range(3): num, residue = divmod(num, BASE) @@ -22,10 +22,10 @@ def split(num: int) -> List[int]: def pack(z, prime): """ - Takes a BigInt3 struct which represents a triple of limbs (d0, d1, d2) of field elements are - reconstruct the 256-bit integer (see split()). + Takes an UnreducedBigInt3 struct which represents a triple of limbs (d0, d1, d2) of field + elements and reconstructs the corresponding 256-bit integer (see split()). Note that the limbs do not have to be in the range [0, BASE). prime should be the Cairo field, and it is used to handle negative values of the limbs. """ limbs = z.d0, z.d1, z.d2 - return sum(as_int(limb, prime) * 2 ** (86 * i) for i, limb in enumerate(limbs)) + return sum(as_int(limb, prime) * (BASE ** i) for i, limb in enumerate(limbs)) diff --git a/src/starkware/cairo/common/dict.cairo b/src/starkware/cairo/common/dict.cairo index 679e3668..194aec83 100644 --- a/src/starkware/cairo/common/dict.cairo +++ b/src/starkware/cairo/common/dict.cairo @@ -2,6 +2,9 @@ from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.squash_dict import squash_dict # Creates a new dict. +# +# Hint argument: +# initial_dict - A python dict containing the initial values of the new dict. func dict_new() -> (res : DictAccess*): %{ if '__dict_manager' not in globals(): @@ -24,9 +27,9 @@ func dict_read{dict_ptr : DictAccess*}(key : felt) -> (value : felt): dict_tracker.current_ptr += ids.DictAccess.SIZE ids.value = dict_tracker.data[ids.key] %} - assert dict_ptr.key = key - assert dict_ptr.prev_value = value - assert dict_ptr.new_value = value + dict_ptr.key = key + dict_ptr.prev_value = value + dict_ptr.new_value = value let dict_ptr = dict_ptr + DictAccess.SIZE return (value=value) end @@ -39,8 +42,8 @@ func dict_write{dict_ptr : DictAccess*}(key : felt, new_value : felt): ids.dict_ptr.prev_value = dict_tracker.data[ids.key] dict_tracker.data[ids.key] = ids.new_value %} - assert dict_ptr.key = key - assert dict_ptr.new_value = new_value + dict_ptr.key = key + dict_ptr.new_value = new_value let dict_ptr = dict_ptr + DictAccess.SIZE return () end diff --git a/src/starkware/cairo/common/dict_access.cairo b/src/starkware/cairo/common/dict_access.cairo index 92d41780..b1625804 100644 --- a/src/starkware/cairo/common/dict_access.cairo +++ b/src/starkware/cairo/common/dict_access.cairo @@ -1,3 +1,8 @@ +# Represents an access (read, write or modify) to the dictionary. The dictionary is represented as a +# chronological list of such accesses. The "current" value of a key is the new_value of the last +# access with that key. +# In a valid dictionary, the prev_value of each access is equal to the new_value of the previous +# access to the same key. struct DictAccess: member key : felt member prev_value : felt diff --git a/src/starkware/cairo/common/find_element.cairo b/src/starkware/cairo/common/find_element.cairo index c00c7a63..5d9d2a1b 100644 --- a/src/starkware/cairo/common/find_element.cairo +++ b/src/starkware/cairo/common/find_element.cairo @@ -22,7 +22,7 @@ const FIND_ELEMENT_RANGE_CHECK_USAGE = 2 # # Optional hint variables: # __find_element_index - the index that should be returned. If not specified, the function will -# search for it. +# return the first index that has the key. func find_element{range_check_ptr}(array_ptr : felt*, elm_size, n_elms, key) -> (elm_ptr : felt*): alloc_locals local index diff --git a/src/starkware/cairo/common/hash_chain.cairo b/src/starkware/cairo/common/hash_chain.cairo index 9078666c..b2901a1f 100644 --- a/src/starkware/cairo/common/hash_chain.cairo +++ b/src/starkware/cairo/common/hash_chain.cairo @@ -12,37 +12,34 @@ func hash_chain{hash_ptr : HashBuiltin*}(data_ptr : felt*) -> (hash : felt): member cur_hash : felt end - let data_length = ap - [data_length] = [data_ptr]; ap++ - let loop_frame = cast(ap, LoopLocals*) - + tempvar data_length = [data_ptr] + tempvar data_ptr_end = data_ptr + data_length # Prepare the loop_frame for the first iteration of the hash_loop. - loop_frame.data_ptr = data_ptr + [data_length]; ap++ - loop_frame.hash_ptr = hash_ptr; ap++ - loop_frame.cur_hash = [loop_frame.data_ptr]; ap++ + tempvar loop_frame = LoopLocals( + data_ptr=data_ptr_end, + hash_ptr=hash_ptr, + cur_hash=[data_ptr_end]) hash_loop: let curr_frame = cast(ap - LoopLocals.SIZE, LoopLocals*) let current_hash : HashBuiltin* = curr_frame.hash_ptr - let new_data_ptr = curr_frame.data_ptr - 1 - let new_data = ap - [new_data] = [new_data_ptr]; ap++ + tempvar new_data = [curr_frame.data_ptr - 1] - let n_elements_to_hash = ap + let n_elements_to_hash = [ap] # Assign current_hash inputs and allocate space for n_elements_to_hash. - [new_data] = current_hash.x; ap++ - curr_frame.cur_hash = current_hash.y + current_hash.x = new_data; ap++ + current_hash.y = curr_frame.cur_hash # Set the frame for the next loop iteration (going backwards). - let next_frame = cast(ap, LoopLocals*) - next_frame.data_ptr = new_data_ptr; ap++ - next_frame.hash_ptr = curr_frame.hash_ptr + HashBuiltin.SIZE; ap++ - next_frame.cur_hash = current_hash.result; ap++ + tempvar next_frame = LoopLocals( + data_ptr=curr_frame.data_ptr - 1, + hash_ptr=curr_frame.hash_ptr + HashBuiltin.SIZE, + cur_hash=current_hash.result) # Update n_elements_to_hash and loop accordingly. Note that the hash is calculated backwards. - [n_elements_to_hash] = next_frame.data_ptr - data_ptr - jmp hash_loop if [n_elements_to_hash] != 0 + n_elements_to_hash = next_frame.data_ptr - data_ptr + jmp hash_loop if n_elements_to_hash != 0 # Set the hash_ptr implicit argument and return the result. let hash_ptr = next_frame.hash_ptr diff --git a/src/starkware/cairo/common/hash_state.cairo b/src/starkware/cairo/common/hash_state.cairo index 3500fde8..8c1d0f13 100644 --- a/src/starkware/cairo/common/hash_state.cairo +++ b/src/starkware/cairo/common/hash_state.cairo @@ -12,7 +12,7 @@ struct HashState: member n_words : felt end -# Initializes a new HashState with no items. +# Initializes a new HashState with no items and returns it. func hash_init() -> (hash_state_ptr : HashState*): alloc_locals let (__fp__, _) = get_fp_and_pc() @@ -22,14 +22,50 @@ func hash_init() -> (hash_state_ptr : HashState*): return (hash_state_ptr=&hash_state) end +# Adds each item in an array of items to the HashState. +# Returns a new HashState with the hash of the items of the input HashState and the array of items. +# The array is represented by a pointer and a length. +func hash_update{hash_ptr : HashBuiltin*}( + hash_state_ptr : HashState*, data_ptr : felt*, data_length) -> ( + new_hash_state_ptr : HashState*): + alloc_locals + let (hash) = hash_update_inner( + data_ptr=data_ptr, data_length=data_length, hash=hash_state_ptr.current_hash) + let (__fp__, _) = get_fp_and_pc() + local new_hash_state : HashState + new_hash_state.current_hash = hash + assert new_hash_state.n_words = hash_state_ptr.n_words + data_length + return (new_hash_state_ptr=&new_hash_state) +end + +# Adds a single item to the HashState. +# Returns a new HashState with the hash of the items of the input HashState and the item. +func hash_update_single{hash_ptr : HashBuiltin*}(hash_state_ptr : HashState*, item) -> ( + new_hash_state_ptr : HashState*): + alloc_locals + let (hash) = hash2(x=hash_state_ptr.current_hash, y=item) + let (__fp__, _) = get_fp_and_pc() + local new_hash_state : HashState + new_hash_state.current_hash = hash + assert new_hash_state.n_words = hash_state_ptr.n_words + 1 + return (new_hash_state_ptr=&new_hash_state) +end + +# Returns the hash result of the HashState. +func hash_finalize{hash_ptr : HashBuiltin*}(hash_state_ptr : HashState*) -> (hash): + return hash2(x=hash_state_ptr.current_hash, y=hash_state_ptr.n_words) +end + # A helper function for 'hash_update', see its documentation. # Computes the hash of an array of items, not including its length. +# The hash is: hash(...hash(hash(data[0], data[1]), data[2])..., data[n-1]). func hash_update_inner{hash_ptr : HashBuiltin*}( data_ptr : felt*, data_length : felt, hash : felt) -> (hash : felt): if data_length == 0: return (hash=hash) end + # Compute 'data_last_ptr' before entering the loop. alloc_locals local data_last_ptr : felt* = data_ptr + data_length - 1 struct LoopLocals: @@ -67,35 +103,3 @@ func hash_update_inner{hash_ptr : HashBuiltin*}( let hash_ptr = final_locals.hash_ptr return (hash=final_locals.cur_hash) end - -# Adds each item in an array of items to the HashState. -# The array is represented by a pointer and a length. -func hash_update{hash_ptr : HashBuiltin*}( - hash_state_ptr : HashState*, data_ptr : felt*, data_length) -> ( - new_hash_state_ptr : HashState*): - alloc_locals - let (hash) = hash_update_inner( - data_ptr=data_ptr, data_length=data_length, hash=hash_state_ptr.current_hash) - let (__fp__, _) = get_fp_and_pc() - local new_hash_state : HashState - new_hash_state.current_hash = hash - assert new_hash_state.n_words = hash_state_ptr.n_words + data_length - return (new_hash_state_ptr=&new_hash_state) -end - -# Adds a single item to the HashState. -func hash_update_single{hash_ptr : HashBuiltin*}(hash_state_ptr : HashState*, item) -> ( - new_hash_state_ptr : HashState*): - alloc_locals - let (hash) = hash2(x=hash_state_ptr.current_hash, y=item) - let (__fp__, _) = get_fp_and_pc() - local new_hash_state : HashState - new_hash_state.current_hash = hash - assert new_hash_state.n_words = hash_state_ptr.n_words + 1 - return (new_hash_state_ptr=&new_hash_state) -end - -# Returns the hash result of the HashState. -func hash_finalize{hash_ptr : HashBuiltin*}(hash_state_ptr : HashState*) -> (hash): - return hash2(x=hash_state_ptr.current_hash, y=hash_state_ptr.n_words) -end diff --git a/src/starkware/cairo/common/math.cairo b/src/starkware/cairo/common/math.cairo index 2f851324..9fcab998 100644 --- a/src/starkware/cairo/common/math.cairo +++ b/src/starkware/cairo/common/math.cairo @@ -61,7 +61,10 @@ end # Verifies that 0 <= a <= b. # -# Prover assumption: a, b < RANGE_CHECK_BOUND. +# Prover assumption: b < RANGE_CHECK_BOUND. +# +# This function is still sound without the prover assumptions. In that case, it is guaranteed +# that a < RANGE_CHECK_BOUND and b < 2 * RANGE_CHECK_BOUND. func assert_nn_le{range_check_ptr}(a, b): assert_nn(a) assert_le(a, b) @@ -69,6 +72,9 @@ func assert_nn_le{range_check_ptr}(a, b): end # Asserts that value is in the range [lower, upper). +# Or more precisely: +# (0 <= value - lower < RANGE_CHECK_BOUND) and (0 <= upper - 1 - value < RANGE_CHECK_BOUND). +# # Prover assumption: 0 <= upper - lower <= RANGE_CHECK_BOUND. func assert_in_range{range_check_ptr}(value, lower, upper): assert_le(lower, value) @@ -135,7 +141,7 @@ func split_felt{range_check_ptr}(value) -> (high, low): if high == MAX_HIGH: assert_le(low, MAX_LOW) else: - assert_le(high, MAX_HIGH) + assert_le(high, MAX_HIGH - 1) end return (high=high, low=low) end @@ -269,7 +275,7 @@ func unsigned_div_rem{range_check_ptr}(value, div) -> (q, r): return (q, r) end -# Returns q and r such that. -bound <= q < bound, 0 <= r < div -1 and value = q * div + r. +# Returns q and r such that. -bound <= q < bound, 0 <= r < div and value = q * div + r. # value < PRIME / 2 is considered positive and value > PRIME / 2 is considered negative. # # Assumptions: @@ -357,3 +363,14 @@ func sqrt{range_check_ptr}(value) -> (res): return (res=root) end + +# Computes the evaluation of a polynomial on the given point. +func horner_eval(n_coefficients : felt, coefficients : felt*, point : felt) -> (res : felt): + if n_coefficients == 0: + return (res=0) + end + + let (n_minus_one_res) = horner_eval( + n_coefficients=n_coefficients - 1, coefficients=&coefficients[1], point=point) + return (res=n_minus_one_res * point + [coefficients]) +end diff --git a/src/starkware/cairo/common/math_cmp.cairo b/src/starkware/cairo/common/math_cmp.cairo index 8ecfcb58..60974dc3 100644 --- a/src/starkware/cairo/common/math_cmp.cairo +++ b/src/starkware/cairo/common/math_cmp.cairo @@ -38,8 +38,10 @@ func is_le{range_check_ptr}(a, b) -> (res): return is_nn(b - a) end -# Returns 1 of 0 <= a <= b < RANGE_CHECK_BOUND. +# Returns 1 if 0 <= a <= b < RANGE_CHECK_BOUND. # Returns 0 otherwise. +# +# Assumption: b < RANGE_CHECK_BOUND. func is_nn_le{range_check_ptr}(a, b) -> (res): let (res) = is_nn(a) if res == 0: @@ -51,7 +53,7 @@ end # Returns 1 if value is in the range [lower, upper). # Returns 0 otherwise. # Assumptions: -# upper - lower <= RC_BOUND +# upper - lower <= RANGE_CHECK_BOUND func is_in_range{range_check_ptr}(value, lower, upper) -> (res): let (res) = is_le(lower, value) if res == 0: diff --git a/src/starkware/cairo/common/merkle_multi_update.cairo b/src/starkware/cairo/common/merkle_multi_update.cairo index cde56fac..2c24e8b4 100644 --- a/src/starkware/cairo/common/merkle_multi_update.cairo +++ b/src/starkware/cairo/common/merkle_multi_update.cairo @@ -162,7 +162,9 @@ func merkle_multi_update_inner{hash_ptr : HashBuiltin*, update_ptr : DictAccess* return () update_both: - # Locals 0 and 1 are taken by non deterministic jumps. + # When the function starts we have fp=ap. + # The two nondeterministic jumps, write to [fp] and [fp + 1] and advance ap by 2. + # Thus, the next free memory cell is [fp + 2] and we need to increment ap by 1. let local_left_index = [fp + 2] %{ assert case == 'both' %} local_left_index = index * 2; ap++ @@ -180,12 +182,10 @@ func merkle_multi_update_inner{hash_ptr : HashBuiltin*, update_ptr : DictAccess* ) %} merkle_multi_update_inner( - height=height - 1, prev_root=hash0.x, new_root=hash1.x, index=index * 2) + height=height - 1, prev_root=hash0.x, new_root=hash1.x, index=local_left_index) %{ vm_exit_scope() %} # Update right. - # Push height to workaround one hint per line limitation. - tempvar height_minus_1 = height - 1 %{ vm_enter_scope(dict( node=right_child, preimage=preimage, @@ -194,7 +194,7 @@ func merkle_multi_update_inner{hash_ptr : HashBuiltin*, update_ptr : DictAccess* ) %} merkle_multi_update_inner( - height=height_minus_1, prev_root=hash0.y, new_root=hash1.y, index=local_left_index + 1) + height=height - 1, prev_root=hash0.y, new_root=hash1.y, index=local_left_index + 1) %{ vm_exit_scope() %} return () end diff --git a/src/starkware/cairo/common/merkle_update.cairo b/src/starkware/cairo/common/merkle_update.cairo index a39af85a..e781df3d 100644 --- a/src/starkware/cairo/common/merkle_update.cairo +++ b/src/starkware/cairo/common/merkle_update.cairo @@ -42,6 +42,11 @@ func merkle_update{hash_ptr : HashBuiltin*}(height, prev_leaf, new_leaf, index) [right_sibling] = new_node_hash.y; ap++ # Call merkle_update recursively. + + # Index must be even. + # We can show by induction that index is in the range [0, 2 ** height) (We know that when height + # is 0 index must be 0). If index was odd, then index / 2 is larger than PRIME / 2 which + # contradicts the fact that index / 2 is in the range [0, 2 ** (height - 1)). return merkle_update( height=height - 1, prev_leaf=prev_node_hash.result, @@ -63,6 +68,7 @@ func merkle_update{hash_ptr : HashBuiltin*}(height, prev_leaf, new_leaf, index) [left_sibling] = prev_node_hash.x [left_sibling] = new_node_hash.x; ap++ + # Similarly to the description above, index must be odd at this point. return merkle_update( height=height - 1, prev_leaf=prev_node_hash.result, diff --git a/src/starkware/cairo/common/patricia.cairo b/src/starkware/cairo/common/patricia.cairo index fc7bf260..d4271e7e 100644 --- a/src/starkware/cairo/common/patricia.cairo +++ b/src/starkware/cairo/common/patricia.cairo @@ -24,6 +24,11 @@ struct NodeEdge: member bottom : felt end +# Holds the constants needed for Patricia updates. +struct PatriciaUpdateConstants: + member globals_pow2 : felt* +end + # Given an edge node hash, opens the hash using the preimage hint, and returns a NodeEdge object. func open_edge{hash_ptr : HashBuiltin*, range_check_ptr}( globals : ParticiaGlobals*, node : felt) -> (edge : NodeEdge*): @@ -441,6 +446,29 @@ end func patricia_update{hash_ptr : HashBuiltin*, range_check_ptr}( update_ptr : DictAccess*, n_updates : felt, height : felt, prev_root : felt, new_root : felt): + let (patricia_update_constants : PatriciaUpdateConstants*) = patricia_update_constants_new() + patricia_update_using_update_constants( + patricia_update_constants=patricia_update_constants, + update_ptr=update_ptr, + n_updates=n_updates, + height=height, + prev_root=prev_root, + new_root=new_root) + + return () +end + +func patricia_update_constants_new() -> (patricia_update_constants : PatriciaUpdateConstants*): + # Compute power-of-2 array for patricia updates. + alloc_locals + let (local globals_pow2 : felt*) = alloc() + compute_pow2_array(pow2_ptr=globals_pow2, cur=1, n=MAX_LENGTH + 1) + return (patricia_update_constants=new PatriciaUpdateConstants(globals_pow2=globals_pow2)) +end + +func patricia_update_using_update_constants{hash_ptr : HashBuiltin*, range_check_ptr}( + patricia_update_constants : PatriciaUpdateConstants*, update_ptr : DictAccess*, + n_updates : felt, height : felt, prev_root : felt, new_root : felt): if n_updates == 0: prev_root = new_root return () @@ -476,16 +504,13 @@ func patricia_update{hash_ptr : HashBuiltin*, range_check_ptr}( alloc_locals local update_end : DictAccess* = update_ptr + n_updates * DictAccess.SIZE - # Compute globals. - let (local globals_pow2 : felt*) = alloc() - compute_pow2_array(globals_pow2, 1, MAX_LENGTH + 1) - # Traverse prev tree. let (local siblings) = alloc() let original_update_ptr = update_ptr let original_siblings = siblings let (local globals_prev : ParticiaGlobals*) = alloc() - assert [globals_prev] = ParticiaGlobals(pow2=globals_pow2, access_offset=DictAccess.prev_value) + assert [globals_prev] = ParticiaGlobals( + pow2=patricia_update_constants.globals_pow2, access_offset=DictAccess.prev_value) assert_le(height, MAX_LENGTH) %{ vm_enter_scope(dict(node=node, **common_args)) %} @@ -500,7 +525,8 @@ func patricia_update{hash_ptr : HashBuiltin*, range_check_ptr}( let update_ptr = original_update_ptr let siblings = original_siblings let (local globals_new : ParticiaGlobals*) = alloc() - assert [globals_new] = ParticiaGlobals(pow2=globals_pow2, access_offset=DictAccess.new_value) + assert [globals_new] = ParticiaGlobals( + pow2=patricia_update_constants.globals_pow2, access_offset=DictAccess.new_value) %{ vm_enter_scope(dict(node=node, **common_args)) %} with update_ptr, siblings: diff --git a/src/starkware/cairo/common/registers.cairo b/src/starkware/cairo/common/registers.cairo index 302a94d7..94ce1e04 100644 --- a/src/starkware/cairo/common/registers.cairo +++ b/src/starkware/cairo/common/registers.cairo @@ -1,19 +1,4 @@ -# Returns the contents of the fp and pc registers of the calling function. -# The pc register's value is the address of the instruction that follows directly after the -# invocation of get_fp_and_pc(). -func get_fp_and_pc() -> (fp_val, pc_val): - # The call instruction itself already places the old fp and the return pc at [ap - 2], [ap - 1]. - return (fp_val=[ap - 2], pc_val=[ap - 1]) -end - -# Returns the content of the ap register just before this function was invoked. -func get_ap() -> (ap_val): - # Once get_ap() is invoked, fp points to ap + 2 (since the call instruction placed the old fp - # and pc in memory, advancing ap accordingly). - # Hence, the desired ap value is fp - 2. - let (fp_val, pc_val) = get_fp_and_pc() - return (ap_val=fp_val - 2) -end +from starkware.cairo.lang.compiler.lib.registers import get_ap, get_fp_and_pc # Takes the value of a label (relative to program base) and returns the actual runtime address of # that label in the memory. @@ -33,9 +18,9 @@ end # let (callback_address) = get_label_location(do_callback) # do_thing_then_callback(callback=callback_address) # end -func get_label_location(label_value) -> (res): +func get_label_location(label_value : codeoffset) -> (res): let (_, pc_val) = get_fp_and_pc() ret_pc_label: - return (res=label_value + pc_val - ret_pc_label) + return (res=pc_val + (label_value - ret_pc_label)) end diff --git a/src/starkware/cairo/common/small_merkle_tree.cairo b/src/starkware/cairo/common/small_merkle_tree.cairo index aadfd8de..982dcaa7 100644 --- a/src/starkware/cairo/common/small_merkle_tree.cairo +++ b/src/starkware/cairo/common/small_merkle_tree.cairo @@ -20,7 +20,7 @@ from starkware.cairo.common.merkle_multi_update import merkle_multi_update # dict_accesses_start=dict_ptr_start, # dict_accesses_end=dict_ptr) # const HEIGHT = 3 -# let (prev_root, new_root) = small_merkle_tree( +# let (prev_root, new_root) = small_merkle_tree_update( # squashed_dict_start, squashed_dict_end, HEIGHT) # # In this example prev_root is the Merkle root of [0, 2, 0, 4, 0, 6, 0, 0], and new_root @@ -46,7 +46,7 @@ from starkware.cairo.common.merkle_multi_update import merkle_multi_update # * squashed_dict was created using the higher-level API dict_squash() (rather than squash_dict()). # * This function can be used for (relatively) small Merkle trees whose leaves can be loaded # to the memory. -func small_merkle_tree{hash_ptr : HashBuiltin*}( +func small_merkle_tree_update{hash_ptr : HashBuiltin*}( squashed_dict_start : DictAccess*, squashed_dict_end : DictAccess*, height : felt) -> ( prev_root : felt, new_root : felt): %{ vm_enter_scope({'__dict_manager': __dict_manager}) %} diff --git a/src/starkware/cairo/common/squash_dict.cairo b/src/starkware/cairo/common/squash_dict.cairo index 2f055c7c..197e6181 100644 --- a/src/starkware/cairo/common/squash_dict.cairo +++ b/src/starkware/cairo/common/squash_dict.cairo @@ -24,18 +24,17 @@ from starkware.cairo.common.math import assert_lt_felt func squash_dict{range_check_ptr}( dict_accesses : DictAccess*, dict_accesses_end : DictAccess*, squashed_dict : DictAccess*) -> (squashed_dict : DictAccess*): - let ptr_diff = [ap] + alloc_locals %{ vm_enter_scope() %} - ptr_diff = dict_accesses_end - dict_accesses; ap++ + local ptr_diff = dict_accesses_end - dict_accesses if ptr_diff == 0: # Access array is empty, nothing to check. %{ vm_exit_scope() %} return (squashed_dict=squashed_dict) end - let first_key = [fp + 1] - let big_keys = [fp + 2] - ap += 2 + local first_key + local big_keys tempvar n_accesses = ptr_diff / DictAccess.SIZE %{ dict_access_size = ids.DictAccess.SIZE @@ -147,11 +146,17 @@ func squash_dict_inner( assert first_value = dict_diff.prev_value # Skip loop nondeterministically if necessary. + # The verifier doesn't care if the loop is skipped or not. The only thing it checks is that + # the function iterated over remaining_accesses accesses in total + # with ascending keys and ascending indices for the same key. + # This guarantees that all the entries were visited exactly once. local should_skip_loop %{ ids.should_skip_loop = 0 if current_access_indices else 1 %} jmp skip_loop if should_skip_loop != 0 loop: + # Define references to access the values from the previous iteration, + # the temporary variables and the values for the current iteration. let prev_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) let loop_temps = cast(ap, LoopTemps*) let loop_locals = cast(ap + LoopTemps.SIZE, LoopLocals*) @@ -179,6 +184,7 @@ func squash_dict_inner( # Next range_check_ptr. loop_locals.range_check_ptr = prev_loop_locals.range_check_ptr + 1; ap++ + # The verifier doesn't care how many loop iterations are executed. See comment above. %{ ids.loop_temps.should_continue = 1 if current_access_indices else 0 %} jmp loop if loop_temps.should_continue != 0; ap++ @@ -189,6 +195,7 @@ func squash_dict_inner( %{ assert len(current_access_indices) == 0 %} [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) [ap] = [last_loop_locals.range_check_ptr]; ap++ + # Calculating the number of used accesses from the number of range check usages for efficiency. tempvar n_used_accesses = last_loop_locals.range_check_ptr - range_check_ptr %{ assert ids.n_used_accesses == len(access_indices[key]) %} @@ -204,6 +211,7 @@ func squash_dict_inner( return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict + DictAccess.SIZE) end + # ap points to the address of the beginning of the next access in the list. let next_key = [ap] ap += 1 # Guess next_key and check that next_key > key. diff --git a/src/starkware/cairo/common/structs.py b/src/starkware/cairo/common/structs.py index 6bc85874..dc3359f9 100644 --- a/src/starkware/cairo/common/structs.py +++ b/src/starkware/cairo/common/structs.py @@ -1,5 +1,4 @@ -from collections import namedtuple -from typing import List, MutableMapping, Optional +from typing import List, MutableMapping, NamedTuple, Optional from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction from starkware.cairo.lang.compiler.identifier_definition import StructDefinition @@ -63,8 +62,14 @@ def build_struct(self, name: ScopedName): """ Builds and returns namedtuple from a Cairo struct. """ - sturct_def = self.get_struct_definition(name=name) - return namedtuple(sturct_def.full_name.path[-1], list(sturct_def.members.keys())) + struct_def = self.get_struct_definition(name=name) + + typed_fields = [ + (member_name, type(member_def.cairo_type)) + for member_name, member_def in struct_def.members.items() + ] + + return NamedTuple(struct_def.full_name.path[-1], typed_fields) def build_func_args(self, func: ScopedName): """ @@ -78,7 +83,13 @@ def build_func_args(self, func: ScopedName): args = get_struct_definition( full_name + CodeElementFunction.ARGUMENT_SCOPE, self.identifiers ).members - return namedtuple(f"{func[-1:]}_full_args", list({**implict_args, **args})) + + typed_fields = [ + (member_name, type(member_def.cairo_type)) + for member_name, member_def in {**implict_args, **args}.items() + ] + + return NamedTuple(f"{func[-1:]}_full_args", typed_fields) @property def structs(self): diff --git a/src/starkware/cairo/common/usort.cairo b/src/starkware/cairo/common/usort.cairo new file mode 100644 index 00000000..7bb0ef29 --- /dev/null +++ b/src/starkware/cairo/common/usort.cairo @@ -0,0 +1,99 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.math import assert_lt, assert_nn + +# Sorts an array of field elements and removes duplicates. +# Returns the sorted array and an array of multiplicities. +# multiplicities[i] is the number of times that output[i] appeared in input. +# Completeness assumption: All numbers are in [0, RANGE_CHECK_BOUND). +func usort{range_check_ptr}(input_len : felt, input : felt*) -> ( + output_len : felt, output : felt*, multiplicities : felt*): + alloc_locals + local output_len + local output : felt* + local multiplicities : felt* + %{ vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size'))) %} + %{ + from collections import defaultdict + + input_ptr = ids.input + input_len = int(ids.input_len) + if __usort_max_size is not None: + assert input_len <= __usort_max_size, ( + f"usort() can only be used with input_len<={__usort_max_size}. " + f"Got: input_len={input_len}." + ) + + positions_dict = defaultdict(list) + for i in range(input_len): + val = memory[input_ptr + i] + positions_dict[val].append(i) + + output = sorted(positions_dict.keys()) + ids.output_len = len(output) + ids.output = segments.gen_arg(output) + ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output]) + %} + + let output_start = output + verify_usort{output=output}( + input_len=input_len, input=input, total_visited=0, multiplicities=multiplicities, prev=-1) + + %{ vm_exit_scope() %} + return (output_len=output - output_start, output=output_start, multiplicities=multiplicities) +end + +# Verifies that usort of input is (output, multiplicities). See usort(). +func verify_usort{range_check_ptr, output : felt*}( + input_len : felt, input : felt*, total_visited : felt, multiplicities : felt*, prev : felt): + alloc_locals + + if total_visited == input_len: + return () + end + + local value = [output] + let output = &output[1] + assert_lt(prev, value) + + local multiplicity = [multiplicities] + assert_nn(multiplicity - 1) + + %{ + last_pos = 0 + positions = positions_dict[ids.value][::-1] + %} + verify_multiplicity(multiplicity=multiplicity, input_len=input_len, input=input, value=value) + + return verify_usort( + input_len=input_len, + input=input, + total_visited=total_visited + multiplicity, + multiplicities=&multiplicities[1], + prev=value) +end + +# Verifies that value appears at least multiplicity times in input. +func verify_multiplicity{range_check_ptr}( + multiplicity : felt, input_len : felt, input : felt*, value : felt): + if multiplicity == 0: + %{ assert len(positions) == 0 %} + assert_nn(input_len) + return () + end + + alloc_locals + # Skip to the next appearance. + local next_item_index + %{ + current_pos = positions.pop() + ids.next_item_index = current_pos - last_pos + last_pos = current_pos + 1 + %} + assert_nn(next_item_index) + assert input[next_item_index] = value + return verify_multiplicity( + multiplicity=multiplicity - 1, + input_len=input_len - next_item_index - 1, + input=&input[next_item_index + 1], + value=value) +end diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 39e898a4..a3df0a69 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.7.1 +0.8.0 diff --git a/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py b/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py index 4a7bca26..9972b086 100644 --- a/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py @@ -78,12 +78,12 @@ def rule(vm, addr): # Assert that if the current address is part of a point which is all set in the # memory, the point is on the curve. for pair in EC_POINT_INDICES[:2]: - ec_point = [memory[instance + i] for i in pair] + ec_point_x, ec_point_y = [memory[instance + i] for i in pair] assert point_on_curve( - *ec_point, ALPHA, BETA, FIELD_PRIME + ec_point_x, ec_point_y, ALPHA, BETA, FIELD_PRIME ), f"{self.name} builtin: point {pair} is not on the curve." - res = ec_op_impl( + res = ec_op_impl( # type: ignore *[memory[instance + i] for i in range(INPUT_CELLS_PER_EC_OP)], ALPHA, FIELD_PRIME ) # The result cannot be the point at infinity. diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt index f6f32325..304fa723 100644 --- a/src/starkware/cairo/lang/compiler/CMakeLists.txt +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -39,6 +39,7 @@ python_lib(cairo_compile_lib injector.py instruction_builder.py instruction.py + lib/registers.cairo location_utils.py module_reader.py offset_reference.py @@ -167,6 +168,7 @@ full_python_test(cairo_compile_test LIBS cairo_compile_lib cairo_compile_test_utils_lib + starkware_python_test_utils_lib starkware_python_utils_lib pip_pytest ) diff --git a/src/starkware/cairo/lang/compiler/assembler.py b/src/starkware/cairo/lang/compiler/assembler.py index 3197e172..d35d295f 100644 --- a/src/starkware/cairo/lang/compiler/assembler.py +++ b/src/starkware/cairo/lang/compiler/assembler.py @@ -2,8 +2,10 @@ from starkware.cairo.lang.compiler.debug_info import DebugInfo, HintLocation, InstructionLocation from starkware.cairo.lang.compiler.encode import encode_instruction +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.instruction_builder import build_instruction from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram +from starkware.cairo.lang.compiler.preprocessor.unique_labels import is_anonymous_label from starkware.cairo.lang.compiler.program import CairoHint, Program from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -59,12 +61,21 @@ def assemble( if debug_info is not None: debug_info.add_autogen_file_contents() + # Filter anonymous labels. + identifiers = IdentifierManager.from_dict( + { + name: identifier_definition + for name, identifier_definition in preprocessed_program.identifiers.as_dict().items() + if not is_anonymous_label(name.path[-1]) + } + ) + return Program( prime=preprocessed_program.prime, data=data, hints=hints, main_scope=main_scope, - identifiers=preprocessed_program.identifiers, + identifiers=identifiers, attributes=preprocessed_program.attributes, builtins=preprocessed_program.builtins, reference_manager=preprocessed_program.reference_manager, diff --git a/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py b/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py index ed4133ee..1652cdd3 100644 --- a/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py +++ b/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py @@ -3,6 +3,7 @@ ExprDeref, ExprDot, ExprNeg, + ExprNewOperator, ExprOperator, ExprParentheses, ExprSubscript, @@ -25,6 +26,8 @@ def remove_parentheses(expr): return ExprDeref(addr=remove_parentheses(expr.addr)) if isinstance(expr, ExprDot): return ExprDot(expr=remove_parentheses(expr.expr), member=expr.member) + if isinstance(expr, ExprNewOperator): + return ExprNewOperator(expr=remove_parentheses(expr.expr), is_typed=expr.is_typed) if isinstance(expr, ExprSubscript): return ExprSubscript( expr=remove_parentheses(expr.expr), offset=remove_parentheses(expr.offset) diff --git a/src/starkware/cairo/lang/compiler/ast/cairo_types.py b/src/starkware/cairo/lang/compiler/ast/cairo_types.py index a97ec4ec..af0dd9d1 100644 --- a/src/starkware/cairo/lang/compiler/ast/cairo_types.py +++ b/src/starkware/cairo/lang/compiler/ast/cairo_types.py @@ -5,6 +5,7 @@ from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import Notes from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -36,6 +37,17 @@ def get_children(self) -> Sequence[Optional[AstNode]]: return [] +@dataclasses.dataclass +class TypeCodeoffset(CairoType): + location: Optional[Location] = LocationField + + def format(self): + return "codeoffset" + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [] + + @dataclasses.dataclass class TypePointer(CairoType): pointee: CairoType @@ -73,19 +85,82 @@ def get_children(self) -> Sequence[Optional[AstNode]]: @dataclasses.dataclass class TypeTuple(CairoType): """ - Type for a tuple. + Represents a type of a named or unnamed tuple. + For example, "(felt, felt*)" or "(a : felt, b : felt*)". """ - members: List[CairoType] + @dataclasses.dataclass + class Item(AstNode): + """ + Represents a possibly named type item of a TypeTuple. + For example: "felt" or "a : felt". + """ + + name: Optional[str] + typ: CairoType + location: Optional[Location] = LocationField + + def format(self): + if self.name is None: + return self.typ.format() + return f"{self.name} : {self.typ.format()}" + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.typ] + + members: List["TypeTuple.Item"] + notes: List[Notes] = dataclasses.field(hash=False, compare=False) + has_trailing_comma: bool = dataclasses.field(hash=False, compare=False) location: Optional[Location] = LocationField + def __post_init__(self): + assert len(self.notes) == len(self.members) + 1 + + def assert_no_comments(self): + for note in self.notes: + note.assert_no_comments() + def format(self): + self.assert_no_comments() member_formats = [member.format() for member in self.members] return f"({', '.join(member_formats)})" def get_children(self) -> Sequence[Optional[AstNode]]: return self.members + @property + def types(self) -> List[CairoType]: + """ + Returns the unnamed types of the tuple. + """ + return [x.typ for x in self.members] + + @classmethod + def unnamed(cls, types: List[CairoType], location: Optional[Location] = None): + """ + Creates an unnamed tuple type from the given types. + """ + return cls.from_members( + members=[TypeTuple.Item(name=None, typ=typ) for typ in types], + location=location, + ) + + @classmethod + def from_members(cls, members: List["TypeTuple.Item"], location: Optional[Location]): + """ + Creates a tuple (with no notes) from the given members. + """ + return cls( + members=members, + notes=[Notes() for _ in range(len(members) + 1)], + has_trailing_comma=False, + location=location, + ) + + @property + def is_named(self) -> bool: + return all(member.name is not None for member in self.members) + class CastType(Enum): # When the compiler creates a cast expression for references. diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py index 9710c78f..73d2d694 100644 --- a/src/starkware/cairo/lang/compiler/ast/code_elements.py +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -1,10 +1,11 @@ import dataclasses from abc import abstractmethod -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Union from starkware.cairo.lang.compiler.ast.aliased_identifier import AliasedIdentifier from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType from starkware.cairo.lang.compiler.ast.expr import ( ExprAssignment, Expression, @@ -15,7 +16,7 @@ INDENTATION, LocationField, ParticleFormattingConfig, - create_particle_sublist, + ParticleList, particles_in_lines, ) from starkware.cairo.lang.compiler.ast.instructions import InstructionAst @@ -169,10 +170,9 @@ class CodeElementReturn(CodeElement): def format(self, allowed_line_length): expr_codes = [x.format() for x in self.exprs] - particles = ["return (", create_particle_sublist(expr_codes, ")")] return particles_in_lines( - particles=particles, + particles=["return (", ParticleList(elements=expr_codes, end=")")], config=ParticleFormattingConfig( allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=True ), @@ -274,11 +274,10 @@ def format(self, allowed_line_length): particles = self.rvalue.get_particles() end_particle = ") = " + particles[0] - particles = ( - ["let ("] - + create_particle_sublist(self.unpacking_list.get_particles(), end_particle) - + particles[1:] + unpacking_list_particles = ParticleList( + elements=self.unpacking_list.get_particles(), end=end_particle ) + particles = ["let ("] + unpacking_list_particles.to_strings() + particles[1:] return particles_in_lines( particles=particles, @@ -428,13 +427,14 @@ def name(self): def format(self, allowed_line_length): code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) code = indent(code, INDENTATION) + particles: List[Union[str, ParticleList]] if self.element_type in ["struct", "namespace"]: particles = [f"{self.element_type} {self.name}:"] else: if self.implicit_arguments is not None: first_particle_suffix = "{" implicit_args_particles = [ - create_particle_sublist(self.implicit_arguments.get_particles(), "}(") + ParticleList(elements=self.implicit_arguments.get_particles(), end="}(") ] else: first_particle_suffix = "(" @@ -444,21 +444,23 @@ def format(self, allowed_line_length): particles = [ f"{self.element_type} {self.name}{first_particle_suffix}", *implicit_args_particles, - create_particle_sublist(self.arguments.get_particles(), ") -> ("), - create_particle_sublist(self.returns.get_particles(), "):"), + ParticleList(elements=self.arguments.get_particles(), end=") -> ("), + ParticleList(elements=self.returns.get_particles(), end="):"), ] else: particles = [ f"{self.element_type} {self.name}{first_particle_suffix}", *implicit_args_particles, - create_particle_sublist(self.arguments.get_particles(), "):"), + ParticleList(elements=self.arguments.get_particles(), end="):"), ] decorators = "".join(f"@{decorator.format()}\n" for decorator in self.decorators) header = particles_in_lines( particles=particles, config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, line_indent=INDENTATION * 2 + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + double_indentation=True, ), ) return f"{decorators}{header}\n{code}end" @@ -473,6 +475,30 @@ def get_children(self) -> Sequence[Optional[AstNode]]: ] +@dataclasses.dataclass +class CodeElementTypeDef(CodeElement): + """ + Represents a statement of the form: + using new_type_name = old_type + For example, + using Point = (x : felt, y : felt) + """ + + identifier: ExprIdentifier + cairo_type: CairoType + location: Optional[Location] = LocationField + + def format(self, allowed_line_length): + return f"using {self.identifier.format()} = {self.cairo_type.format()}" + + @property + def name(self) -> str: + return self.identifier.name + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.identifier, self.cairo_type] + + @dataclasses.dataclass class CodeElementWithAttr(CodeElement): attribute_name: ExprIdentifier @@ -620,9 +646,8 @@ def format(self, allowed_line_length): if len(one_liner) <= allowed_line_length: return one_liner - particles = [f"{prefix}(", create_particle_sublist(items, ")")] return particles_in_lines( - particles=particles, + particles=[f"{prefix}(", ParticleList(elements=items, end=")")], config=ParticleFormattingConfig( allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=False ), diff --git a/src/starkware/cairo/lang/compiler/ast/expr.py b/src/starkware/cairo/lang/compiler/ast/expr.py index 06a485cb..1b621946 100644 --- a/src/starkware/cairo/lang/compiler/ast/expr.py +++ b/src/starkware/cairo/lang/compiler/ast/expr.py @@ -13,7 +13,7 @@ from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.instruction import Register from starkware.python.expression_string import ExpressionString -from starkware.python.utils import indent +from starkware.python.utils import indent, safe_zip class Expression(AstNode): @@ -146,6 +146,9 @@ class ArgList(AstNode): has_trailing_comma: bool location: Optional[Location] = LocationField + def __post_init__(self): + assert len(self.notes) == len(self.args) + 1 + def assert_no_comments(self): for note in self.notes: note.assert_no_comments() @@ -157,7 +160,7 @@ def format(self): code = "" assert len(self.args) + 1 == len(self.notes) - for notes, arg in zip(self.notes[:-1], self.args): + for notes, arg in safe_zip(self.notes[:-1], self.args): if code != "": code += "," if notes.empty: @@ -375,20 +378,39 @@ def get_children(self) -> Sequence[Optional[AstNode]]: return [self.members] -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class ExprFutureLabel(Expression): """ Represents a future label whose current pc is not known yet. """ identifier: ExprIdentifier + # True if the label should be considered of type codeoffset (otherwise it is considered felt). + is_typed: bool + location: Optional[Location] = LocationField def to_expr_str(self): return self.identifier.to_expr_str() - @property - def location(self): - return self.identifier.location - def get_children(self) -> Sequence[Optional[AstNode]]: return [self.identifier] + + +@dataclasses.dataclass +class ExprNewOperator(Expression): + """ + Represents an expression of the form "new ()". + For example, "new MyStruct(1, 2, z=3)". + """ + + expr: Expression + # True if the type of the expression should be a pointer to the type of 'expr'. + # False, if the type should be considered as felt. + is_typed: bool + location: Optional[Location] = LocationField + + def to_expr_str(self): + return self.expr.to_expr_str().operator_new() + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.expr] diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py index 6592caa3..e6c68689 100644 --- a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import field -from typing import List +from typing import List, Union import marshmallow @@ -20,6 +20,7 @@ metadata=dict(marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True)), ) max_line_length_ctx_var: ContextVar[int] = ContextVar("max_line_length", default=100) +one_item_per_line_ctx_var: ContextVar[bool] = ContextVar("one_item_per_line", default=False) def get_max_line_length(): @@ -27,14 +28,29 @@ def get_max_line_length(): @contextmanager -def set_max_line_length(line_length: bool): +def set_max_line_length(line_length: int): """ Context manager that sets max_line_length context variable. """ - previous = get_max_line_length() - max_line_length_ctx_var.set(line_length) - yield - max_line_length_ctx_var.set(previous) + token = max_line_length_ctx_var.set(line_length) + try: + yield + finally: + max_line_length_ctx_var.reset(token) + + +@contextmanager +def set_one_item_per_line(value: bool): + """ + Context manager that sets one_item_per_line context variable. + If true, each list item (e.g., function arguments) will be put in a separate line, + if the list doesn't fit a single line. + """ + token = one_item_per_line_ctx_var.set(value) + try: + yield + finally: + one_item_per_line_ctx_var.reset(token) class FormattingError(LocationError): @@ -50,7 +66,37 @@ class ParticleFormattingConfig: # The prefix of the first line. first_line_prefix: str = "" # At most one item per line. + # Note: if the one_item_per_line ContextVar is True, this field is ignored (it has a slightly + # different formatting). one_per_line: bool = False + # If True, line_indent is doubled. + # Note: if the one_item_per_line ContextVar is True, this field is ignored. + double_indentation: bool = False + + +@dataclasses.dataclass(frozen=True) +class ParticleList: + """ + A list of particles, which is a part of a larger list of particles that constructs one or + more lines. + """ + + elements: List[str] + separator: str = ", " + end: str = "" + + def to_strings(self) -> List[str]: + if len(self.elements) == 0: + # If the list is empty, return the single element 'end'. + return [self.end] + # Concatenate the 'separator' to all elements and 'end' to the last one. + return [elm + self.separator for elm in self.elements[:-1]] + [self.elements[-1] + self.end] + + def elements_to_string(self) -> str: + """ + Returns a concatenation of the strings in self.elements, separated with self.separator. + """ + return self.separator.join(self.elements) class ParticleLineBuilder: @@ -65,7 +111,7 @@ def __init__(self, config: ParticleFormattingConfig): self.config = config - def newline(self): + def newline(self, indent: bool = True): """ Opens a new line. """ @@ -73,7 +119,7 @@ def newline(self): return self.lines.append(self.line) self.line_is_new = True - self.line = " " * self.config.line_indent + self.line = (" " * self.config.line_indent) if indent else "" def add_to_line(self, string): """ @@ -85,6 +131,12 @@ def add_to_line(self, string): self.line += string self.line_is_new = False + def can_fit_in_line(self, string: str) -> bool: + """ + Returns True if the given string can fit in the current line. + """ + return len(self.line) + len(string) <= self.config.allowed_line_length + def finalize(self): """ Finalizes the particle lines and returns the result. @@ -94,62 +146,126 @@ def finalize(self): return "\n".join(line.rstrip() for line in self.lines) -def create_particle_sublist(lst: List[str], end: str = "", separator: str = ", ") -> List[str]: - if len(lst) == 0: - # If the list is empty, return the single element 'end'. - return [end] - # Concatenate the 'separator' to all elements of the 'lst' and 'end' to the last one. - return [elm + separator for elm in lst[:-1]] + [lst[-1] + end] +def add_list_new_format(builder: ParticleLineBuilder, particle_list: ParticleList): + """ + Adds a particle list to the current line. + If the list cannot be fully concatenated to the current line opens a new line, and puts each + element of the list in a separate line, indented by 'INDENTATION' charactes. + + For example, using this function to format a list of arguments may result in the following + formatting: + func f( + x, + y, + z, + ) -> ( + a, + b, + c + ): + + With a longer line length we will get the lists on the same line: + func f(x, y, z) -> (a, b, c): + """ + elements_string = particle_list.elements_to_string() + + # If the entire list fits in the current line, or the list is empty, add everything to the + # current line. + if ( + builder.can_fit_in_line(elements_string + particle_list.end) + or len(particle_list.elements) == 0 + ): + builder.add_to_line(elements_string + particle_list.end) + return + + # If the entire list fits in a new line, add it. + # Else, add each element of the list in a separate line. + builder.newline() + if builder.can_fit_in_line(elements_string): + builder.add_to_line(elements_string) + else: + for elm in particle_list.elements: + builder.newline() + builder.add_to_line(elm + particle_list.separator) + + builder.newline(indent=False) + builder.add_to_line(particle_list.end) -def particles_in_lines(particles, config: ParticleFormattingConfig): +def add_list_old_format( + builder: ParticleLineBuilder, particle_list: ParticleList, config: ParticleFormattingConfig +): + """ + Adds a particle list to the current line. + If the list cannot be fully concatenated to the current line opens a new line. + + For example, using this function to format a list of arguments may result in the following + formatting: + func f( + x, y, + z) -> ( + a, b, + c): + + With a longer line length we will get the lists on the same line: + func f(x, y, z) -> (a, b, c): + """ + list_particles = particle_list.to_strings() + + # If the entire list fits in a single line, add it. + if sum(map(len, list_particles), config.line_indent) < config.allowed_line_length: + builder.add_to_line("".join(list_particles)) + return + builder.newline() + for member in list_particles: + if config.one_per_line: + builder.newline() + builder.add_to_line(member) + + +def particles_in_lines( + particles: List[Union[str, ParticleList]], config: ParticleFormattingConfig +) -> str: """ Receives a list 'particles' that contains strings and particle sublists and generates lines according to the following rules: + + When one_item_per_line ContextVar is False: - The first line is not indented. All other lines start with 'line_indent' spaces. - A line containing more than one particle can be no longer than 'allowed_line_length'. - - A sublist that cannot be fully concatenated to the current line opens a new line. - - Example: - particles_in_lines( - ['func f(', - create_particle_sublist(['x', 'y', 'z'], ') -> ('), - create_particle_sublist(['a', 'b', 'c'], '):')], - 12, 4) - returns '''\ - func f( - x, y, - z) -> ( - a, b, - c):\ - ''' - With a longer line length we will get the lists on the same line: - particles_in_lines( - ['func f(', - create_particle_sublist(['x', 'y', 'z'], ') -> ('), - create_particle_sublist([], '):')], - 19, 4) - returns '''\ - func f( - x, y, z) -> ():\ - ''' + - A sublist that cannot be fully concatenated to the current line opens a new line (see + add_list_old_format). + + When one_item_per_line ContextVar is True: + - The first line is not indented. Other lines start with 'line_indent' spaces. Lines + that contruct sublists are indented as described in add_list_new_format. + - A line containing more than one particle can be no longer than 'allowed_line_length'. + - A sublist that cannot be fully concatenated to the current line opens a new line (see + add_list_new_format). + + Usage example: + particles_in_lines( + ['func f(', + ParticleList(elements=['x', 'y', 'z'], end=') -> ('), + ParticleList(elements=['a', 'b', 'c'], end='):')], + 12, 4) """ + if config.double_indentation and not one_item_per_line_ctx_var.get(): + config = dataclasses.replace( + config, line_indent=2 * config.line_indent, double_indentation=False + ) + builder = ParticleLineBuilder(config=config) for particle in particles: if isinstance(particle, str): builder.add_to_line(particle) - if isinstance(particle, list): - # If the entire sublist fits in a single line, add it. - if sum(map(len, particle), config.line_indent) < config.allowed_line_length: - builder.add_to_line("".join(particle)) - continue - builder.newline() - for member in particle: - if config.one_per_line: - builder.newline() - builder.add_to_line(member) + if isinstance(particle, ParticleList): + if one_item_per_line_ctx_var.get(): + add_list_new_format(builder, particle) + else: + add_list_old_format(builder, particle, config) return builder.finalize() diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py index 8c5c444f..1ba4f03a 100644 --- a/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py @@ -1,16 +1,40 @@ +from typing import List, Union + from starkware.cairo.lang.compiler.ast.formatting_utils import ( ParticleFormattingConfig, - create_particle_sublist, + ParticleList, particles_in_lines, + set_one_item_per_line, ) +def run_test_particles_in_lines( + particles, config: ParticleFormattingConfig, expected: str, expected_one_per_line: str +): + with set_one_item_per_line(False): + assert ( + particles_in_lines( + particles=particles, + config=config, + ) + == expected + ) + with set_one_item_per_line(True): + assert ( + particles_in_lines( + particles=particles, + config=config, + ) + == expected_one_per_line + ) + + def test_particles_in_lines(): - particles = [ + particles: List[Union[str, ParticleList]] = [ "start ", "foo ", "bar ", - create_particle_sublist(["a", "b", "c", "dddd", "e", "f"], "*"), + ParticleList(elements=["a", "b", "c", "dddd", "e", "f"], end="*"), " asdf", ] expected = """\ @@ -20,18 +44,28 @@ def test_particles_in_lines(): dddd, e, f* asdf\ """ - assert ( - particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=12, line_indent=2), - ) - == expected + expected_one_per_line = """\ +start foo + bar + a, + b, + c, + dddd, + e, + f, +* asdf\ +""" + run_test_particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=2), + expected=expected, + expected_one_per_line=expected_one_per_line, ) particles = [ "func f(", - create_particle_sublist(["x", "y", "z"], ") -> ("), - create_particle_sublist(["a", "b", "c"], "):"), + ParticleList(elements=["x", "y", "z"], end=") -> ("), + ParticleList(elements=["a", "b", "c"], end="):"), ] expected = """\ func f( @@ -40,12 +74,18 @@ def test_particles_in_lines(): a, b, c):\ """ - assert ( - particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=12, line_indent=4), - ) - == expected + expected_one_per_line = """\ +func f( + x, y, z +) -> ( + a, b, c +):\ +""" + run_test_particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=4), + expected=expected, + expected_one_per_line=expected_one_per_line, ) # Same particles, using one_per_line=True. @@ -86,8 +126,8 @@ def test_particles_in_lines(): particles = [ "func f(", - create_particle_sublist(["x", "y", "z"], ") -> ("), - create_particle_sublist([], "):"), + ParticleList(elements=["x", "y", "z"], end=") -> ("), + ParticleList(elements=[], end="):"), ] expected = """\ func f( @@ -107,10 +147,10 @@ def test_linebreak_on_particle_space(): Tests line breaking when the line length is exceeded by the space in the ', ' seperator at the end of a particle. """ - particles = [ + particles: List[Union[str, ParticleList]] = [ "func f(", - create_particle_sublist(["x", "y", "z"], ") -> ("), - create_particle_sublist([], "):"), + ParticleList(elements=["x", "y", "z"], end=") -> ("), + ParticleList(elements=[], end="):"), ] expected = """\ func f( @@ -118,20 +158,25 @@ def test_linebreak_on_particle_space(): z) -> ( ):\ """ - assert ( - particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=9, line_indent=4), - ) - == expected + expected_one_per_line = """\ +func f( + x, + y, + z, +) -> ():\ +""" + run_test_particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=9, line_indent=4), + expected=expected, + expected_one_per_line=expected_one_per_line, ) - assert ( - particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=10, line_indent=4), - ) - == expected + run_test_particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=10, line_indent=4), + expected=expected, + expected_one_per_line=expected_one_per_line, ) expected = """\ @@ -141,10 +186,9 @@ def test_linebreak_on_particle_space(): z) -> ( ):\ """ - assert ( - particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=8, line_indent=4), - ) - == expected + run_test_particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=8, line_indent=4), + expected=expected, + expected_one_per_line=expected_one_per_line, ) diff --git a/src/starkware/cairo/lang/compiler/ast/instructions.py b/src/starkware/cairo/lang/compiler/ast/instructions.py index 844e4574..c6e05e33 100644 --- a/src/starkware/cairo/lang/compiler/ast/instructions.py +++ b/src/starkware/cairo/lang/compiler/ast/instructions.py @@ -116,6 +116,9 @@ class CallLabelInstruction(InstructionBody): label: ExprIdentifier location: Optional[Location] = LocationField + # Indicates the 'label' is a fully qualified identifier, rather then a relative one. + # This field is typically set for compiler-generated calls. + fully_qualified_label: bool = False def format(self): return f"call {self.label.format()}" diff --git a/src/starkware/cairo/lang/compiler/ast/node.py b/src/starkware/cairo/lang/compiler/ast/node.py index d0208f53..b7bae5a2 100644 --- a/src/starkware/cairo/lang/compiler/ast/node.py +++ b/src/starkware/cairo/lang/compiler/ast/node.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Sequence +from typing import Iterator, Optional, Sequence class AstNode(ABC): @@ -8,3 +8,12 @@ def get_children(self) -> Sequence[Optional["AstNode"]]: """ Returns a list of the node's children (notes are not included). """ + + def get_subtree(self) -> Iterator["AstNode"]: + """ + Returns an iterator of all non-None nodes in the subtree rooted at this node, preorder + (visit each node before its children). + """ + yield self + for child in filter(None, self.get_children()): + yield from child.get_subtree() diff --git a/src/starkware/cairo/lang/compiler/ast/notes.py b/src/starkware/cairo/lang/compiler/ast/notes.py index b1cab410..b189b9dd 100644 --- a/src/starkware/cairo/lang/compiler/ast/notes.py +++ b/src/starkware/cairo/lang/compiler/ast/notes.py @@ -9,7 +9,7 @@ from starkware.cairo.lang.compiler.error_handling import Location -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Notes(AstNode): """ Represents new-lines and comments that appear inside an expression or other code element. diff --git a/src/starkware/cairo/lang/compiler/ast/rvalue.py b/src/starkware/cairo/lang/compiler/ast/rvalue.py index 8c069069..d515c1ad 100644 --- a/src/starkware/cairo/lang/compiler/ast/rvalue.py +++ b/src/starkware/cairo/lang/compiler/ast/rvalue.py @@ -7,7 +7,7 @@ INDENTATION, LocationField, ParticleFormattingConfig, - create_particle_sublist, + ParticleList, particles_in_lines, ) from starkware.cairo.lang.compiler.ast.instructions import CallInstruction @@ -123,12 +123,12 @@ def get_particles(self): if self.implicit_arguments is not None: particles[-1] += "{" particles.append( - create_particle_sublist([x.format() for x in self.implicit_arguments.args], "}(") + ParticleList(elements=[x.format() for x in self.implicit_arguments.args], end="}(") ) else: particles[-1] += "(" - particles.append(create_particle_sublist([x.format() for x in self.arguments.args], ")")) + particles.append(ParticleList(elements=[x.format() for x in self.arguments.args], end=")")) return particles def format(self, allowed_line_length): diff --git a/src/starkware/cairo/lang/compiler/ast_objects_test.py b/src/starkware/cairo/lang/compiler/ast_objects_test.py index f8e1d05f..871312c9 100644 --- a/src/starkware/cairo/lang/compiler/ast_objects_test.py +++ b/src/starkware/cairo/lang/compiler/ast_objects_test.py @@ -2,7 +2,10 @@ from starkware.cairo.lang.compiler.ast.ast_objects_test_utils import remove_parentheses from starkware.cairo.lang.compiler.ast.expr import ExprConst, ExprNeg, ExprOperator -from starkware.cairo.lang.compiler.ast.formatting_utils import FormattingError +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + FormattingError, + set_one_item_per_line, +) from starkware.cairo.lang.compiler.parser import parse_code_element, parse_expr, parse_file @@ -53,6 +56,8 @@ def test_format_parentheses(): assert remove_parentheses(parse_expr("[((x+y) + z)]")).format() == "[x + y + z]" + assert remove_parentheses(parse_expr("new (2+3)")).format() == "new (2 + 3)" + # Test that parentheses are not added if they were already present. assert parse_expr("(a * (b + c))").format() == "(a * (b + c))" assert parse_expr("((a * ((b + c))))").format() == "((a * ((b + c))))" @@ -527,6 +532,145 @@ def test_func_arg_ret_splitting(): assert parse_file(before).format(allowed_line_length=25) == after +def test_func_one_per_line_splitting(): + before = """\ +func myfunc{x, y}(a, b): + ret +end +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == before + before = """\ +func myfunc{x: felt*, y: felt*}(a, b, c, d): + ret +end +""" + after = """\ +func myfunc{ + x : felt*, y : felt* +}(a, b, c, d): + ret +end +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + before = """\ +func myfunc{long_imp_arg1, long_imp_arg2,long_imp_arg3}( + very_long_arg1, very_long_arg2): + ret +end +""" + after = """\ +func myfunc{ + long_imp_arg1, + long_imp_arg2, + long_imp_arg3, +}( + very_long_arg1, + very_long_arg2, +): + ret +end +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + before = """\ +func myfunc{}(a, variable_name_which_is_way_too_long_but_has_to_be_supported): + ret +end +""" + after = """\ +func myfunc{}( + a, + variable_name_which_is_way_too_long_but_has_to_be_supported, +): + ret +end +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + before = """\ +func myfunc(ab, cd, ef) -> ( + long_return_arg1, long_return_arg2): + ret +end +""" + after = """\ +func myfunc( + ab, cd, ef +) -> ( + long_return_arg1, + long_return_arg2, +): + ret +end +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + + +def test_return_one_per_line_splitting(): + before = """\ +return (a, b, c, foo, bar, + variable_name_which_is_way_too_long_but_has_to_be_supported, g) +""" + after = """\ +return ( + a, + b, + c, + foo, + bar, + variable_name_which_is_way_too_long_but_has_to_be_supported, + g, +) +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + + +def test_func_call_one_per_line_splitting(): + before = """\ +let (a, b, c) = foo(long_arg1, long_arg2, long_arg3) +""" + after = """\ +let (a, b, c) = foo( + long_arg1, + long_arg2, + long_arg3, +) +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + before = """\ +return foo(long_arg1, long_arg2, long_arg3) +""" + after = """\ +return foo( + long_arg1, + long_arg2, + long_arg3, +) +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + + +def test_import_one_per_line_splitting(): + before = """\ +from a.b.c import (import1, import2, import3) +""" + after = """\ +from a.b.c import ( + import1, + import2, + import3, +) +""" + with set_one_item_per_line(True): + assert parse_file(before).format(allowed_line_length=25) == after + + def test_directives(): code = """\ [ap] = [ap] diff --git a/src/starkware/cairo/lang/compiler/cairo.ebnf b/src/starkware/cairo/lang/compiler/cairo.ebnf index ed61a4ea..40eaa27d 100644 --- a/src/starkware/cairo/lang/compiler/cairo.ebnf +++ b/src/starkware/cairo/lang/compiler/cairo.ebnf @@ -15,19 +15,24 @@ _NEQ: "!=" _ARROW: "->" _AT: "@" +// Comma separated list with notes. +comma: "," +_expr_sep: (comma | nonempty_notes)+ +comma_separated_with_notes{item}: notes item? (_expr_sep item)* _expr_sep? + // Types. -type: "felt" -> type_felt - | identifier -> type_struct - | type "*" -> type_pointer - | type _DBL_STAR -> type_pointer2 - | "(" ((type ",")* type ","?)? ")" -> type_tuple +named_type: identifier (":" type)? | non_identifier_type +non_identifier_type: "felt" -> type_felt + | "codeoffset" -> type_codeoffset + | type "*" -> type_pointer + | type _DBL_STAR -> type_pointer2 + | "(" comma_separated_with_notes{named_type} ")" -> type_tuple +?type: non_identifier_type + | identifier -> type_struct // Expressions. expr_assignment: expr | identifier_def "=" expr -?arg_list_item: expr_assignment -comma: "," -_expr_sep: (comma | nonempty_notes)+ -arg_list: (notes arg_list_item? (_expr_sep arg_list_item)* _expr_sep?) +arg_list: comma_separated_with_notes{expr_assignment} decorator: _AT identifier_def decorator_list: (decorator _NEWLINE*)* @@ -42,6 +47,7 @@ decorator_list: (decorator _NEWLINE*)* ?unary: pow | "&" unary -> unary_addressof | "-" unary -> unary_neg + | "new" unary -> unary_new_operator ?pow: atom | atom _DBL_STAR notes pow -> expr_pow identifier: IDENTIFIER ("." IDENTIFIER)* @@ -148,6 +154,7 @@ code_element: instruction -> code_element_ | identifier ":" -> code_element_label | _func -> code_element_function | _struct -> code_element_struct + | "using" identifier_def "=" type -> code_element_typedef | _with_attr_statement -> code_element_with_attr | _with_statement -> code_element_with | HINT -> code_element_hint diff --git a/src/starkware/cairo/lang/compiler/cairo_compile.py b/src/starkware/cairo/lang/compiler/cairo_compile.py index 54b766c3..b007e380 100644 --- a/src/starkware/cairo/lang/compiler/cairo_compile.py +++ b/src/starkware/cairo/lang/compiler/cairo_compile.py @@ -94,10 +94,6 @@ def cairo_compile_common( try: codes = get_codes(args.files) - file_contents_for_debug_info = {} - if getattr(args, "proof_mode", False): - codes = add_start_code(codes) - file_contents_for_debug_info[START_FILE_NAME] = codes[0][0] out = args.output if args.output is not None else sys.stdout cairo_path: List[str] = list( @@ -107,8 +103,17 @@ def cairo_compile_common( pass_manager = pass_manager_factory(args, module_reader) + start_codes = [] + file_contents_for_debug_info = {} + if getattr(args, "proof_mode", False): + start_codes = [(get_start_code(), START_FILE_NAME)] + file_contents_for_debug_info[START_FILE_NAME] = start_codes[0][0] + preprocessed = preprocess_codes( - codes=codes, pass_manager=pass_manager, main_scope=MAIN_SCOPE + codes=codes, + pass_manager=pass_manager, + main_scope=MAIN_SCOPE, + start_codes=start_codes, ) if args.preprocess: @@ -160,10 +165,6 @@ def get_codes(file_names: List[str]) -> List[Tuple[str, str]]: return codes_with_filenames -def add_start_code(codes_with_filenames: List[Tuple[str, str]]) -> List[Tuple[str, str]]: - return [(get_start_code(), START_FILE_NAME)] + codes_with_filenames - - def compile_cairo_files( files: List[str], prime: Optional[int] = None, @@ -207,12 +208,10 @@ def compile_cairo_ex( if isinstance(code, list): codes_with_filenames = code + start_codes = [] if add_start: - codes_with_filenames = add_start_code(codes_with_filenames) - - # Add the start code to the debug info if exists. - if START_FILE_NAME == codes_with_filenames[0][1]: - file_contents_for_debug_info[START_FILE_NAME] = codes_with_filenames[0][0] + start_codes = [(get_start_code(), START_FILE_NAME)] + file_contents_for_debug_info[START_FILE_NAME] = start_codes[0][0] if pass_manager is None: assert prime is not None, "Exactly one of prime and pass_manager must be given." @@ -227,7 +226,10 @@ def compile_cairo_ex( if main_scope is None: main_scope = MAIN_SCOPE preprocessed_program = preprocess_codes( - codes=codes_with_filenames, pass_manager=pass_manager, main_scope=main_scope + codes=codes_with_filenames, + pass_manager=pass_manager, + main_scope=main_scope, + start_codes=start_codes, ) program = cairo_assemble_program( preprocessed_program, diff --git a/src/starkware/cairo/lang/compiler/cairo_format.py b/src/starkware/cairo/lang/compiler/cairo_format.py index b197a0fe..f4a0dbe0 100644 --- a/src/starkware/cairo/lang/compiler/cairo_format.py +++ b/src/starkware/cairo/lang/compiler/cairo_format.py @@ -1,6 +1,7 @@ import argparse import sys +from starkware.cairo.lang.compiler.ast.formatting_utils import set_one_item_per_line from starkware.cairo.lang.compiler.parser import parse_file from starkware.cairo.lang.version import __version__ @@ -9,33 +10,43 @@ def main(): parser = argparse.ArgumentParser(description="A tool to automatically format Cairo code.") parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") parser.add_argument("files", metavar="file", type=str, nargs="+", help="File names") + parser.add_argument( + "--one_item_per_line", + action="store_true", + help=( + "Put each list item (e.g., function arguments) in a separate line, " + "if the list doesn't fit into a single line." + ), + ) action = parser.add_mutually_exclusive_group(required=False) action.add_argument("-i", dest="inplace", action="store_true", help="Edit files inplace.") action.add_argument("-c", dest="check", action="store_true", help="Check files' formats.") args = parser.parse_args() + return_code = 0 - for path in args.files: - old_content = open(path).read() if path != "-" else sys.stdin.read() - try: - new_content = parse_file( - old_content, filename="" if path == "-" else path - ).format() - except Exception as exc: - print(exc, file=sys.stderr) - return 2 - - if args.inplace: - assert path != "-", 'Using "-i" together with "-" is not supported.' - open(path, "w").write(new_content) - elif args.check: - assert path != "-", 'Using "-c" together with "-" is not supported.' - if old_content != new_content: - print(f'File "{path}" is incorrectly formatted.', file=sys.stderr) - return_code = 1 - else: - print(new_content, end="") + with set_one_item_per_line(args.one_item_per_line): + for path in args.files: + old_content = open(path).read() if path != "-" else sys.stdin.read() + try: + new_content = parse_file( + old_content, filename="" if path == "-" else path + ).format() + except Exception as exc: + print(exc, file=sys.stderr) + return 2 + + if args.inplace: + assert path != "-", 'Using "-i" together with "-" is not supported.' + open(path, "w").write(new_content) + elif args.check: + assert path != "-", 'Using "-c" together with "-" is not supported.' + if old_content != new_content: + print(f'File "{path}" is incorrectly formatted.', file=sys.stderr) + return_code = 1 + else: + print(new_content, end="") return return_code diff --git a/src/starkware/cairo/lang/compiler/debug_info.py b/src/starkware/cairo/lang/compiler/debug_info.py index 6a70cdeb..99f92d10 100644 --- a/src/starkware/cairo/lang/compiler/debug_info.py +++ b/src/starkware/cairo/lang/compiler/debug_info.py @@ -72,6 +72,11 @@ def add_autogen_file_contents(self): ) if not is_autogen: continue + + # The following asserts are for mypy. + assert input_file.filename is not None + assert input_file.content is not None + if input_file.filename in self.file_contents: assert self.file_contents[input_file.filename] == input_file.content, ( f'Found two versions of auto-generated file "{input_file.filename}":\n' diff --git a/src/starkware/cairo/lang/compiler/encode.py b/src/starkware/cairo/lang/compiler/encode.py index 76bd79d2..847c565b 100644 --- a/src/starkware/cairo/lang/compiler/encode.py +++ b/src/starkware/cairo/lang/compiler/encode.py @@ -32,7 +32,7 @@ def encode_instruction(element: BytecodeElement, prime: int) -> List[int]: Given an Instruction, returns a list of 1 or 2 integers representing the instruction. """ if isinstance(element, BytecodeData): - return [element.data] + return [element.data % prime] assert isinstance(element, Instruction) inst = element assert prime > 2 ** (3 * OFFSET_BITS + 16) diff --git a/src/starkware/cairo/lang/compiler/encode_test.py b/src/starkware/cairo/lang/compiler/encode_test.py index 48d2ced9..235de503 100644 --- a/src/starkware/cairo/lang/compiler/encode_test.py +++ b/src/starkware/cairo/lang/compiler/encode_test.py @@ -1,5 +1,9 @@ import dataclasses +import pytest + +from starkware.cairo.lang.compiler.ast.expr import ExprConst +from starkware.cairo.lang.compiler.ast.instructions import DefineWordInstruction, InstructionAst from starkware.cairo.lang.compiler.encode import ( decode_instruction, encode_instruction, @@ -154,3 +158,17 @@ def test_addap(): assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(*encoded) is False + + +@pytest.mark.parametrize("value", [-2, 2 * PRIME + 3]) +def test_out_of_range_dw(value): + """ + Tests that encode_instruction handles out of range words correctly. + """ + # Build the instruction explicitly as parse_instruction might return an instruction + # that needs simplification before encoding. + instruction = InstructionAst( + body=DefineWordInstruction(expr=ExprConst(val=value)), + inc_ap=False, + ) + assert encode_instruction(build_instruction(instruction), prime=PRIME) == [value % PRIME] diff --git a/src/starkware/cairo/lang/compiler/expression_evaluator.py b/src/starkware/cairo/lang/compiler/expression_evaluator.py index 2bc477b5..66b19d08 100644 --- a/src/starkware/cairo/lang/compiler/expression_evaluator.py +++ b/src/starkware/cairo/lang/compiler/expression_evaluator.py @@ -1,27 +1,29 @@ -from typing import MutableMapping, Optional +from typing import Generic, Mapping, Optional, TypeVar, cast from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer from starkware.cairo.lang.compiler.ast.expr import ExprConst, ExprDeref, Expression, ExprReg -from starkware.cairo.lang.compiler.error_handling import LocationError +from starkware.cairo.lang.compiler.error_handling import Location, LocationError from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system +T = TypeVar("T") + class ExpressionEvaluatorError(LocationError): pass -class ExpressionEvaluator(ExpressionSimplifier): +class ExpressionEvaluator(Generic[T], ExpressionSimplifier): prime: int def __init__( self, prime: int, - ap: Optional[int], - fp: int, - memory: MutableMapping[int, int], + ap: Optional[T], + fp: T, + memory: Mapping[T, T], identifiers: Optional[IdentifierManager] = None, ): super().__init__(prime=prime) @@ -31,7 +33,7 @@ def __init__( self.memory = memory self.identifiers = identifiers - def eval(self, expr: Expression) -> int: + def eval(self, expr: Expression) -> T: expr, expr_type = simplify_type_system(expr, identifiers=self.identifiers) assert isinstance( expr_type, (TypeFelt, TypePointer) @@ -39,14 +41,14 @@ def eval(self, expr: Expression) -> int: res = self.visit(expr) assert isinstance(res, ExprConst), f"Unable to evaluate expression '{expr.format()}'." assert self.prime is not None - return res.val % self.prime + return cast(T, res.val % self.prime) def visit_ExprReg(self, expr: ExprReg) -> ExprConst: if expr.reg is Register.AP: assert self.ap is not None, "Cannot substitute ap in the expression." - return ExprConst(val=self.ap, location=expr.location) + return self.to_expr_const(val=self.ap, location=expr.location) elif expr.reg is Register.FP: - return ExprConst(val=self.fp, location=expr.location) + return self.to_expr_const(val=self.fp, location=expr.location) else: raise NotImplementedError(f"Register of type {expr.reg} is not supported") @@ -56,6 +58,11 @@ def visit_ExprDeref(self, expr: ExprDeref) -> Expression: return expr assert self.prime is not None try: - return ExprConst(val=self.memory[addr.val % self.prime], location=expr.location) + return self.to_expr_const( + val=self.memory[cast(T, addr.val % self.prime)], location=expr.location + ) except Exception as exc: raise ExpressionEvaluatorError(str(exc), location=expr.location) + + def to_expr_const(self, val: T, location: Optional[Location]) -> ExprConst: + return ExprConst(val=cast(int, val), location=location) diff --git a/src/starkware/cairo/lang/compiler/expression_evaluator_test.py b/src/starkware/cairo/lang/compiler/expression_evaluator_test.py index c6a5ff4e..3e8bd5a1 100644 --- a/src/starkware/cairo/lang/compiler/expression_evaluator_test.py +++ b/src/starkware/cairo/lang/compiler/expression_evaluator_test.py @@ -7,7 +7,7 @@ def test_eval_registers(): fp = 10 prime = 13 - evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory={}) + evaluator = ExpressionEvaluator[int](prime=prime, ap=ap, fp=fp, memory={}) assert evaluator.eval(parse_expr("2 * ap + 3 * fp - 5")) == (2 * ap + 3 * fp - 5) % prime @@ -16,7 +16,7 @@ def test_eval_with_types(): fp = 10 prime = 13 - evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory={}) + evaluator = ExpressionEvaluator[int](prime=prime, ap=ap, fp=fp, memory={}) assert evaluator.eval(parse_expr("cast(ap, T*)")) == ap @@ -26,7 +26,7 @@ def test_eval_registers_and_memory(): prime = 13 memory = {(2 * ap + 3 * fp - 5) % prime: 7, 7: 5, 6: 0} - evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory=memory) + evaluator = ExpressionEvaluator[int](prime=prime, ap=ap, fp=fp, memory=memory) assert evaluator.eval(parse_expr("[2 * ap + 3 * fp - 5]")) == 7 assert ( evaluator.eval(parse_expr("[[2 * ap + 3 * fp - 5]] + 3 * ap")) diff --git a/src/starkware/cairo/lang/compiler/expression_transformer.py b/src/starkware/cairo/lang/compiler/expression_transformer.py index b3db7a2c..e3c74599 100644 --- a/src/starkware/cairo/lang/compiler/expression_transformer.py +++ b/src/starkware/cairo/lang/compiler/expression_transformer.py @@ -13,6 +13,7 @@ ExprHint, ExprIdentifier, ExprNeg, + ExprNewOperator, ExprOperator, ExprParentheses, ExprPow, @@ -62,7 +63,11 @@ def visit_ExprIdentifier(self, expr: ExprIdentifier): return ExprIdentifier(name=expr.name, location=self.location_modifier(expr.location)) def visit_ExprFutureLabel(self, expr: ExprFutureLabel): - return ExprFutureLabel(self.visit(expr.identifier)) + return ExprFutureLabel( + identifier=self.visit(expr.identifier), + is_typed=expr.is_typed, + location=self.location_modifier(expr.location), + ) def visit_ExprReg(self, expr: ExprReg): return ExprReg(reg=expr.reg, location=self.location_modifier(expr.location)) @@ -158,6 +163,13 @@ def visit_ExprFuncCall(self, expr: ExprFuncCall): location=self.location_modifier(expr.location), ) + def visit_ExprNewOperator(self, expr: ExprNewOperator): + return ExprNewOperator( + expr=self.visit(expr.expr), + is_typed=expr.is_typed, + location=self.location_modifier(expr.location), + ) + def location_modifier(self, location: Optional[Location]) -> Optional[Location]: """ This function can be overridden by subclasses to modify location information. diff --git a/src/starkware/cairo/lang/compiler/identifier_definition.py b/src/starkware/cairo/lang/compiler/identifier_definition.py index 5fc6f3a8..bc51d2ac 100644 --- a/src/starkware/cairo/lang/compiler/identifier_definition.py +++ b/src/starkware/cairo/lang/compiler/identifier_definition.py @@ -96,6 +96,19 @@ def sort_members(self, item, many, **kwargs): return item +@marshmallow_dataclass.dataclass +class TypeDefinition(IdentifierDefinition): + """ + Represents a type alias for another type. + """ + + TYPE: ClassVar[str] = "type_definition" + Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema + + cairo_type: CairoType = field(metadata=dict(marshmallow_field=CairoTypeAsStr(required=True))) + location: Optional[Location] = LocationField + + @marshmallow_dataclass.dataclass class LabelDefinition(IdentifierDefinition): TYPE: ClassVar[str] = "label" @@ -157,6 +170,7 @@ class IdentifierDefinitionSchema(OneOfSchema): ReferenceDefinition.TYPE: ReferenceDefinition.Schema, ScopeDefinition.TYPE: ScopeDefinition.Schema, StructDefinition.TYPE: StructDefinition.Schema, + TypeDefinition.TYPE: TypeDefinition.Schema, } def get_obj_type(self, obj): diff --git a/src/starkware/cairo/lang/compiler/identifier_manager.py b/src/starkware/cairo/lang/compiler/identifier_manager.py index 9900343b..1ba2fa89 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, List, MutableMapping, Optional, Set, Union +from typing import Dict, List, Mapping, MutableMapping, Optional, Set, Union from starkware.cairo.lang.compiler.identifier_definition import ( AliasDefinition, @@ -76,7 +76,7 @@ class IdentifierManager: def __init__(self): self.root = IdentifierScope(self, ScopedName()) - self.dict = {} + self.dict: MutableMapping[ScopedName, IdentifierDefinition] = {} def add_identifier(self, name: ScopedName, definition: IdentifierDefinition): """ @@ -87,7 +87,7 @@ def add_identifier(self, name: ScopedName, definition: IdentifierDefinition): @classmethod def from_dict( - cls, identifier_dict: Dict[ScopedName, IdentifierDefinition] + cls, identifier_dict: Mapping[ScopedName, IdentifierDefinition] ) -> "IdentifierManager": identifier_manager = cls() for name, identifier_definition in identifier_dict.items(): diff --git a/src/starkware/cairo/lang/compiler/import_loader_test.py b/src/starkware/cairo/lang/compiler/import_loader_test.py index a936bf6c..c842fb8e 100644 --- a/src/starkware/cairo/lang/compiler/import_loader_test.py +++ b/src/starkware/cairo/lang/compiler/import_loader_test.py @@ -1,6 +1,6 @@ import re from random import sample -from typing import Dict +from typing import Dict, List import pytest @@ -47,6 +47,7 @@ def test_unreachabale_file(): # Failed to parse internal module. with pytest.raises(ImportLoaderError) as e: collect_imports("root.file", read_file_from_dict(files)) + assert e.value.location is not None assert f""" {get_location_marks(files['root.file'], e.value.location)} {e.value.message} @@ -137,7 +138,7 @@ def test_topologycal_order(): # Initialize the dependencies DAG. A list of int lists. # j is in the i-th list iff i->j in the dependencies DAG. - dependencies = [[]] * N_VERTICES + dependencies: List[List[int]] = [[] for _ in range(N_VERTICES)] for i in range(N_VERTICES - N_NEIGHBORS): dependencies[i] = sample(range(i + 1, N_VERTICES), N_NEIGHBORS) diff --git a/src/starkware/cairo/lang/compiler/injector_test.py b/src/starkware/cairo/lang/compiler/injector_test.py index 1362a23e..73be0f6c 100644 --- a/src/starkware/cairo/lang/compiler/injector_test.py +++ b/src/starkware/cairo/lang/compiler/injector_test.py @@ -1,4 +1,9 @@ -from starkware.cairo.lang.compiler.ast.code_elements import CodeBlock, CodeElementScoped +from starkware.cairo.lang.compiler.ast.code_elements import ( + CodeBlock, + CodeElementFunction, + CodeElementIf, + CodeElementScoped, +) from starkware.cairo.lang.compiler.injector import inject_code_elements from starkware.cairo.lang.compiler.parser import parse from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -35,18 +40,28 @@ def test_injector(): expected_type=CodeBlock, ).code_elements[0] + if_block = block.code_elements[2].code_elm + assert isinstance(if_block, CodeElementIf) + + foo_func = block.code_elements[4].code_elm + assert isinstance(foo_func, CodeElementFunction) + injections = { id(block.code_elements[0].code_elm): [injected0, injected1], id(block.code_elements[1].code_elm): [], id(block.code_elements[2].code_elm): [injected0], - id(block.code_elements[2].code_elm.main_code_block.code_elements[0].code_elm): [injected1], - id(block.code_elements[4].code_elm.code_block.code_elements[0].code_elm): [injected0], + id(if_block.main_code_block.code_elements[0].code_elm): [injected1], + id(foo_func.code_block.code_elements[0].code_elm): [injected0], } + # We pass wrong types below (e.g., CodeBlock instead of CodeElement) to simplify the test. scoped_elm = CodeElementScoped( scope=ScopedName.from_string("main"), - code_elements=[block], + code_elements=[block], # type: ignore ) - block = inject_code_elements(ast=scoped_elm, injections=injections).code_elements[0] + block = inject_code_elements( # type: ignore + ast=scoped_elm, # type: ignore + injections=injections, + ).code_elements[0] assert ( block.format(allowed_line_length=100) == """\ diff --git a/src/starkware/cairo/lang/compiler/instruction_builder_test.py b/src/starkware/cairo/lang/compiler/instruction_builder_test.py index 8ed6e202..9038807b 100644 --- a/src/starkware/cairo/lang/compiler/instruction_builder_test.py +++ b/src/starkware/cairo/lang/compiler/instruction_builder_test.py @@ -596,6 +596,7 @@ def verify_exception(code_with_err): code = code_with_err.splitlines()[0] with pytest.raises(InstructionBuilderError) as e: parse_and_build(code) + assert e.value.location is not None assert ( get_location_marks(code, e.value.location) + "\n" + str(e.value.message) == code_with_err.rstrip() diff --git a/src/starkware/cairo/lang/compiler/lib/registers.cairo b/src/starkware/cairo/lang/compiler/lib/registers.cairo new file mode 100644 index 00000000..a835accf --- /dev/null +++ b/src/starkware/cairo/lang/compiler/lib/registers.cairo @@ -0,0 +1,17 @@ +# Returns the contents of the fp and pc registers of the calling function. +# The pc register's value is the address of the instruction that follows directly after the +# invocation of get_fp_and_pc(). +func get_fp_and_pc() -> (fp_val, pc_val): + # The call instruction itself already places the old fp and the return pc at [ap - 2], [ap - 1]. + return (fp_val=[ap - 2], pc_val=[ap - 1]) +end + +# Returns the content of the ap register just before this function was invoked. +@known_ap_change +func get_ap() -> (ap_val): + # Once get_ap() is invoked, fp points to ap + 2 (since the call instruction placed the old fp + # and pc in memory, advancing ap accordingly). + # Hence, the desired ap value is fp - 2. + let (fp_val, pc_val) = get_fp_and_pc() + return (ap_val=fp_val - 2) +end diff --git a/src/starkware/cairo/lang/compiler/module_reader.py b/src/starkware/cairo/lang/compiler/module_reader.py index 6ec6befe..cca300da 100644 --- a/src/starkware/cairo/lang/compiler/module_reader.py +++ b/src/starkware/cairo/lang/compiler/module_reader.py @@ -19,7 +19,7 @@ class ModuleReader: def __init__(self, paths: List[str], cairo_suffix: str): self.paths: List[str] = paths self.cairo_suffix: str = cairo_suffix - self.source_files_with_scopes: Set[Tuple[str, str]] = set() + self.source_files_with_scopes: Set[Tuple[str, ScopedName]] = set() @property def source_files(self): diff --git a/src/starkware/cairo/lang/compiler/parser.py b/src/starkware/cairo/lang/compiler/parser.py index 4755a6a7..d7986ffb 100644 --- a/src/starkware/cairo/lang/compiler/parser.py +++ b/src/starkware/cairo/lang/compiler/parser.py @@ -7,7 +7,7 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType from starkware.cairo.lang.compiler.ast.code_elements import CodeBlock, CodeElement -from starkware.cairo.lang.compiler.ast.expr import Expression +from starkware.cairo.lang.compiler.ast.expr import ExprConst, Expression from starkware.cairo.lang.compiler.ast.instructions import InstructionAst from starkware.cairo.lang.compiler.ast.module import CairoFile from starkware.cairo.lang.compiler.error_handling import InputFile, Location, LocationError @@ -69,6 +69,7 @@ def wrap_lark_error(err: LarkError, input_file: InputFile) -> Exception: "LPAR", "LSQB", "NONDET", + "NEW", "PYCONST", "SHORT_STRING", "register", @@ -122,6 +123,7 @@ def wrap_lark_error(err: LarkError, input_file: InputFile) -> Exception: "MEMBER": '"member"', "MINUS": '"-"', "NAMESPACE": '"namespace"', + "NEW": '"new"', "PLUS": '"+"', "RBRACE": '"}"', "RET": '"ret"', @@ -254,6 +256,16 @@ def parse_expr(code: str) -> Expression: return parse(None, code, "expr", Expression) +def parse_const(code: str) -> ExprConst: + """ + Parses the given string and returns an ExprConst instance. + """ + # Use parse_expr to share the lru cache. + expr_const = parse_expr(code=code) + assert isinstance(expr_const, ExprConst) + return expr_const + + def parse_type(code: str) -> CairoType: """ Parses the given string and returns an Expression instance. diff --git a/src/starkware/cairo/lang/compiler/parser_errors_test.py b/src/starkware/cairo/lang/compiler/parser_errors_test.py index 4ad20441..be8a05bb 100644 --- a/src/starkware/cairo/lang/compiler/parser_errors_test.py +++ b/src/starkware/cairo/lang/compiler/parser_errors_test.py @@ -175,3 +175,60 @@ def test_parser_error(): ^ """, ) + + +def test_new_operator_error(): + verify_exception( + """ +let a = new +""", + """ +file:?:?: Unexpected token Token('_NEWLINE', '\\n'). Expected: expression. +let a = new + ^ +""", + ) + + verify_exception( + """ +new = 5 +""", + """ +file:?:?: Unexpected token Token('EQUAL', '='). Expected: expression. +new = 5 + ^ +""", + ) + + verify_exception( + """ +new A() +""", + """ +file:?:?: Unexpected token Token('_NEWLINE', '\\n'). Expected one of: ".", "=", "[", operator. +new A() + ^ +""", + ) + + verify_exception( + """ +new A().f +""", + """ +file:?:?: Unexpected token Token('_NEWLINE', '\\n'). Expected one of: ".", "=", "[", operator. +new A().f + ^ +""", + ) + + verify_exception( + """ +new A() new +""", + """ +file:?:?: Unexpected token Token('NEW', 'new'). Expected one of: ".", "=", "[", operator. +new A() new + ^*^ +""", + ) diff --git a/src/starkware/cairo/lang/compiler/parser_test.py b/src/starkware/cairo/lang/compiler/parser_test.py index b88fd04e..229c025b 100644 --- a/src/starkware/cairo/lang/compiler/parser_test.py +++ b/src/starkware/cairo/lang/compiler/parser_test.py @@ -1,7 +1,15 @@ +from typing import List + import pytest from starkware.cairo.lang.compiler.ast.aliased_identifier import AliasedIdentifier -from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypeTuple +from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, + TypeCodeoffset, + TypeFelt, + TypePointer, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.code_elements import ( CodeElementImport, CodeElementReference, @@ -31,12 +39,13 @@ RetInstruction, ) from starkware.cairo.lang.compiler.ast.types import TypedIdentifier -from starkware.cairo.lang.compiler.error_handling import LocationError, get_location_marks +from starkware.cairo.lang.compiler.error_handling import Location, LocationError, get_location_marks from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.parser import ( parse, parse_code_element, + parse_const, parse_expr, parse_instruction, parse_type, @@ -47,13 +56,15 @@ def test_int(): - expr = parse_expr(" 01234 ") - assert expr == ExprConst(val=1234) - assert expr.format_str == "01234" - assert expr.format() == "01234" + expr_const = parse_const(" 01234 ") + assert expr_const == ExprConst(val=1234) + assert expr_const.format_str == "01234" + assert expr_const.format() == "01234" expr = parse_expr("-01234") + assert isinstance(expr, ExprNeg) assert expr == ExprNeg(val=ExprConst(val=1234)) + assert isinstance(expr.val, ExprConst) assert expr.val.format_str == "01234" assert expr.format() == "-01234" @@ -61,13 +72,15 @@ def test_int(): def test_hex_int(): - expr = parse_expr(" 0x1234 ") - assert expr == ExprConst(val=0x1234) - assert expr.format_str == "0x1234" - assert expr.format() == "0x1234" + expr_const = parse_const(" 0x1234 ") + assert expr_const == ExprConst(val=0x1234) + assert expr_const.format_str == "0x1234" + assert expr_const.format() == "0x1234" expr = parse_expr("-0x01234") + assert isinstance(expr, ExprNeg) assert expr == ExprNeg(val=ExprConst(val=0x1234)) + assert isinstance(expr.val, ExprConst) assert expr.val.format_str == "0x01234" assert expr.format() == "-0x01234" @@ -75,18 +88,25 @@ def test_hex_int(): def test_short_string(): - expr = parse_expr(" 'ab' ") - assert expr == ExprConst(val=ord("a") * 256 + ord("b")) - assert expr.format_str == "'ab'" - assert expr.format() == "'ab'" + expr_const = parse_const(" 'ab' ") + assert expr_const == ExprConst(val=ord("a") * 256 + ord("b")) + assert expr_const.format_str == "'ab'" + assert expr_const.format() == "'ab'" expr = parse_expr("-'abc'") + assert isinstance(expr, ExprNeg) assert expr == ExprNeg(val=ExprConst(val=int.from_bytes(b"abc", "big"))) + assert isinstance(expr.val, ExprConst) assert expr.val.format_str == "'abc'" assert expr.format() == "-'abc'" assert parse_expr("-'abcd'") == parse_expr("- 'abcd'") + assert parse_const(r"'a\x00\x12b'").val == int.from_bytes([ord("a"), 0, 18, ord("b")], "big") + assert parse_const(r"'\x00\x1'").val == int.from_bytes( + [0, ord("\\"), ord("x"), ord("1")], "big" + ) + verify_exception( """let x = '0123456789012345678901234567890123456789'""", """ @@ -107,17 +127,71 @@ def test_short_string(): def test_types(): assert isinstance(parse_type("felt"), TypeFelt) + assert parse_type("felt").format() == "felt" + assert isinstance(parse_type("codeoffset"), TypeCodeoffset) + assert parse_type("codeoffset").format() == "codeoffset" assert parse_type("my_namespace.MyStruct * *").format() == "my_namespace.MyStruct**" assert parse_type("my_namespace.MyStruct*****").format() == "my_namespace.MyStruct*****" def test_type_tuple(): typ = parse_type("(felt)") - assert typ == TypeTuple(members=[TypeFelt()]) + assert typ == TypeTuple.unnamed([TypeFelt()]) assert typ.format() == "(felt)" assert parse_type("( felt, felt* , (felt, T.S,)* )").format() == "(felt, felt*, (felt, T.S)*)" +def test_type_named_tuple(): + typ = parse_type("(a:felt, b: felt )") + assert isinstance(typ, TypeTuple) + assert typ.members[0].name == "a" + assert typ.members[1].name == "b" + assert typ.format() == "(a : felt, b : felt)" + assert ( + parse_type("( a:(felt, felt*) , b:(c:felt,)* )").format() + == "(a : (felt, felt*), b : (c : felt)*)" + ) + + # With new lines. + assert parse_type("(\n a:felt\n\n, \n b: felt \n )").format() == "(a : felt, b : felt)" + + # With comments (parsing, but no auto-formatting). + typ2 = parse_type("( # This is a comment.\na:felt\n\n, \n b: felt \n )") + assert typ2 == typ + with pytest.raises( + FormattingError, + match="Comments inside expressions are not supported by the auto-formatter.", + ): + typ2.format() + + +def test_type_named_tuple_failure(): + verify_exception( + "local x : (a* : felt)", + """ +file:?:?: Unexpected token Token('COLON', ':'). Expected one of: ")", "*", ",". +local x : (a* : felt) + ^ +""", + ) + verify_exception( + "local x : (a.b : felt)", + """ +file:?:?: Unexpected '.' in name. +local x : (a.b : felt) + ^*^ +""", + ) + verify_exception( + "local x : (a, b : felt)", + """ +file:?:?: All fields in a named tuple must have a name. +local x : (a, b : felt) + ^***********^ +""", + ) + + def test_identifier_and_dot(): assert parse_expr("x.y . z + x ").format() == "x.y.z + x" assert parse_expr(" [x]. y . z").format() == "[x].y.z" @@ -245,7 +319,7 @@ def test_tuple_expr(): assert parse_expr("( 2,)").format() == "(2,)" assert parse_expr("( 1 , ap)").format() == "(1, ap)" assert parse_expr("( 1 , ap, )").format() == "(1, ap,)" - assert parse_expr("( 1 , a=2, b=(c=()))").format() == "(1, a=2, b=(c=()))" + assert parse_expr("( a=1 , b=2, c=(d= ()))").format() == "(a=1, b=2, c=(d=()))" verify_exception( "let x = (,)", @@ -277,6 +351,14 @@ def test_tuple_expr(): file:?:?: Expected a comma before this expression. (b)) ^*^ +""", + ) + verify_exception( + "let x = (1 , a=2, b=(c=()))", + """ +file:?:?: All fields in a named tuple must have a name. +let x = (1 , a=2, b=(c=())) + ^*****************^ """, ) @@ -616,6 +698,14 @@ def test_return_value_reference(): parse_expr("let z = call x; ap++") +def test_new_operator(): + expr = parse_expr("new Struct(a = 1, b= 2 )") + assert expr.format() == "new Struct(a=1, b=2)" + + res = parse_code_element("new ( a = 1, b= 2 ) + 5 = 17 + new 7 + 2") + assert res.format(allowed_line_length=100) == "(new (a=1, b=2)) + 5 = 17 + (new 7) + 2" + + def test_return(): res = parse_code_element("return( 1, \na= 2 )") assert res.format(allowed_line_length=100) == "return (1, a=2)" @@ -759,11 +849,15 @@ def test_func_expr(): def test_parent_location(): - parent_location = (parse_expr("1 + 2").location, "An error ocurred while processing:") + location = parse_expr("1 + 2").location + assert location is not None + parent_location = (location, "An error ocurred while processing:") - location = parse_code_element( + code_element = parse_code_element( "let x = 3 + 4", parser_context=ParserContext(parent_location=parent_location) - ).expr.location + ) + assert isinstance(code_element, CodeElementReference) + location = code_element.expr.location location_err = LocationError(message="Error", location=location) assert ( str(location_err) @@ -793,19 +887,11 @@ def test_locations(): lines = code_with_marks.splitlines() code, marks = lines[0], lines[1:] - expr = parse_instruction(code) - exprs = [ - expr, - expr.body, - expr.body.a, - expr.body.a.addr, - expr.body.b, - expr.body.b.addr, - expr.body.b.addr.a, - expr.body.b.addr.b, - ] - for expr, mark in safe_zip(exprs, marks): - assert get_location_marks(code, expr.location) == code + "\n" + mark + expr = parse_instruction(code=code) + for inner_expr, mark in safe_zip(expr.get_subtree(), marks): + location = getattr(inner_expr, "location") + assert isinstance(location, Location) + assert get_location_marks(content=code, location=location) == code + "\n" + mark def test_pointer(): @@ -822,10 +908,13 @@ def test_pointer(): lines = code_with_marks.splitlines() code, marks = lines[0], lines[1:] typ = parse_type(code) - exprs = [ + types: List[CairoType] = [ typ, ] - for i in range(len(marks) - 1): - exprs.append(exprs[-1].pointee) - for expr, mark in safe_zip(exprs, marks): - assert get_location_marks(code, expr.location) == code + "\n" + mark + for _ in range(len(marks) - 1): + typ = types[-1] + assert isinstance(typ, TypePointer) + types.append(typ.pointee) + for typ, mark in safe_zip(types, marks): + assert typ.location is not None + assert get_location_marks(code, typ.location) == code + "\n" + mark diff --git a/src/starkware/cairo/lang/compiler/parser_transformer.py b/src/starkware/cairo/lang/compiler/parser_transformer.py index 5e6e200e..8a2afc89 100644 --- a/src/starkware/cairo/lang/compiler/parser_transformer.py +++ b/src/starkware/cairo/lang/compiler/parser_transformer.py @@ -1,5 +1,6 @@ import dataclasses -from typing import List, Optional +import re +from typing import List, Optional, Tuple from lark import Transformer, v_args @@ -7,6 +8,8 @@ from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, + TypeCodeoffset, TypeFelt, TypePointer, TypeStruct, @@ -35,6 +38,7 @@ CodeElementStaticAssert, CodeElementTailCall, CodeElementTemporaryVariable, + CodeElementTypeDef, CodeElementUnpackBinding, CodeElementWith, CodeElementWithAttr, @@ -52,6 +56,7 @@ ExprHint, ExprIdentifier, ExprNeg, + ExprNewOperator, ExprOperator, ExprParentheses, ExprPow, @@ -116,6 +121,17 @@ class Comma: location: Optional[Location] +@dataclasses.dataclass +class CommaSeparatedWithNotes: + """ + Represents a list of comma separated values, such as expressions or types. + """ + + args: list + notes: List[Notes] + has_trailing_comma: bool + + class ParserTransformer(Transformer): """ Transforms the lark tree into an AST based on the classes defined in ast/*.py. @@ -128,12 +144,79 @@ def __init__(self, input_file: InputFile, parser_context: Optional[ParserContext def __default__(self, data: str, children, meta): raise TypeError(f"Unable to parse tree node of type {data}") + # Comma separated list with notes. + + @v_args(meta=True) + def comma(self, value, meta): + return Comma(location=self.meta2loc(meta)) + + def comma_separated_with_notes(self, value) -> CommaSeparatedWithNotes: + saw_comma = True + all_notes: List[Notes] = [] + current_notes: List[Notes] = [] + args: list = [] + for v in value: + if isinstance(v, Notes): + # Join the notes before and after the comma. + current_notes.append(v) + elif isinstance(v, Comma): + if saw_comma: + raise ParserError("Unexpected comma.", location=v.location) + saw_comma = True + else: + if not saw_comma: + raise ParserError( + "Expected a comma before this expression.", location=v.location + ) + all_notes.append(Notes.merge(current_notes)) + args.append(v) + + # Reset state. + saw_comma = False + current_notes = [] + + all_notes.append(Notes.merge(current_notes)) + + return CommaSeparatedWithNotes( + args=args, + notes=all_notes, + has_trailing_comma=saw_comma, + ) + # Types. + @v_args(meta=True) + def named_type(self, value, meta) -> TypeTuple.Item: + name: Optional[str] + if len(value) == 1: + # Unnamed type. + (typ,) = value + name = None + if isinstance(typ, ExprIdentifier): + typ = self.type_struct([typ]) + elif len(value) == 2: + # Named type. + identifier, typ = value + assert isinstance(identifier, ExprIdentifier) + assert isinstance(typ, CairoType) + if ScopedName.SEPARATOR in identifier.name: + raise ParserError( + f"Unexpected '{ScopedName.SEPARATOR}' in name.", location=identifier.location + ) + name = identifier.name + else: + raise NotImplementedError(f"Unexpected number of values. {value}") + + return TypeTuple.Item(name=name, typ=typ, location=self.meta2loc(meta)) + @v_args(meta=True) def type_felt(self, value, meta): return TypeFelt(location=self.meta2loc(meta)) + @v_args(meta=True) + def type_codeoffset(self, value, meta): + return TypeCodeoffset(location=self.meta2loc(meta)) + def type_struct(self, value): assert len(value) == 1 and isinstance(value[0], ExprIdentifier) return TypeStruct( @@ -155,48 +238,28 @@ def type_pointer2(self, value, meta): ) @v_args(meta=True) - def type_tuple(self, value, meta): - return TypeTuple(members=value, location=self.meta2loc(meta)) - - @v_args(meta=True) - def comma(self, value, meta): - return Comma(location=self.meta2loc(meta)) + def type_tuple(self, value: Tuple[CommaSeparatedWithNotes], meta): + (lst,) = value + is_named = set((member.name is not None) for member in lst.args) + if is_named == {True, False}: + raise ParserError( + "All fields in a named tuple must have a name.", location=self.meta2loc(meta) + ) + return TypeTuple( + members=lst.args, + notes=lst.notes, + has_trailing_comma=lst.has_trailing_comma, + location=self.meta2loc(meta), + ) # Expression. @v_args(meta=True) - def arg_list(self, value, meta): - saw_comma = True - all_notes: List[Notes] = [] - current_notes: List[Notes] = [] - args: List[ExprAssignment] = [] - for v in value: - if isinstance(v, ExprAssignment): - if not saw_comma: - raise ParserError( - "Expected a comma before this expression.", location=v.location - ) - all_notes.append(Notes.merge(current_notes)) - args.append(v) - - # Reset state. - saw_comma = False - current_notes = [] - elif isinstance(v, Notes): - # Join the notes before and after the comma. - current_notes.append(v) - elif isinstance(v, Comma): - if saw_comma: - raise ParserError("Unexpected comma.", location=v.location) - saw_comma = True - else: - raise NotImplementedError(f"Unexpected parser item {type(v).__name__}") - - all_notes.append(Notes.merge(current_notes)) - + def arg_list(self, value: Tuple[CommaSeparatedWithNotes], meta): + (lst,) = value return ArgList( - args=args, - notes=all_notes, - has_trailing_comma=saw_comma, + args=lst.args, + notes=lst.notes, + has_trailing_comma=lst.has_trailing_comma, location=self.meta2loc(meta), ) @@ -245,6 +308,8 @@ def atom_short_string(self, value, meta): text_bytes = text.encode("ascii") except UnicodeEncodeError: raise ParserError(f"Expected an ascii string. Found: {repr(text)}.", location=location) + + text_bytes = backslash_to_hex(text_bytes) return ExprConst( val=int.from_bytes(text_bytes, "big"), format_str=token_text, @@ -295,6 +360,10 @@ def unary_addressof(self, value, meta): def unary_neg(self, value, meta): return ExprNeg(val=value[0], location=self.meta2loc(meta)) + @v_args(meta=True) + def unary_new_operator(self, value, meta): + return ExprNewOperator(expr=value[0], is_typed=True, location=self.meta2loc(meta)) + @v_args(meta=True) def expr_pow(self, value, meta): return ExprPow(a=value[0], b=value[2], notes=value[1], location=self.meta2loc(meta)) @@ -332,6 +401,12 @@ def atom_tuple_or_parentheses(self, value, meta): val=args[0].expr, notes=arg_list.notes[0], location=arg_list.location ) + is_named = set((member.identifier is not None) for member in args) + if is_named == {True, False}: + raise ParserError( + "All fields in a named tuple must have a name.", location=self.meta2loc(meta) + ) + return ExprTuple(members=arg_list, location=self.meta2loc(meta)) # Register. @@ -541,8 +616,10 @@ def code_element_func_call(self, value): def code_element_label(self, value): identifier = value[0] - if "." in identifier.name: - raise ParserError("Unexpected '.' in label name.", location=identifier.location) + if ScopedName.SEPARATOR in identifier.name: + raise ParserError( + f"Unexpected '{ScopedName.SEPARATOR}' in label name.", location=identifier.location + ) return CodeElementLabel(identifier=identifier) @v_args(meta=True) @@ -621,6 +698,12 @@ def code_element_struct(self, value): decorators=decorators, ) + @v_args(meta=True) + def code_element_typedef(self, value, meta): + return CodeElementTypeDef( + identifier=value[0], cairo_type=value[1], location=self.meta2loc(meta) + ) + def code_element_with(self, value): assert len(value) > 1 return CodeElementWith( @@ -769,3 +852,12 @@ def meta2loc(self, meta): input_file=self.input_file, parent_location=self.parser_context.parent_location, ) + + +def backslash_to_hex(value: bytes) -> bytes: + r""" + Replaces substrings of the form '\x**' with the corresponding byte. + """ + pattern = br"\\x([0-9a-fA-F]{2})" + replacer = lambda m: bytes.fromhex(m.group(1).decode("ascii")) + return re.sub(pattern, replacer, value) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py index 43e78d9f..c75da62b 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum, auto -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt from starkware.cairo.lang.compiler.ast.code_elements import ( @@ -15,17 +15,20 @@ ExprHint, ExprIdentifier, ExprNeg, + ExprNewOperator, ExprOperator, ExprReg, ) from starkware.cairo.lang.compiler.ast.types import TypedIdentifier from starkware.cairo.lang.compiler.error_handling import Location +from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.instruction_builder import ( InstructionBuilderError, _parse_offset, _parse_register_offset, ) +from starkware.cairo.lang.compiler.preprocessor.flow import RegTrackingData from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.references import translate_ap @@ -55,8 +58,20 @@ def get_fp_val(self, location: Optional[Location]) -> Expression: Usually, this should resolve the expression "__fp__". """ + @abstractmethod + def visit(self, elm: CodeElement): + """ + Visits the given code element. + """ -def is_simple_deref(expr: Expression): + @abstractmethod + def get_ap_tracking(self) -> RegTrackingData: + """ + Returns the current ap tracking data. + """ + + +def is_simple_deref(expr: Expression) -> bool: """ Returns True if expr is of the form [reg + offset]. """ @@ -81,10 +96,9 @@ def __init__(self, context: CompoundExpressionContext): """ Constructs a CompoundExpressionVisitor. """ - self.code_elements: List[CodeElement] = [] self.context = context - # Number of variables created so far. - self.n_vars: int = 0 + # Ap tracking information at the start of the processing. + self.ap_tracking = context.get_ap_tracking() def rewrite(self, expr: Expression, sim: SimplicityLevel): """ @@ -108,7 +122,12 @@ def rewrite_ExprReg(self, expr: ExprReg, sim: SimplicityLevel): location=expr.location, ) elif expr.reg is Register.FP: - return self.rewrite(expr=self.context.get_fp_val(expr.location), sim=sim) + # Note that self.context.get_fp_val returns the value of the __fp__ reference translated + # according to the change in ap (caused by previous calls to rewrite() and wrap()). + # Since rewrite expects to get the expression untranslated, we call untranslate_ap. + return self.rewrite( + expr=self.untranslate_ap(self.context.get_fp_val(expr.location)), sim=sim + ) else: raise NotImplementedError(f"Unknown register {expr.reg}.") @@ -163,6 +182,9 @@ def rewrite_ExprDeref(self, expr: ExprDeref, sim: SimplicityLevel): return expr if sim is SimplicityLevel.OPERATION else self.wrap(expr) def rewrite_ExprFutureLabel(self, expr: ExprFutureLabel, sim: SimplicityLevel): + assert ( + not expr.is_typed + ), "The CompoundExpressionVisitor expects ExprFutureLabel expressions to be untyped." # Treat this as a constant. if sim in [SimplicityLevel.DEREF_CONST, SimplicityLevel.OPERATION]: return expr @@ -171,13 +193,18 @@ def rewrite_ExprFutureLabel(self, expr: ExprFutureLabel, sim: SimplicityLevel): def rewrite_ExprHint(self, expr: ExprHint, sim: SimplicityLevel): return self.wrap(expr) + def rewrite_ExprNewOperator(self, expr: ExprNewOperator, sim: SimplicityLevel): + assert ( + not expr.is_typed + ), "The CompoundExpressionVisitor expects ExprNewOperator expressions to be untyped." + return self.wrap(expr) + def wrap(self, expr: Expression) -> ExprIdentifier: identifier = ExprIdentifier(name=self.context.new_tempvar_name(), location=expr.location) expr = self.translate_ap(expr) - self.n_vars += 1 - self.code_elements.append( + self.context.visit( CodeElementTemporaryVariable( typed_identifier=TypedIdentifier( identifier=identifier, expr_type=TypeFelt(location=expr.location) @@ -188,15 +215,32 @@ def wrap(self, expr: Expression) -> ExprIdentifier: ) return identifier - def translate_ap(self, expr): - return translate_ap(expr, self.n_vars) + def translate_ap(self, expr: Expression) -> Expression: + """ + Translates ap according to the change in the ap register from the beginning of the use + of the class. + """ + return translate_ap(expr, self.context.get_ap_tracking() - self.ap_tracking) + + def untranslate_ap(self, expr: Expression) -> Expression: + """ + Gets an expression whose ap was translated (according to the change in the ap register + from the beginning of the use of the class) and reverts the translation. + This function is the inverse of translate_ap. + """ + # Use the simplifier to convert (ap + offset_1) + offset_2 to ap + (offset_1 + offset_2), + # since the expressions are assumed to be simplified. + simplifier = ExpressionSimplifier() + return simplifier.visit( + translate_ap(expr, self.ap_tracking - self.context.get_ap_tracking()) + ) def process_compound_expressions( exprs: List[Expression], simplicity: Union[SimplicityLevel, List[SimplicityLevel]], context: CompoundExpressionContext, -) -> Tuple[List[CodeElement], List[Expression]]: +) -> List[Expression]: """ Rewrites the given list of expressions, by adding temporary variables, in the required simiplicity levels. @@ -206,8 +250,7 @@ def process_compound_expressions( 'simplicity' may be one SimplicityLevel for all the expressions or a list of SimplicityLevel for each expression separately. - Returns a list of code elements with the temporary variables and the list of simplified - expressions. + Returns the list of simplified expressions. """ if isinstance(simplicity, SimplicityLevel): simplicity = [simplicity] * len(exprs) @@ -221,12 +264,12 @@ def process_compound_expressions( # Second, translate ap according to the total number of instructions. simplified_exprs = [visitor.translate_ap(expr) for expr in simplified_exprs] - return visitor.code_elements, simplified_exprs + return simplified_exprs def process_compound_assert( expr_a: Expression, expr_b: Expression, context: CompoundExpressionContext -): +) -> List[Expression]: """ A version of process_compound_expressions() for assert instructions. Takes two expressions and returns them simplified to levels [DEREF, OPERATION] or [OPERATION, DEREF], @@ -246,7 +289,6 @@ def process_compound_assert( # Left-hand side is already too complicated for DEREF. simplicity = [SimplicityLevel.OPERATION, SimplicityLevel.DEREF] - code_elements, exprs = process_compound_expressions( + return process_compound_expressions( exprs=[expr_a, expr_b], simplicity=simplicity, context=context ) - return code_elements, exprs diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py index 670274f5..3b6237bf 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py @@ -1,8 +1,12 @@ import itertools -from typing import Optional +from typing import List, Optional import pytest +from starkware.cairo.lang.compiler.ast.code_elements import ( + CodeElement, + CodeElementTemporaryVariable, +) from starkware.cairo.lang.compiler.ast.expr import Expression from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.parser import parse_expr @@ -18,11 +22,14 @@ strip_comments_and_linebreaks, verify_exception, ) +from starkware.cairo.lang.compiler.preprocessor.reg_tracking import RegTrackingData class CompoundExpressionTestContext(CompoundExpressionContext): def __init__(self): self.tempvar_name_counter = itertools.count(0) + self.code_elements: List[CodeElement] = [] + self.ap_tracking = RegTrackingData() def new_tempvar_name(self) -> str: return f"x{next(self.tempvar_name_counter)}" @@ -30,6 +37,17 @@ def new_tempvar_name(self) -> str: def get_fp_val(self, location: Optional[Location]) -> Expression: raise NotImplementedError("fp is not supported in the test.") + def visit(self, elm: CodeElement): + def group_alloc(): + raise NotImplementedError("group_alloc() is not expected to be called.") + + assert isinstance(elm, CodeElementTemporaryVariable) + self.ap_tracking = self.ap_tracking.add(1, group_alloc=group_alloc) + self.code_elements.append(elm) + + def get_ap_tracking(self) -> RegTrackingData: + return self.ap_tracking + @pytest.mark.parametrize( "expr_str, to_operation, to_deref_const, to_deref_offset, to_deref", @@ -106,12 +124,13 @@ def test_compound_expression_visitor( (SimplicityLevel.DEREF_OFFSET, to_deref_offset), (SimplicityLevel.DEREF, to_deref), ]: - visitor = CompoundExpressionVisitor(context=CompoundExpressionTestContext()) + context = CompoundExpressionTestContext() + visitor = CompoundExpressionVisitor(context=context) res = visitor.rewrite(expr, sim) assert ( "".join( code_element.format(allowed_line_length=100) + "; " - for code_element in visitor.code_elements + for code_element in context.code_elements ) + res.format() == expected_result @@ -119,7 +138,8 @@ def test_compound_expression_visitor( def test_compound_expression_visitor_long(): - visitor = CompoundExpressionVisitor(context=CompoundExpressionTestContext()) + context = CompoundExpressionTestContext() + visitor = CompoundExpressionVisitor(context=context) res = visitor.rewrite( parse_expr("[ap + 100] - [fp] * [[-[ap + 200] / [ap + 300]]] + [ap] * [ap]"), SimplicityLevel.OPERATION, @@ -127,7 +147,7 @@ def test_compound_expression_visitor_long(): assert ( "".join( code_element.format(allowed_line_length=100) + "\n" - for code_element in visitor.code_elements + for code_element in context.code_elements ) == """\ tempvar x0 : felt = [ap - 0 + 200] * (-1) @@ -143,12 +163,13 @@ def test_compound_expression_visitor_long(): def test_compound_expression_visitor_inverses(): - visitor = CompoundExpressionVisitor(context=CompoundExpressionTestContext()) + context = CompoundExpressionTestContext() + visitor = CompoundExpressionVisitor(context=context) res = visitor.rewrite(parse_expr("2 - 1 / [ap] + [ap] / 3"), SimplicityLevel.DEREF) assert ( "".join( code_element.format(allowed_line_length=100) + "\n" - for code_element in visitor.code_elements + for code_element in context.code_elements ) == """\ tempvar x0 : felt = 2 @@ -163,7 +184,8 @@ def test_compound_expression_visitor_inverses(): def test_process_compound_expressions(): - code_elements, res = process_compound_expressions( + context = CompoundExpressionTestContext() + res = process_compound_expressions( list( map( parse_expr, @@ -183,11 +205,12 @@ def test_process_compound_expressions(): SimplicityLevel.OPERATION, SimplicityLevel.OPERATION, ], - context=CompoundExpressionTestContext(), + context=context, ) assert ( "".join( - code_element.format(allowed_line_length=100) + "\n" for code_element in code_elements + code_element.format(allowed_line_length=100) + "\n" + for code_element in context.code_elements ) == """\ tempvar x0 : felt = [ap - 0 - 1] * [ap - 0 - 1] diff --git a/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py b/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py index f8b6ac96..2ab58e75 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py @@ -30,7 +30,15 @@ def default_pass_manager( additional_scopes_to_compile: Optional[Set[ScopedName]] = None, ) -> PassManager: manager = PassManager() - manager.add_stage("module_collector", ModuleCollector(read_module=read_module)) + manager.add_stage( + "module_collector", + ModuleCollector( + read_module=read_module, + additional_modules=[ + "starkware.cairo.lang.compiler.lib.registers", + ], + ), + ) manager.add_stage( "unique_label_creator", VisitorStage(lambda context: UniqueLabelCreator(), modify_ast=True) ) @@ -100,8 +108,38 @@ def __init__( self.read_module = read_module self.additional_modules = [] if additional_modules is None else list(additional_modules) + def collect_module( + self, code: str, filename: str, context: PassManagerContext, visited_modules: Set[str] + ): + """ + Collects the module with the given code and filename. + Updates 'context' and 'visited_modules'. + """ + + # Function used to read files given module names. + # The root module (filename) is handled separately, for this module code is returned. + def read_file_fixed(name): + return (code, filename) if name == filename else self.read_module(name) + + files = collect_imports(filename, read_file=read_file_fixed) + for module_name, ast in files.items(): + # Check if the module is one of the files given in 'context.codes'. + is_main_scope = module_name == filename + if is_main_scope: + scope = context.main_scope + else: + scope = ScopedName.from_string(module_name) + if module_name in visited_modules: + continue + visited_modules.add(module_name) + context.modules.append(CairoModule(cairo_file=ast, module_name=scope)) + def run(self, context: PassManagerContext): - visited_modules = set() + visited_modules: Set[str] = set() + for code, filename in context.start_codes: + self.collect_module( + code=code, filename=filename, context=context, visited_modules=visited_modules + ) for additional_module in self.additional_modules: files = collect_imports(additional_module, read_file=self.read_module) @@ -113,20 +151,6 @@ def run(self, context: PassManagerContext): context.modules.append(CairoModule(cairo_file=ast, module_name=scope)) for code, filename in context.codes: - # Function used to read files given module names. - # The root module (filename) is handled separately, for this module code is returned. - def read_file_fixed(name): - return (code, filename) if name == filename else self.read_module(name) - - files = collect_imports(filename, read_file=read_file_fixed) - for module_name, ast in files.items(): - # Check if the module is one of the files given in 'context.codes'. - is_main_scope = module_name == filename - if is_main_scope: - scope = context.main_scope - else: - scope = ScopedName.from_string(module_name) - if module_name in visited_modules: - continue - visited_modules.add(module_name) - context.modules.append(CairoModule(cairo_file=ast, module_name=scope)) + self.collect_module( + code=code, filename=filename, context=context, visited_modules=visited_modules + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py index 185a5085..6f61758a 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py @@ -1,7 +1,12 @@ from typing import Dict, List, Optional, Set from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction, CodeElementImport -from starkware.cairo.lang.compiler.ast.expr import ExprAssignment, ExprDot, ExprIdentifier +from starkware.cairo.lang.compiler.ast.expr import ( + ExprAssignment, + ExprDot, + ExprIdentifier, + ExprNewOperator, +) from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.ast.visitor import Visitor from starkware.cairo.lang.compiler.error_handling import Location @@ -65,6 +70,22 @@ def visit_ExprDot(self, expr: ExprDot): # We override the default visitor, since we must not visit expr.member. self.visit(expr.expr) + def visit_ExprNewOperator(self, expr: ExprNewOperator): + if self.current_function is None: + # The new operator is not supported outside of a function since the 'get_ap' needs to be + # added as a dependency of some function. + raise PreprocessorError( + "The new operator is not supported outside of a function.", location=expr.location + ) + + super().visit(expr.expr) + + self.add_identifier( + name=ScopedName.from_string("starkware.cairo.lang.compiler.lib.registers.get_ap"), + location=expr.location, + is_resolved=True, + ) + def visit_CodeElementFunction(self, elm: CodeElementFunction): if elm.element_type == "func": # Update self.current_function. diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py index a25a4487..736f3d6f 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py @@ -1,8 +1,9 @@ import dataclasses -from typing import Dict, Optional +from typing import Dict, Optional, Tuple, Type from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, + TypeCodeoffset, TypeFelt, TypePointer, TypeStruct, @@ -16,6 +17,7 @@ FutureIdentifierDefinition, IdentifierDefinition, StructDefinition, + TypeDefinition, ) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition @@ -71,11 +73,15 @@ def add_name_definition( self.identifiers.add_identifier(name, identifier_definition) self.identifier_locations[name] = location - def get_struct_definition( - self, name: ScopedName, location: Optional[Location] - ) -> StructDefinition: + def get_identifier_definition( + self, + name: ScopedName, + supported_types: Tuple[Type[IdentifierDefinition], ...], + location: Optional[Location], + ) -> IdentifierDefinition: """ - Returns the struct definition that corresponds to the given identifier. + Returns the definition that corresponds to the given identifier. + Verifies that it's one of the given types. location is used if there is an error. """ @@ -85,15 +91,33 @@ def get_struct_definition( except IdentifierError as exc: raise PreprocessorError(str(exc), location=location) - struct_def = res.identifier_definition - if not isinstance(struct_def, StructDefinition): + identifier_definition = res.identifier_definition + if not isinstance(identifier_definition, supported_types): + possible_types = " or ".join( + supported_type.TYPE for supported_type in supported_types # type: ignore + ) raise PreprocessorError( f"""\ -Expected '{res.canonical_name}' to be a {StructDefinition.TYPE}. Found: '{struct_def.TYPE}'.""", +Expected '{res.canonical_name}' to be {possible_types}. Found: '{identifier_definition.TYPE}'.""", location=location, ) - return struct_def + return identifier_definition + + def get_struct_definition( + self, + name: ScopedName, + location: Optional[Location], + ) -> StructDefinition: + """ + Returns the struct definition that corresponds to the given identifier. + location is used if there is an error. + """ + res = self.get_identifier_definition( + name=name, supported_types=(StructDefinition,), location=location + ) + assert isinstance(res, StructDefinition) + return res def try_get_struct_definition(self, name: ScopedName) -> Optional[StructDefinition]: """ @@ -104,11 +128,16 @@ def try_get_struct_definition(self, name: ScopedName) -> Optional[StructDefiniti except PreprocessorError: return None - def get_canonical_struct_name(self, scoped_name: ScopedName, location: Optional[Location]): + def verify_possibly_future_struct( + self, + identifier_definition: IdentifierDefinition, + scoped_name: ScopedName, + location: Optional[Location], + ): """ - Returns the canonical name for the struct given by scoped_name in the current - accessible_scopes. - This function also works for structs that do not have a StructDefinition yet. + Checks that the given IdentifierSearchResult represents a struct. + This function also works for structs that do not have a StructDefinition yet + (FutureDefinition). For example when parsing: struct S: @@ -116,15 +145,11 @@ def get_canonical_struct_name(self, scoped_name: ScopedName, location: Optional[ end We have to lookup S before S is defined in the identifier manager. - location is used if there is an error. + scoped_name and location are used if there is an error. """ - result = self.identifiers.search(self.accessible_scopes, scoped_name) - canonical_name = result.get_canonical_name() - identifier_def = result.identifier_definition - - identifier_type = identifier_def.TYPE - if isinstance(identifier_def, FutureIdentifierDefinition): - identifier_type = identifier_def.identifier_type.TYPE # type: ignore + identifier_type = identifier_definition.TYPE + if isinstance(identifier_definition, FutureIdentifierDefinition): + identifier_type = identifier_definition.identifier_type.TYPE # type: ignore if identifier_type != StructDefinition.TYPE: raise PreprocessorError( @@ -133,13 +158,11 @@ def get_canonical_struct_name(self, scoped_name: ScopedName, location: Optional[ location=location, ) - return canonical_name - def resolve_type(self, cairo_type: CairoType) -> CairoType: """ Resolves a CairoType instance to fully qualified name. """ - if isinstance(cairo_type, TypeFelt): + if isinstance(cairo_type, (TypeFelt, TypeCodeoffset)): return cairo_type elif isinstance(cairo_type, TypePointer): return dataclasses.replace(cairo_type, pointee=self.resolve_type(cairo_type.pointee)) @@ -147,30 +170,59 @@ def resolve_type(self, cairo_type: CairoType) -> CairoType: if cairo_type.is_fully_resolved: return cairo_type try: + result = self.identifiers.search(self.accessible_scopes, cairo_type.scope) + result.assert_fully_parsed() + if isinstance(result.identifier_definition, TypeDefinition): + return self.resolve_type(result.identifier_definition.cairo_type) + + if ( + isinstance(result.identifier_definition, FutureIdentifierDefinition) + and result.identifier_definition.identifier_type is TypeDefinition + ): + raise PreprocessorError( + "Cannot use a type before its definition.", location=cairo_type.location + ) + + self.verify_possibly_future_struct( + identifier_definition=result.identifier_definition, + scoped_name=cairo_type.scope, + location=cairo_type.location, + ) + return dataclasses.replace( cairo_type, - scope=self.get_canonical_struct_name( - scoped_name=cairo_type.scope, location=cairo_type.location - ), + scope=result.get_canonical_name(), is_fully_resolved=True, ) except IdentifierError as exc: raise PreprocessorError(str(exc), location=cairo_type.location) elif isinstance(cairo_type, TypeTuple): + check_no_duplicate_names(cairo_type) return dataclasses.replace( - cairo_type, members=[self.resolve_type(subtype) for subtype in cairo_type.members] + cairo_type, + members=[ + dataclasses.replace(member, typ=self.resolve_type(member.typ)) + for member in cairo_type.members + ], ) else: raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") - def get_struct_size(self, struct_name: ScopedName, location: Optional[Location]): - return self.get_struct_definition(name=struct_name, location=location).size + def get_size_by_type_name(self, struct_name: ScopedName, location: Optional[Location]): + res = self.get_identifier_definition( + name=struct_name, supported_types=(StructDefinition, TypeDefinition), location=location + ) + assert isinstance(res, (StructDefinition, TypeDefinition)) + if isinstance(res, StructDefinition): + return res.size + else: + return self.get_size(res.cairo_type) def get_size(self, cairo_type: CairoType): """ Returns the size of the given type. """ - if isinstance(cairo_type, (TypeFelt, TypePointer)): + if isinstance(cairo_type, (TypeFelt, TypePointer, TypeCodeoffset)): return 1 elif isinstance(cairo_type, TypeStruct): if cairo_type.is_fully_resolved: @@ -181,11 +233,11 @@ def get_size(self, cairo_type: CairoType): except DefinitionError as exc: raise PreprocessorError(str(exc), location=cairo_type.location) else: - return self.get_struct_size( + return self.get_size_by_type_name( struct_name=cairo_type.scope, location=cairo_type.location ) elif isinstance(cairo_type, TypeTuple): - return sum(self.get_size(member_type) for member_type in cairo_type.members) + return sum(self.get_size(member_type) for member_type in cairo_type.types) else: raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") @@ -198,3 +250,22 @@ def inside_a_struct(self) -> bool: return False return parent.element_type == "struct" + + +def check_no_duplicate_names(cairo_type: TypeTuple): + """ + Verifies that there are no duplicate names in a tuple type. Raises a PreprocessorError + otherwise. + Does not check the inner types. + """ + names = set() + for member in cairo_type.members: + member_name = member.name + if member_name is None: + continue + if member_name in names: + raise PreprocessorError( + "Named tuple cannot have two entries with the same name.", + location=member.location, + ) + names.add(member_name) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py index 32bccabc..d0280e5e 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional, Type from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.code_elements import ( @@ -13,6 +13,7 @@ CodeElementReference, CodeElementReturnValueReference, CodeElementTemporaryVariable, + CodeElementTypeDef, CodeElementUnpackBinding, CodeElementWith, ) @@ -28,6 +29,7 @@ NamespaceDefinition, ReferenceDefinition, StructDefinition, + TypeDefinition, ) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.preprocessor.local_variables import N_LOCALS_CONSTANT @@ -56,13 +58,14 @@ class IdentifierCollector(Visitor): """ # A dict from code element types to the identifier type they define. - IDENTIFIER_DEFINERS = { + IDENTIFIER_DEFINERS: Dict[Type[CodeElement], Type[IdentifierDefinition]] = { CodeElementConst: ConstDefinition, CodeElementLabel: LabelDefinition, CodeElementReference: ReferenceDefinition, CodeElementLocalVariable: ReferenceDefinition, CodeElementTemporaryVariable: ReferenceDefinition, CodeElementReturnValueReference: ReferenceDefinition, + CodeElementTypeDef: TypeDefinition, } def __init__(self, identifiers: Optional[IdentifierManager] = None): diff --git a/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py b/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py index bba3456f..3166e145 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py @@ -11,7 +11,9 @@ @dataclasses.dataclass class PassManagerContext: - # A list of pairs (code, filename). + # A list of pairs (code, filename) that should be compiled before any module is imported. + start_codes: List[Tuple[str, str]] + # A list of pairs (code, filename) codes to compile. codes: List[Tuple[str, str]] main_scope: ScopedName identifiers: IdentifierManager diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py index 853c1423..f04c6aa8 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple +from typing import List, Optional, Sequence, Tuple from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager, PassManagerContext @@ -10,6 +10,7 @@ def preprocess_codes( codes: Sequence[Tuple[str, str]], pass_manager: PassManager, main_scope: ScopedName = ScopedName(), + start_codes: Optional[List[Tuple[str, str]]] = None, ) -> PreprocessedProgram: """ Preprocesses a list of Cairo files and returns a PreprocessedProgram instance. @@ -19,6 +20,7 @@ def preprocess_codes( codes=list(codes), main_scope=main_scope, identifiers=IdentifierManager(), + start_codes=[] if start_codes is None else start_codes, ) pass_manager.run(context) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py index 385dc2e3..061bafea 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py @@ -1,13 +1,17 @@ import dataclasses from collections import defaultdict from contextlib import contextmanager +from dataclasses import field from enum import Enum, auto -from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Type, cast +from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Type, Union, cast + +import marshmallow.fields as mfields from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, CastType, + TypeCodeoffset, TypeFelt, TypePointer, TypeStruct, @@ -38,6 +42,7 @@ CodeElementStaticAssert, CodeElementTailCall, CodeElementTemporaryVariable, + CodeElementTypeDef, CodeElementUnpackBinding, CodeElementWith, CodeElementWithAttr, @@ -54,6 +59,7 @@ ExprFutureLabel, ExprHint, ExprIdentifier, + ExprNewOperator, ExprOperator, ExprReg, ExprTuple, @@ -79,6 +85,7 @@ from starkware.cairo.lang.compiler.constants import SIZE_CONSTANT from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier +from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer from starkware.cairo.lang.compiler.identifier_definition import ( ConstDefinition, DefinitionError, @@ -145,11 +152,12 @@ from starkware.cairo.lang.compiler.proxy_identifier_manager import IdentifierManagerMemento from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference, translate_ap from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result -from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.scoped_name import ScopedName, ScopedNameAsStr from starkware.cairo.lang.compiler.substitute_identifiers import substitute_identifiers from starkware.cairo.lang.compiler.type_casts import check_cast from starkware.cairo.lang.compiler.type_system_visitor import get_expr_addr, simplify_type_system from starkware.python.utils import safe_zip +from starkware.starkware_utils.validated_dataclass import ValidatedDataclass # Indicates that the compiler should be able to deduce the change in the ap register for this # function. @@ -159,6 +167,22 @@ MAX_REFERENCE_RETRIES = 4 +class ReferenceChecker(ExpressionTransformer): + """ + Checks that a reference expression is valid. Raises a PreprocessorError otherwise. + """ + + def visit_ExprHint(self, expr: ExprHint): + raise PreprocessorError( + "The use of hints in reference expressions is not allowed.", location=expr.location + ) + + def visit_ExprNewOperator(self, expr: ExprNewOperator): + raise PreprocessorError( + "The use of 'new' in reference expressions is not allowed.", location=expr.location + ) + + class ReferenceTrial: """ Keeps track of an active trial, in which the preprocessor will optimistically compile a @@ -204,16 +228,20 @@ def format(self, with_locations: bool = False) -> str: ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class AttributeBase: name: str value: str -@dataclasses.dataclass -class AttributeScope(AttributeBase): +@dataclasses.dataclass(frozen=True) +class AttributeScope(AttributeBase, ValidatedDataclass): start_pc: int end_pc: int + flow_tracking_data: Optional[FlowTrackingDataActual] = field(metadata=dict(load_default=None)) + accessible_scopes: List[ScopedName] = field( + metadata=dict(marshmallow_field=mfields.List(ScopedNameAsStr, load_default=list)), + ) @dataclasses.dataclass @@ -370,6 +398,9 @@ def __init__( # See ReferenceTrial. self.reference_trial: Optional[ReferenceTrial] = None + def check_reference_expression(self, expr: Expression): + ReferenceChecker().visit(expr) + def search_identifier( self, name: str, location: Optional[Location] ) -> Optional[IdentifierDefinition]: @@ -764,6 +795,9 @@ def fix_reference_revocations( injections=injections, ) + def visit_CodeElementTypeDef(self, elm: CodeElementTypeDef): + self.check_no_hints('Hints before "using" statements are not allowed.') + def visit_CodeElementWith(self, elm: CodeElementWith): new_reference_states = dict(self.reference_states) for aliased_identifier in elm.identifiers: @@ -795,6 +829,9 @@ def visit_CodeElementWith(self, elm: CodeElementWith): def visit_CodeElementWithAttr(self, elm: CodeElementWithAttr): start_pc = self.current_pc + # Retrieve the flow_tracking_data and accessible_scopes before visiting the code block. + flow_tracking_data = self.flow_tracking.get() + accessible_scopes = self.accessible_scopes.copy() self.visit(elm.code_block) end_pc = self.current_pc self.attributes.append( @@ -803,6 +840,8 @@ def visit_CodeElementWithAttr(self, elm: CodeElementWithAttr): value=elm.get_value(), start_pc=start_pc, end_pc=end_pc, + flow_tracking_data=flow_tracking_data, + accessible_scopes=accessible_scopes, ) ) @@ -833,11 +872,9 @@ def visit_CodeElementIf(self, elm: CodeElementIf): expr_a=elm.condition.a, expr_b=elm.condition.b, cond_eq=elm.condition.eq ) - compound_expressions_code_elements, (res_cond_expr,) = process_compound_expressions( + (res_cond_expr,) = process_compound_expressions( [cond_expr], [SimplicityLevel.DEREF], context=self._compound_expression_context ) - for code_element in compound_expressions_code_elements: - self.visit(code_element) # Prepare labels. assert elm.label_neq is not None @@ -971,6 +1008,7 @@ def visit_CodeElementReference(self, elm: CodeElementReference): val, val_type = self.simplify_expr(elm.expr) assert_no_modifier(elm.typed_identifier) + self.check_reference_expression(val) if elm.typed_identifier.expr_type is not None: dst_type = self.resolve_type(elm.typed_identifier.expr_type) @@ -982,6 +1020,7 @@ def visit_CodeElementReference(self, elm: CodeElementReference): dest_type=dst_type, identifier_manager=self.identifiers, cast_type=CastType.ASSIGN, + location=dst_type.location, ): raise PreprocessorError( f"Cannot assign an expression of type '{val_type.format()}' " @@ -1025,6 +1064,61 @@ def visit_CodeElementLocalVariable(self, elm: CodeElementLocalVariable): "Local variables are not supported outside of functions.", location=elm.location ) + def get_expr_for_new_operator(self, new_expr: ExprNewOperator) -> Expression: + """ + Given a new expression, pushes the inner expression onto the stack, calls get_ap + and returns a pointer to the inner expression on the stack. + """ + location = new_expr.location + + # Push new_expr.expr onto the stack. + inner_expr, inner_type = self.simplify_expr(new_expr.expr) + inner_exprs = self.simplified_expr_to_felt_expr_list(expr=inner_expr, expr_type=inner_type) + self.push_compound_expressions(compound_expressions=inner_exprs, location=location) + + # Call get_ap(). + code_elm_call = CodeElementInstruction( + instruction=InstructionAst( + body=CallLabelInstruction( + label=ExprIdentifier( + name="starkware.cairo.lang.compiler.lib.registers.get_ap", location=location + ), + location=location, + fully_qualified_label=True, + ), + inc_ap=False, + location=location, + ) + ) + self.visit(code_elm_call) + + inner_expr_size = self.get_size(cairo_type=inner_type) + + # Create the expression that computes '[ap - 1] - inner_expr_size'. + # Note that here [ap - 1] is the value returned by get_ap(). + current_ap_expr = create_simple_ref_expr( + reg=Register.AP, + offset=-1, + cairo_type=TypeFelt(location=location), + location=location, + ) + + expr: Expression = ExprOperator( + a=current_ap_expr, + op="-", + b=ExprConst(inner_expr_size, location=location), + location=location, + ) + + if new_expr.is_typed: + # Cast pointer_expr to the correct type. + expr = ExprCast( + expr=expr, + dest_type=TypePointer(pointee=inner_type, location=location), + location=location, + ) + return expr + def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): assert_no_modifier(elm.typed_identifier) @@ -1070,7 +1164,12 @@ def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): ), ) else: - expr, src_type = self.simplify_expr(elm.expr) + if isinstance(elm.expr, ExprNewOperator): + expr = self.get_expr_for_new_operator(elm.expr) + else: + expr = elm.expr + + expr, src_type = self.simplify_expr(expr) src_size = self.get_size(src_type) if elm.typed_identifier.expr_type is None: @@ -1082,6 +1181,7 @@ def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): dest_type=dest_type, identifier_manager=self.identifiers, cast_type=CastType.ASSIGN, + location=elm.location, ): raise PreprocessorError( f"Cannot assign an expression of type '{src_type.format()}' " @@ -1126,9 +1226,7 @@ def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAsse ap_diff = self.flow_tracking.get_ap_tracking() - original_ap_tracking dst = self.simplifier.visit(translate_ap(dst, ap_diff)) src = self.simplifier.visit(translate_ap(src, ap_diff)) - compound_expressions_code_elements, (expr_a, expr_b) = process_compound_assert( - dst, src, self._compound_expression_context - ) + (expr_a, expr_b) = process_compound_assert(dst, src, self._compound_expression_context) assert_eq = CodeElementInstruction( instruction=InstructionAst( body=AssertEqInstruction(a=expr_a, b=expr_b, location=instruction.location), @@ -1137,8 +1235,6 @@ def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAsse ) ) - for code_element in compound_expressions_code_elements: - self.visit(code_element) self.visit(assert_eq) if self.auxiliary_info is not None: @@ -1354,15 +1450,12 @@ def push_compound_expressions( location - location to attach to errors if no finer location is relevant. """ # Generate instructions. - compound_expressions_code_elements, simple_exprs = process_compound_expressions( + simple_exprs = process_compound_expressions( compound_expressions, SimplicityLevel.OPERATION, context=self._compound_expression_context, ) - for code_element in compound_expressions_code_elements: - self.visit(code_element) - assert len(simple_exprs) == len(compound_expressions) simple_exprs = self.optimize_expressions_for_push(simple_exprs) compound_expressions = compound_expressions[-len(simple_exprs) :] @@ -1447,7 +1540,10 @@ def visit_CodeElementReturn(self, elm: CodeElementReturn): self.auxiliary_info.finish_return(exprs=elm.exprs) def check_tail_call_cast( - self, src_struct: StructDefinition, dest_struct: StructDefinition + self, + src_struct: StructDefinition, + dest_struct: StructDefinition, + location: Optional[Location], ) -> bool: """ Checks if src_struct can be converted to dest_struct in the context of a tail call. @@ -1464,6 +1560,7 @@ def check_tail_call_cast( dest_type=dest_member.cairo_type, identifier_manager=self.identifiers, cast_type=CastType.ASSIGN, + location=location, ): return False @@ -1493,7 +1590,9 @@ def visit_CodeElementTailCall(self, elm: CodeElementTailCall): identifier_manager=self.identifiers, ) - if not self.check_tail_call_cast(src_struct=src_struct, dest_struct=dest_struct): + if not self.check_tail_call_cast( + src_struct=src_struct, dest_struct=dest_struct, location=elm.location + ): raise PreprocessorError( f"""\ Cannot convert the return type of {func_name} to the return type of {self.current_scope[-1:]}.""", @@ -1553,7 +1652,7 @@ def add_implicit_return_references( implicit_args_struct = self.get_struct_definition( name=called_function + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, location=location ) - return_size = self.get_struct_size( + return_size = self.get_size_by_type_name( struct_name=called_function + CodeElementFunction.RETURN_SCOPE, location=location ) @@ -1697,6 +1796,9 @@ def visit_CodeElementReturnValueReference(self, elm: CodeElementReturnValueRefer func_ident = None if isinstance(elm.func_call.call_inst, CallLabelInstruction): func_ident = elm.func_call.call_inst.label + assert ( + not elm.func_call.call_inst.fully_qualified_label + ), "Expecting a relative label." elif isinstance(elm.func_call, RvalueFuncCall): # If the function name is the name of a struct, replace the # CodeElementReturnValueReference with a regular reference. @@ -1804,6 +1906,7 @@ def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): dest_type=cairo_type, identifier_manager=self.identifiers, cast_type=CastType.UNPACKING, + location=typed_identifier.location, ): raise PreprocessorError( f"""\ @@ -1888,7 +1991,11 @@ def visit_CodeElementLabel(self, elm: CodeElementLabel): self.add_label(elm.identifier) if self.auxiliary_info is not None: - _, label_full_name = self.get_label(elm.identifier.name, elm.identifier.location) + _, label_full_name = self.get_label( + label_name=elm.identifier.name, + fully_qualified_label=False, + location=elm.identifier.location, + ) self.auxiliary_info.record_label(label_full_name=label_full_name) def visit_CodeElementHint(self, elm: CodeElementHint): @@ -1938,7 +2045,11 @@ def visit_JumpInstruction(self, instruction: JumpInstruction): def visit_JumpToLabelInstruction(self, instruction: JumpToLabelInstruction): label_name = instruction.label.name - label_pc, label_full_name = self.get_label(label_name, instruction.label.location) + label_pc, label_full_name = self.get_label( + label_name=label_name, + fully_qualified_label=False, + location=instruction.label.location, + ) # Process instruction. res_instruction: InstructionBody @@ -2016,7 +2127,11 @@ def visit_CallInstruction(self, instruction: CallInstruction): def visit_CallLabelInstruction(self, instruction: CallLabelInstruction): label_name = instruction.label.name - label_pc, full_label_scope = self.get_label(label_name, instruction.label.location) + label_pc, full_label_scope = self.get_label( + label_name=label_name, + fully_qualified_label=instruction.fully_qualified_label, + location=instruction.label.location, + ) # If the function has a known reg change, use it. ap_change = RegChangeUnknown() @@ -2089,7 +2204,7 @@ def simplify_expr_as_felt(self, expr) -> Expression: to felt (felt or pointer) and it does not return the type. """ expr, expr_type = self.simplify_expr(expr) - if not isinstance(expr_type, (TypeFelt, TypePointer)): + if not isinstance(expr_type, (TypeFelt, TypePointer, TypeCodeoffset)): raise PreprocessorError( f"Expected a 'felt' or a pointer type. Got: '{expr_type.format()}'.", location=expr.location, @@ -2114,6 +2229,7 @@ def simplify_expr_to_felt_expr_list( dest_type=expected_type, identifier_manager=self.identifiers, cast_type=CastType.ASSIGN, + location=location, ): raise PreprocessorError( f"""\ @@ -2131,12 +2247,12 @@ def simplified_expr_to_felt_expr_list( that can be passed to process_compound_expressions. """ - if isinstance(expr_type, (TypeFelt, TypePointer)): + if isinstance(expr_type, (TypeFelt, TypePointer, TypeCodeoffset)): return [expr] # Get the list of member types. if isinstance(expr_type, TypeTuple): - member_types = expr_type.members + member_types = expr_type.types elif isinstance(expr_type, TypeStruct): struct_definition = get_struct_definition( expr_type.scope, identifier_manager=self.identifiers @@ -2183,16 +2299,24 @@ def simplified_expr_to_felt_expr_list( return expr_list def get_label( - self, label_name: str, location: Optional[Location] + self, label_name: str, fully_qualified_label: bool, location: Optional[Location] ) -> Tuple[Optional[int], Optional[ScopedName]]: """ Returns a pair (pc, canonical_name) for the given label, or (None, None) if this label hasn't been processed yet. + + fully_qualified_label indicates that 'label_name' is a fully qualified identifier, + rather than a relative one. """ try: - search_result = self.identifiers.search( - accessible_scopes=self.accessible_scopes, name=ScopedName.from_string(label_name) - ) + scoped_name = ScopedName.from_string(label_name) + if fully_qualified_label: + search_result = self.identifiers.get(name=scoped_name) + else: + search_result = self.identifiers.search( + accessible_scopes=self.accessible_scopes, + name=scoped_name, + ) search_result.assert_fully_parsed() except IdentifierError as exc: raise PreprocessorError(str(exc), location=location) @@ -2208,7 +2332,7 @@ def get_label( ) return search_result.identifier_definition.pc, search_result.canonical_name - def get_variable(self, var: ExprIdentifier): + def get_variable(self, var: ExprIdentifier) -> Union[int, Expression]: identifier_definition = self.search_identifier(var.name, var.location) # Check that identifier_definition is not None for mypy. assert identifier_definition is not None @@ -2216,7 +2340,7 @@ def get_variable(self, var: ExprIdentifier): if isinstance(identifier_definition, FutureIdentifierDefinition): if identifier_definition.identifier_type in [LabelDefinition, FunctionDefinition]: # Allow future label assignment. - return ExprFutureLabel(identifier=var) + return ExprFutureLabel(identifier=var, is_typed=True, location=var.location) raise PreprocessorError( f"Identifier '{var.name}' referenced before definition.", location=var.location ) @@ -2224,12 +2348,17 @@ def get_variable(self, var: ExprIdentifier): if isinstance(identifier_definition, ConstDefinition): return identifier_definition.value - if isinstance(identifier_definition, LabelDefinition): - return identifier_definition.pc - if isinstance(identifier_definition, MemberDefinition): return identifier_definition.offset + if isinstance(identifier_definition, LabelDefinition): + location = var.location + return ExprCast( + expr=ExprConst(identifier_definition.pc, location=location), + dest_type=TypeCodeoffset(location=location), + location=location, + ) + if isinstance(identifier_definition, (ReferenceDefinition, OffsetReferenceDefinition)): try: res_expr = identifier_definition.eval( @@ -2369,3 +2498,9 @@ def get_fp_val(self, location: Optional[Location]) -> Expression: "Using the value of fp directly, requires defining a variable named __fp__.", location=exc.location, ) + + def visit(self, elm: CodeElement): + self.preprocessor.visit(elm) + + def get_ap_tracking(self) -> RegTrackingData: + return self.preprocessor.flow_tracking.get_ap_tracking() diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py index 5d17646c..32f80c5c 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py @@ -1,27 +1,32 @@ import pytest +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction -from starkware.cairo.lang.compiler.ast.module import CairoModule +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.identifier_definition import ( ConstDefinition, FunctionDefinition, LabelDefinition, ReferenceDefinition, + TypeDefinition, ) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError from starkware.cairo.lang.compiler.instruction_builder import InstructionBuilderError from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager +from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingDataActual from starkware.cairo.lang.compiler.preprocessor.preprocess_codes import preprocess_codes from starkware.cairo.lang.compiler.preprocessor.preprocessor import AttributeScope from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( + CAIRO_TEST_MODULES, PRIME, TEST_SCOPE, preprocess_str, strip_comments_and_linebreaks, verify_exception, ) +from starkware.cairo.lang.compiler.preprocessor.reg_tracking import RegTrackingData from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.test_utils import read_file_from_dict from starkware.cairo.lang.compiler.type_casts import CairoTypeError @@ -173,8 +178,9 @@ def test_assign_future_label(): def test_assign_future_function_label(): code = """\ -g(f) -g((f + 1) * 2 + 3) +start: +g(f - start) +g((f - start + 1) * 2 + 3) func f() -> (): ret @@ -825,27 +831,73 @@ def test_with_statement_failure(): def test_with_attr_statement(): code = """ -[ap] = 0 -with_attr attr_name("attribute value"): - [ap] = 1 +func a(): + alloc_locals + local x = 0 + ap += 7 + with_attr attr_name("attribute value"): + [ap] = 1 + end + [ap] = 2 + ret end -[ap] = 2 """ program = preprocess_str(code=code, prime=PRIME) assert ( program.format() == """\ -[ap] = 0 +ap += 1 +[fp] = 0 +ap += 7 [ap] = 1 [ap] = 2 +ret """ ) + expected_flow_tracking_data = FlowTrackingDataActual( + ap_tracking=RegTrackingData(group=0, offset=8), + reference_ids={ScopedName.from_string("test_scope.a.x"): 0}, + ) + expected_accessible_scopes = [ + ScopedName.from_string("test_scope"), + ScopedName.from_string("test_scope.a"), + ] expected_attributes = [ - AttributeScope(name="attr_name", value="attribute value", start_pc=2, end_pc=4) + AttributeScope( + name="attr_name", + value="attribute value", + start_pc=6, + end_pc=8, + flow_tracking_data=expected_flow_tracking_data, + accessible_scopes=expected_accessible_scopes, + ) ] assert program.attributes == expected_attributes +def test_attribute_scope_deserialization_with_missing_fields(): + """ + Check that AttributeScope can be deserialized even if accessible_scopes or flow_tracking_data + are missing from the serialization. + """ + code = """ +with_attr attr_name("attribute value"): + [ap] = 1 +end +""" + program = compile_cairo(code, prime=DEFAULT_PRIME) + assert len(program.attributes) == 1 + + serialized_program = program.dump() + serialized_attribute = serialized_program["attributes"][0] + del serialized_attribute["accessible_scopes"] + del serialized_attribute["flow_tracking_data"] + + deserialized_attribute = program.load(data=serialized_program).attributes[0] + assert deserialized_attribute.accessible_scopes == [] + assert deserialized_attribute.flow_tracking_data is None + + def test_implicit_args(): code = """\ struct T: @@ -1217,9 +1269,10 @@ def test_function_call_by_value_args(): member s : felt member t : S end -func f(x, y : T, z : T): +func f(w : S, x, y : T, z : T): + let s = S(a=13, b=17) let t : T = [cast(ap, T*)] - let res = f(x=2, y=z, z=t) + let res = f(w=s, x=2, y=z, z=t) return() end """ @@ -1227,14 +1280,16 @@ def test_function_call_by_value_args(): assert ( program.format() == """\ +[ap] = 13; ap++ +[ap] = 17; ap++ [ap] = 2; ap++ [ap] = [fp + (-5)]; ap++ [ap] = [fp + (-4)]; ap++ [ap] = [fp + (-3)]; ap++ -[ap] = [ap + (-4)]; ap++ -[ap] = [ap + (-4)]; ap++ -[ap] = [ap + (-4)]; ap++ -call rel -8 +[ap] = [ap + (-6)]; ap++ +[ap] = [ap + (-6)]; ap++ +[ap] = [ap + (-6)]; ap++ +call rel -12 ret """ ) @@ -1390,7 +1445,9 @@ def test_import(): } program = preprocess_codes( codes=[(files["."], ".")], - pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), + pass_manager=default_pass_manager( + prime=PRIME, read_module=read_file_from_dict(dct={**files, **CAIRO_TEST_MODULES}) + ), ) assert ( @@ -1437,7 +1494,9 @@ def get_full_name(name, curr_scope=""): # Preprocess program. program = preprocess_codes( codes=[(files["."], ".")], - pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), + pass_manager=default_pass_manager( + prime=PRIME, read_module=read_file_from_dict(dct={**files, **CAIRO_TEST_MODULES}) + ), main_scope=scope("__main__"), ) @@ -1648,7 +1707,6 @@ def test_process_file_scope(): valid_scope = ScopedName.from_string("some.valid.scope") program = preprocess_str("const x = 4", prime=PRIME, main_scope=valid_scope) - module = CairoModule(cairo_file=program, module_name=valid_scope) assert program.identifiers.as_dict() == {valid_scope + "x": ConstDefinition(4)} @@ -2210,6 +2268,47 @@ def test_rebind_reference_failures(): ) +def test_invalid_references(): + verify_exception( + """ +let x = 3 * cast(nondet %{ rnadom.randrange(10) %}, felt) + 5 +""", + """ +file:?:?: The use of hints in reference expressions is not allowed. +let x = 3 * cast(nondet %{ rnadom.randrange(10) %}, felt) + 5 + ^*******************************************^ +""", + ) + + +def test_rvalue_func_call_reference_with_nondet(): + """ + Tests that nondet hints are computed exactly once when used as arguments to function calls + in a reference expression. + """ + program = preprocess_str( + code=""" +func foo(val) -> (res): + return (res=val) +end +let x = foo(nondet %{ 5 %}) +assert x = x +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ +[ap] = [fp + (-3)]; ap++ +ret +%{ memory[ap] = to_felt_or_relocatable(5) %} +ap += 1 +call rel -4 +[ap + (-1)] = [ap + (-1)] +""" + ) + + def test_reference_over_calls(): program = preprocess_str( code=""" @@ -3064,7 +3163,7 @@ def test_bad_type_annotation(): end """, """ -file:?:?: Expected 'test_scope.foo' to be a struct. Found: 'function'. +file:?:?: Expected 'test_scope.foo' to be struct or type_definition. Found: 'function'. local a : foo ^*^ """, @@ -3621,6 +3720,63 @@ def test_tuple_expression_failures(): ) +def test_named_tuple_types(): + code = """ +func foo(a : (felt, (x : felt, y : felt))): + tempvar tmp0 = a[0] + tempvar tmp1 = a[1].y + tempvar tmp2 = a[1][1] + ret +end + +foo((0, (x=1, y=2))) +# You can pass named to unnamed and vice versa. +foo((a=0, b=(1, 2))) +""" + program = preprocess_str(code=code, prime=PRIME) + assert ( + program.format() + == """\ +[ap] = [fp + (-5)]; ap++ +[ap] = [fp + (-3)]; ap++ +[ap] = [fp + (-3)]; ap++ +ret +[ap] = 0; ap++ +[ap] = 1; ap++ +[ap] = 2; ap++ +call rel -10 +[ap] = 0; ap++ +[ap] = 1; ap++ +[ap] = 2; ap++ +call rel -18 +""" + ) + + +def test_named_tuple_types_failure(): + verify_exception( + """ +func foo(arg : (a : (b : felt, b : felt), c : felt)): +end +""", + """ +file:?:?: Named tuple cannot have two entries with the same name. +func foo(arg : (a : (b : felt, b : felt), c : felt)): + ^******^ +""", + ) + verify_exception( + """ +let x = (a=(b=1, b=2), c=0) +""", + """ +file:?:?: Named tuple cannot have two entries with the same name. +let x = (a=(b=1, b=2), c=0) + ^*^ +""", + ) + + def test_struct_constructor(): code = """\ struct M: @@ -3744,7 +3900,7 @@ def verify_exception_for_expr(expr_str: str, expected_error: str): verify_exception_for_expr( "T(x=5, y=6, z=7)", """ -file:?:?: Cannot cast an expression of type '(felt, felt, felt)' to 'test_scope.T'. +file:?:?: Cannot cast an expression of type '(x : felt, y : felt, z : felt)' to 'test_scope.T'. The former has 3 members while the latter has 2 members. local a : T = T(x=5, y=6, z=7) ^**************^ @@ -3754,7 +3910,7 @@ def verify_exception_for_expr(expr_str: str, expected_error: str): verify_exception_for_expr( "T(x=5)", """ -file:?:?: Cannot cast an expression of type '(felt)' to 'test_scope.T'. +file:?:?: Cannot cast an expression of type '(x : felt)' to 'test_scope.T'. The former has 1 members while the latter has 2 members. local a : T = T(x=5) ^****^ @@ -3791,7 +3947,7 @@ def verify_exception_for_expr(expr_str: str, expected_error: str): verify_exception_for_expr( "T(5, 6).x", """ -file:?:?: Accessing struct members for r-value structs is not supported yet. +file:?:?: Accessing struct/tuple members for r-value structs is not supported yet. local a : T = T(5, 6).x ^*******^ """, @@ -3846,7 +4002,9 @@ def test_skipped_functions(): } program = preprocess_codes( codes=[(files["."], ".")], - pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), + pass_manager=default_pass_manager( + prime=PRIME, read_module=read_file_from_dict(dct={**files, **CAIRO_TEST_MODULES}) + ), ) assert ( program.format() @@ -3863,13 +4021,13 @@ def test_skipped_functions(): codes=[(files["."], ".")], pass_manager=default_pass_manager( prime=PRIME, - read_module=read_file_from_dict(files), + read_module=read_file_from_dict(dct={**files, **CAIRO_TEST_MODULES}), opt_unused_functions=False, ), ) - assert ( - program.format() - == """\ + assert program.format() == strip_comments_and_linebreaks( + """\ +ret # get_ap is here because opt_unused_functions is false. [ap] = 0; ap++ ret [ap] = 1; ap++ @@ -3937,3 +4095,334 @@ def test_define_word_failure(): """, exc_type=InstructionBuilderError, ) + + +def test_label_arithmetic_flow(): + code = """ +label1: +assert 4 = label2 - label1 +label2: +assert 4 = label2 - label1 +""" + program = preprocess_str(code=code, prime=PRIME) + assert ( + program.format() + == """\ +[ap] = 4; ap++ +4 = [ap + (-1)] +[ap] = 4; ap++ +4 = [ap + (-1)] +""" + ) + + +def test_label_arithmetic_failure(): + verify_exception( + """ +tempvar code_offset : codeoffset +assert 0 = code_offset - 5 +""", + """ +file:?:?: Operator '-' is not implemented for types 'codeoffset' and 'felt'. +assert 0 = code_offset - 5 + ^*************^ +""", + exc_type=CairoTypeError, + ) + + verify_exception( + """ +func foo(): + assert foo = 5 + return() +end + +""", + """ +file:?:?: Cannot compare 'codeoffset' and 'felt'. + assert foo = 5 + ^************^ +""", + ) + + +def test_future_label_substraction_failure(): + """ + Subtracting two future labels doesn't work at the moment. + The test is here to check the error message. + """ + + verify_exception( + """ +assert 0 = label2 - label1 +label1: +label2: +""", + """ +file:?:?: Expected a constant expression or a dereference expression. +assert 0 = label2 - label1 + ^****^ +Preprocessed instruction: +[ap] = [ap + (-1)] - label1; ap++ +""", + exc_type=InstructionBuilderError, + ) + + +def test_future_label_minus_tempvar(): + code = """ +tempvar a = cast(0, codeoffset) +assert 0 = label1 - a +label1: +""" + program = preprocess_str(code=code, prime=PRIME) + assert ( + program.format() + == """\ +[ap] = 0; ap++ +[ap] = 7; ap++ +[ap] = [ap + (-1)] - [ap + (-2)]; ap++ +0 = [ap + (-1)] +""" + ) + + +def test_new_operator_flow(): + code = """ +struct MyStruct: + member a : felt + member b : felt +end + +func test() -> (my_struct: MyStruct*): + tempvar t = 37 + tempvar my_struct = new MyStruct(a=1, b=2) + + # Check that 't' wasn't revoked and that the type of a is MyStruct*. + assert cast(t, MyStruct*) = my_struct + assert cast(t, MyStruct*) = new MyStruct(a=3, b=[new 4]) + return (my_struct=my_struct) +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == strip_comments_and_linebreaks( + """\ +# A dummy get_ap(). +ret + +[ap] = 37; ap++ + +# tempvar my_struct = new MyStruct(a=1, b=2). +[ap] = 1; ap++ +[ap] = 2; ap++ +call rel -7 # call get_ap() +[ap] = [ap + (-1)] + (-2); ap++ # [ap] = get_ap() - MyStruct.SIZE +# assert cast(t, MyStruct*) = my_struct. +[ap + (-6)] = [ap + (-1)] + +# new 4. +[ap] = 4; ap++ +call rel -14 +[ap] = [ap + (-1)] + (-1); ap++ + +# new MyStruct(a=3, b=[new 4]). +[ap] = 3; ap++ +[ap] = [[ap + (-2)]]; ap++ +call rel -21 +[ap] = [ap + (-1)] + (-2); ap++ + +# assert cast(t, MyStruct*) = new MyStruct(a=3, b=[new 4]). +[ap + (-15)] = [ap + (-1)] +# return (my_struct=my_struct). +[ap] = [ap + (-10)]; ap++ +ret +""" + ) + + code = """ +func test() -> (felt_ptr : felt*): + return (felt_ptr=new ([fp] + 5)) +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == strip_comments_and_linebreaks( + """\ +# A dummy get_ap(). +ret + +[ap] = [fp] + 5; ap++ +call rel -3 # call get_ap() +[ap] = [ap + (-1)] + (-1); ap++ +ret +""" + ) + + code = """ +func test() -> (tuple_ptr : (felt, felt)*): + return (tuple_ptr=new (7, 8)) +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert program.format() == strip_comments_and_linebreaks( + """\ +# A dummy get_ap(). +ret + +[ap] = 7; ap++ +[ap] = 8; ap++ +call rel -5 # call get_ap() +[ap] = [ap + (-1)] + (-2); ap++ +ret +""" + ) + + +def test_new_operator_failure(): + verify_exception( + """ +new MyStruct(a=1, b=2) = 13 +""", + """ +file:?:?: The new operator is not supported outside of a function. +new MyStruct(a=1, b=2) = 13 +^********************^ +""", + ) + + verify_exception( + """ +struct MyStruct: + member a : felt + member b : felt +end +func test(): + new MyStruct(a=1, b=2) = 13 + return() +end +""", + """ +file:?:?: Expected a dereference expression. + new MyStruct(a=1, b=2) = 13 + ^********************^ +Preprocessed instruction: +new (1, 2) = 13 +""", + exc_type=InstructionBuilderError, + ) + + verify_exception( + """ +new MyStruct(a=1, b=2) = 13 +""", + """ +file:?:?: The new operator is not supported outside of a function. +new MyStruct(a=1, b=2) = 13 +^********************^ +""", + ) + verify_exception( + """ +func test(): + alloc_locals + local x = new 5 + return() +end + +""", + """ +file:?:?: Cannot cast 'felt*' to 'felt'. + local x = new 5 + ^***^ +""", + exc_type=CairoTypeError, + ) + + verify_exception( + """ +func test(): + let x = new 5 + return() +end + +""", + """ +file:?:?: The use of 'new' in reference expressions is not allowed. + let x = new 5 + ^***^ +""", + ) + + +def test_type_definition(): + code = """ +namespace a: + namespace b: + using Point = (felt, felt) + end + namespace c: + using TwoPoints = (b.Point, b.Point) + end +end + +func foo(z : a.b.Point): + alloc_locals + tempvar x : a.c.TwoPoints = ((0, 1), [cast(fp, a.b.Point*)]) + local y : a.c.TwoPoints + assert x[0] = z + return () +end +""" + program = preprocess_str(code=code, prime=PRIME) + assert ( + program.format() + == """\ +ap += 4 +[ap] = 0; ap++ +[ap] = 1; ap++ +[ap] = [fp]; ap++ +[ap] = [fp + 1]; ap++ +[ap + (-4)] = [fp + (-4)] +[ap + (-3)] = [fp + (-3)] +ret +""" + ) + two_points = program.identifiers.get_by_full_name(TEST_SCOPE + "a.c.TwoPoints") + assert ( + isinstance(two_points, TypeDefinition) + and two_points.cairo_type.format() == "((felt, felt), (felt, felt))" + ) + + +def test_type_definition_failure(): + verify_exception( + """ +using Point = Point +""", + """ +file:?:?: Cannot use a type before its definition. +using Point = Point + ^***^ +""", + ) + verify_exception( + """ +using Point2 = (Point, Point) +using Point = (felt, felt) +""", + """ +file:?:?: Cannot use a type before its definition. +using Point2 = (Point, Point) + ^***^ +""", + ) + verify_exception( + """ +%{ %} +using Point = (felt, felt) +""", + """ +file:?:?: Hints before "using" statements are not allowed. +%{ %} +^***^ +""", + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py index ec8b4433..2f3a0506 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py @@ -20,12 +20,22 @@ TEST_SCOPE = ScopedName.from_string("test_scope") +CAIRO_TEST_MODULES = { + "starkware.cairo.lang.compiler.lib.registers": """ +@known_ap_change +func get_ap() -> (ap_val): + ret +end +""", +} + + def strip_comments_and_linebreaks(program: str): """ Removes all comments and empty lines from the given program. """ program = re.sub(r"\s*#.*\n", "\n", program) - return re.sub("\n+", "\n", program) + return re.sub("\n+", "\n", program).lstrip() def default_read_module(module_name: str): @@ -41,7 +51,9 @@ def preprocess_str( return preprocess_str_ex( code=code, pass_manager=default_pass_manager( - prime=prime, read_module=default_read_module, preprocessor_cls=preprocessor_cls + prime=prime, + read_module=read_file_from_dict(CAIRO_TEST_MODULES), + preprocessor_cls=preprocessor_cls, ), main_scope=main_scope, ) @@ -70,7 +82,9 @@ def verify_exception( main_scope = TEST_SCOPE if pass_manager is None: - pass_manager = default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)) + pass_manager = default_pass_manager( + prime=PRIME, read_module=read_file_from_dict({**files, **CAIRO_TEST_MODULES}) + ) with pytest.raises(exc_type) as e: preprocess_codes(codes=[(code, "")], pass_manager=pass_manager, main_scope=main_scope) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py b/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py index 229de61e..8e1d1fee 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py @@ -19,7 +19,7 @@ def test_from_expr(): assert RegChange.from_expr(ExprIdentifier("asd")) == RegChangeUnknown() with pytest.raises(TypeError): - RegChange.from_expr("wrong type") + RegChange.from_expr("wrong type") # type: ignore def test_reg_change_add(): @@ -28,7 +28,7 @@ def test_reg_change_add(): assert RegChangeUnknown() + RegChangeKnown(2) == RegChangeUnknown() with pytest.raises(TypeError): - RegChangeKnown(3) + "asd" + RegChangeKnown(3) + "asd" # type: ignore with pytest.raises(TypeError): RegChangeUnconstrained() + RegChangeKnown(0) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py index 3eba92ac..c8e76db9 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py @@ -9,10 +9,15 @@ CodeElementEmptyLine, CodeElementFunction, CodeElementMember, + CodeElementTypeDef, ) from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField from starkware.cairo.lang.compiler.error_handling import Location -from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition +from starkware.cairo.lang.compiler.identifier_definition import ( + MemberDefinition, + StructDefinition, + TypeDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( IdentifierAwareVisitor, @@ -49,7 +54,6 @@ def _visit_default(self, obj): def add_struct_definition( self, members_list: List[MemberInfo], struct_name: ScopedName, location: Optional[Location] ): - offset = 0 members: Dict[str, MemberDefinition] = {} for member_info in members_list: @@ -170,3 +174,10 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): ) self.visit(elm.code_block) + + def visit_CodeElementTypeDef(self, elm: CodeElementTypeDef): + self.add_name_definition( + self.current_scope + elm.name, + identifier_definition=TypeDefinition(cairo_type=self.resolve_type(elm.cairo_type)), + location=elm.location, + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py index fe4ddf3f..a4c88c21 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py @@ -2,6 +2,15 @@ from starkware.cairo.lang.compiler.ast.node import AstNode from starkware.cairo.lang.compiler.ast.visitor import Visitor +ANONYMOUS_LABEL_PREFIX = "_anon_label" + + +def is_anonymous_label(label_name: str) -> bool: + """ + Returns True if the given label seems to have been generated by AnonymousLabelGenerator. + """ + return label_name.startswith(ANONYMOUS_LABEL_PREFIX) + class AnonymousLabelGenerator: """ @@ -13,7 +22,7 @@ def __init__(self): self.anon_label_counter = 0 def get(self): - label_name = f"_anon_label{self.anon_label_counter}" + label_name = f"{ANONYMOUS_LABEL_PREFIX}{self.anon_label_counter}" self.anon_label_counter += 1 return label_name diff --git a/src/starkware/cairo/lang/compiler/proxy_identifier_manager.py b/src/starkware/cairo/lang/compiler/proxy_identifier_manager.py index 4d0f22c3..608b106e 100644 --- a/src/starkware/cairo/lang/compiler/proxy_identifier_manager.py +++ b/src/starkware/cairo/lang/compiler/proxy_identifier_manager.py @@ -1,9 +1,10 @@ import dataclasses -from typing import ChainMap, Dict, Optional, Tuple +from typing import ChainMap, Optional, Tuple from starkware.cairo.lang.compiler.identifier_definition import IdentifierDefinition from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager, IdentifierScope from starkware.cairo.lang.compiler.preprocessor.memento import Memento +from starkware.cairo.lang.compiler.scoped_name import ScopedName class ProxyIdentifierManager(IdentifierManager): @@ -15,8 +16,8 @@ class ProxyIdentifierManager(IdentifierManager): def __init__(self, parent: IdentifierManager): self.parent = parent - self.root = ProxyIdentifierScope(manager=self, parent=parent.root) - self.dict: ChainMap[str, IdentifierDefinition] = ChainMap({}, parent.dict) + self.root: ProxyIdentifierScope = ProxyIdentifierScope(manager=self, parent=parent.root) + self.dict: ChainMap[ScopedName, IdentifierDefinition] = ChainMap({}, parent.dict) def apply(self): """ @@ -30,7 +31,6 @@ class ProxyIdentifierScope(IdentifierScope): def __init__(self, manager: IdentifierManager, parent: IdentifierScope): super().__init__(manager=manager, fullname=parent.fullname) self.parent = parent - self.subscopes: Dict[str, IdentifierScope] = {} self.identifiers: ChainMap[str, IdentifierDefinition] = ChainMap({}, parent.identifiers) def get_single_scope(self, name: str) -> Optional["IdentifierScope"]: @@ -55,6 +55,7 @@ def add_subscope(self, first_name: str): def _apply(self): self.parent.identifiers.update(self.identifiers.maps[0]) for name, subscope in self.subscopes.items(): + assert isinstance(subscope, ProxyIdentifierScope) if name not in self.parent.subscopes: self.parent.subscopes[name] = subscope.parent subscope._apply() diff --git a/src/starkware/cairo/lang/compiler/proxy_identifier_manager_test.py b/src/starkware/cairo/lang/compiler/proxy_identifier_manager_test.py index 6f82641c..5b48f99e 100644 --- a/src/starkware/cairo/lang/compiler/proxy_identifier_manager_test.py +++ b/src/starkware/cairo/lang/compiler/proxy_identifier_manager_test.py @@ -5,7 +5,10 @@ IdentifierManager, MissingIdentifierError, ) -from starkware.cairo.lang.compiler.proxy_identifier_manager import ProxyIdentifierManager +from starkware.cairo.lang.compiler.proxy_identifier_manager import ( + ProxyIdentifierManager, + ProxyIdentifierScope, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName scope = ScopedName.from_string @@ -18,7 +21,9 @@ def test_identifier_manager_get(): manager = IdentifierManager.from_dict(identifier_dict) proxy = ProxyIdentifierManager(manager) for full_name in ["a", "a.b"]: - assert proxy.get_scope(scope(full_name)).parent == manager.get_scope(scope(full_name)) + proxy_scope = proxy.get_scope(scope(full_name)) + assert isinstance(proxy_scope, ProxyIdentifierScope) + assert proxy_scope.parent == manager.get_scope(scope(full_name)) assert proxy.get(scope("a.b.c")) == manager.get(scope("a.b.c")) proxy.add_identifier(scope("a.d"), ConstDefinition(value=8)) diff --git a/src/starkware/cairo/lang/compiler/scoped_name.py b/src/starkware/cairo/lang/compiler/scoped_name.py index 02498b8f..4e16e92a 100644 --- a/src/starkware/cairo/lang/compiler/scoped_name.py +++ b/src/starkware/cairo/lang/compiler/scoped_name.py @@ -14,7 +14,7 @@ def __post_init__(self): assert all([self.SEPARATOR not in part for part in self.path]) @classmethod - def from_string(cls, scope: str): + def from_string(cls, scope: str) -> "ScopedName": if scope == "": # Handle the special case of an empty tuple. return cls() diff --git a/src/starkware/cairo/lang/compiler/substitute_identifiers.py b/src/starkware/cairo/lang/compiler/substitute_identifiers.py index 5247229c..5f6da9fa 100644 --- a/src/starkware/cairo/lang/compiler/substitute_identifiers.py +++ b/src/starkware/cairo/lang/compiler/substitute_identifiers.py @@ -118,7 +118,12 @@ def visit_ExprFuncCall(self, expr: ExprFuncCall): ) def visit_ExprFutureLabel(self, expr: ExprFutureLabel): - return self.visit(expr.identifier) + res = self.visit(expr.identifier) + if isinstance(res, ExprFutureLabel): + # If expr.identifier remains unresolved, return the original expression to keep track + # of expr.is_typed. + return expr + return res def substitute_identifiers( diff --git a/src/starkware/cairo/lang/compiler/type_casts.py b/src/starkware/cairo/lang/compiler/type_casts.py index 346e4aa6..98425728 100644 --- a/src/starkware/cairo/lang/compiler/type_casts.py +++ b/src/starkware/cairo/lang/compiler/type_casts.py @@ -1,18 +1,20 @@ import itertools -from typing import Optional +from typing import Iterable, Optional, Sequence from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, CastType, + TypeCodeoffset, TypeFelt, TypePointer, TypeStruct, TypeTuple, ) from starkware.cairo.lang.compiler.ast.expr import ExprDeref, Expression, ExprTuple -from starkware.cairo.lang.compiler.error_handling import LocationError +from starkware.cairo.lang.compiler.error_handling import Location, LocationError from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition +from starkware.python.utils import safe_zip FELT_STAR = TypePointer(pointee=TypeFelt()) @@ -25,17 +27,20 @@ def check_cast( src_type: CairoType, dest_type: CairoType, identifier_manager: IdentifierManager, + cast_type: CastType, + location: Optional[Location], expr: Optional[Expression] = None, - cast_type: CastType = CastType.EXPLICIT, ) -> bool: """ Returns true if the given expression can be casted from src_type to dest_type according to the given 'cast_type'. In some cases of cast failure, an exception with more specific details is raised. - - 'expr' must be specified (not None) when CastType.EXPLICIT (or above) is used. + 'location' is used as the default error location if expr is not specified. """ + if expr is not None and expr.location is not None: + location = expr.location + # CastType.ASSIGN checks: if src_type == dest_type: @@ -45,6 +50,20 @@ def check_cast( if isinstance(src_type, TypePointer) and dest_type == FELT_STAR: return True + # Allow implicit cast between named and unnamed tuples. + if isinstance(src_type, TypeTuple) and isinstance(dest_type, TypeTuple): + verify_tuple_like_cast( + src_type=src_type, + src_members=src_type.members, + dest_type=dest_type, + dest_members=dest_type.members, + identifier_manager=identifier_manager, + expr=expr, + location=location, + cast_type=cast_type, + ) + return True + if cast_type is CastType.ASSIGN: return False @@ -60,46 +79,33 @@ def check_cast( return False # CastType.EXPLICIT checks: - assert expr is not None, f"CastType.EXPLICIT requires expr != None." + + # Allow explicit cast between felts and labels. + if isinstance(src_type, (TypeFelt, TypeCodeoffset)) and isinstance( + dest_type, (TypeFelt, TypeCodeoffset) + ): + return True if isinstance(src_type, TypeTuple) and isinstance(dest_type, TypeStruct): struct_def = get_struct_definition( struct_name=dest_type.resolved_scope, identifier_manager=identifier_manager ) - - n_src_members = len(src_type.members) - n_dest_members = len(struct_def.members) - if n_src_members != n_dest_members: - raise CairoTypeError( - f"""\ -Cannot cast an expression of type '{src_type.format()}' to '{dest_type.format()}'. -The former has {n_src_members} members while the latter has {n_dest_members} members.""", - location=expr.location, - ) - - src_exprs = ( - [arg.expr for arg in expr.members.args] - if isinstance(expr, ExprTuple) - else itertools.repeat(expr) + dest_members = [ + TypeTuple.Item(name=name, typ=member.cairo_type) + for name, member in struct_def.members.items() + ] + + verify_tuple_like_cast( + src_type=src_type, + src_members=src_type.members, + dest_type=dest_type, + dest_members=dest_members, + identifier_manager=identifier_manager, + cast_type=cast_type, + location=location, + expr=expr, ) - for (src_expr, src_member_type, dest_member) in zip( - src_exprs, src_type.members, struct_def.members.values() - ): - dest_member_type = dest_member.cairo_type - if not check_cast( - src_type=src_member_type, - dest_type=dest_member_type, - identifier_manager=identifier_manager, - expr=src_expr, - cast_type=CastType.FORCED if cast_type is CastType.FORCED else CastType.ASSIGN, - ): - - raise CairoTypeError( - f"Cannot cast '{src_member_type.format()}' to '{dest_member_type.format()}'.", - location=src_expr.location, - ) - return True if cast_type is CastType.EXPLICIT: @@ -115,3 +121,60 @@ def check_cast( assert cast_type is CastType.FORCED, f"Unsupported cast type: {cast_type}." return False + + +def verify_tuple_like_cast( + src_type: CairoType, + src_members: Sequence[TypeTuple.Item], + dest_type: CairoType, + dest_members: Sequence[TypeTuple.Item], + identifier_manager: IdentifierManager, + cast_type: CastType, + location: Optional[Location], + expr: Optional[Expression], +): + n_src_members = len(src_members) + n_dest_members = len(dest_members) + if n_src_members != n_dest_members: + raise CairoTypeError( + f"""\ +Cannot cast an expression of type '{src_type.format()}' to '{dest_type.format()}'. +The former has {n_src_members} members while the latter has {n_dest_members} members.""", + location=location, + ) + + src_exprs: Iterable[Optional[Expression]] = ( + [arg.expr for arg in expr.members.args] + if isinstance(expr, ExprTuple) + else itertools.repeat(None, times=n_src_members) + ) + + for (src_expr, src_member, dest_member) in safe_zip(src_exprs, src_members, dest_members): + item_location = location + if src_expr is not None and src_expr.location is not None: + item_location = src_expr.location + + src_name = src_member.name + dest_name = dest_member.name + if src_name is not None and dest_name is not None and src_name != dest_name: + raise CairoTypeError( + f"""\ +Cannot cast '{src_type.format()}' to '{dest_type.format()}'. +Expected argument name {dest_name}. Found: {src_name}.""", + location=item_location, + ) + + src_member_type = src_member.typ + dest_member_type = dest_member.typ + if not check_cast( + src_type=src_member_type, + dest_type=dest_member_type, + identifier_manager=identifier_manager, + cast_type=CastType.FORCED if cast_type is CastType.FORCED else CastType.ASSIGN, + location=item_location, + expr=src_expr, + ): + raise CairoTypeError( + f"Cannot cast '{src_member_type.format()}' to '{dest_member_type.format()}'.", + location=item_location, + ) diff --git a/src/starkware/cairo/lang/compiler/type_casts_test.py b/src/starkware/cairo/lang/compiler/type_casts_test.py index 24a88231..5d20b9f2 100644 --- a/src/starkware/cairo/lang/compiler/type_casts_test.py +++ b/src/starkware/cairo/lang/compiler/type_casts_test.py @@ -1,9 +1,25 @@ +from typing import Union + import pytest from starkware.cairo.lang.compiler.ast.cairo_types import CastType from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.parser import parse_expr, parse_type -from starkware.cairo.lang.compiler.type_casts import check_cast +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import PRIME, preprocess_str +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.type_casts import CairoTypeError, check_cast +from starkware.cairo.lang.compiler.type_system import mark_type_resolved +from starkware.python.test_utils import maybe_raises + + +@pytest.fixture(scope="session") +def identifier_manager() -> IdentifierManager: + code = """ +struct T: + member x : (felt, felt) +end +""" + return preprocess_str(code, PRIME, main_scope=ScopedName()).identifiers @pytest.mark.parametrize( @@ -16,26 +32,41 @@ ["T*", "felt*", True, True, True], ["felt*", "T", False, False, False], ["T", "felt*", False, False, False], + # Tuples and named tuples. ["felt", "(felt,felt)", False, False, False], + ["((felt, felt))", "T", True, False, False], + ["(x : (felt, felt))", "T", True, False, False], + ["(y : (felt, felt))", "T", "Expected argument name x. Found: y.", False, False], + ["(felt)", "(a : felt)", True, True, True], + ["(a : felt)", "(felt)", True, True, True], + ["(a : felt, b : felt)", "(a : felt, c : felt)"] + + ["Expected argument name c. Found: b."] * 3, ], ) def test_type_casts( - src: str, dest: str, explicit_cast: bool, unpacking_cast: bool, assign_cast: bool + identifier_manager: IdentifierManager, + src: str, + dest: str, + explicit_cast: Union[bool, str], + unpacking_cast: Union[bool, str], + assign_cast: Union[bool, str], ): - identifier_manager = IdentifierManager() - src_type = parse_type(src) - dest_type = parse_type(dest) + src_type = mark_type_resolved(parse_type(src)) + dest_type = mark_type_resolved(parse_type(dest)) expr = parse_expr("[ap]") - actual_results = [ - check_cast( - src_type=src_type, - dest_type=dest_type, - identifier_manager=identifier_manager, - expr=expr, - cast_type=cast_type, - ) - for cast_type in [CastType.EXPLICIT, CastType.UNPACKING, CastType.ASSIGN] - ] - expected_results = [explicit_cast, unpacking_cast, assign_cast] - assert actual_results == expected_results + for cast_type, expected_result in zip( + [CastType.EXPLICIT, CastType.UNPACKING, CastType.ASSIGN], + [explicit_cast, unpacking_cast, assign_cast], + ): + error_message = expected_result if isinstance(expected_result, str) else None + with maybe_raises(CairoTypeError, error_message): + actual_result = check_cast( + src_type=src_type, + dest_type=dest_type, + identifier_manager=identifier_manager, + cast_type=cast_type, + location=None, + expr=expr, + ) + assert actual_result == expected_result diff --git a/src/starkware/cairo/lang/compiler/type_system.py b/src/starkware/cairo/lang/compiler/type_system.py index 0fa92943..2bc9bb5c 100644 --- a/src/starkware/cairo/lang/compiler/type_system.py +++ b/src/starkware/cairo/lang/compiler/type_system.py @@ -2,6 +2,7 @@ from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, + TypeCodeoffset, TypeFelt, TypePointer, TypeStruct, @@ -16,7 +17,7 @@ def mark_type_resolved(cairo_type: CairoType) -> CairoType: Marks the given type as resolved (struct names are absolute). This function can be used after parsing a string which is known to contain resolved types. """ - if isinstance(cairo_type, TypeFelt): + if isinstance(cairo_type, (TypeFelt, TypeCodeoffset)): return cairo_type elif isinstance(cairo_type, TypePointer): return dataclasses.replace(cairo_type, pointee=mark_type_resolved(cairo_type.pointee)) @@ -26,7 +27,11 @@ def mark_type_resolved(cairo_type: CairoType) -> CairoType: return dataclasses.replace(cairo_type, is_fully_resolved=True) elif isinstance(cairo_type, TypeTuple): return dataclasses.replace( - cairo_type, members=[mark_type_resolved(member) for member in cairo_type.members] + cairo_type, + members=[ + dataclasses.replace(member, typ=mark_type_resolved(member.typ)) + for member in cairo_type.members + ], ) else: raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") @@ -36,14 +41,14 @@ def is_type_resolved(cairo_type: CairoType) -> bool: """ Returns true if the type is resolved (struct names are absolute). """ - if isinstance(cairo_type, TypeFelt): + if isinstance(cairo_type, (TypeFelt, TypeCodeoffset)): return True elif isinstance(cairo_type, TypePointer): return is_type_resolved(cairo_type.pointee) elif isinstance(cairo_type, TypeStruct): return cairo_type.is_fully_resolved elif isinstance(cairo_type, TypeTuple): - return all(map(is_type_resolved, cairo_type.members)) + return all(map(is_type_resolved, cairo_type.types)) else: raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") diff --git a/src/starkware/cairo/lang/compiler/type_system_visitor.py b/src/starkware/cairo/lang/compiler/type_system_visitor.py index 18f838e4..5d6531f7 100644 --- a/src/starkware/cairo/lang/compiler/type_system_visitor.py +++ b/src/starkware/cairo/lang/compiler/type_system_visitor.py @@ -3,6 +3,7 @@ from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, + TypeCodeoffset, TypeFelt, TypePointer, TypeStruct, @@ -19,6 +20,7 @@ ExprHint, ExprIdentifier, ExprNeg, + ExprNewOperator, ExprOperator, ExprParentheses, ExprReg, @@ -33,6 +35,7 @@ IdentifierAwareVisitor, ) from starkware.cairo.lang.compiler.type_casts import CairoTypeError, check_cast +from starkware.python.utils import safe_zip def get_expr_addr(expr: Expression): @@ -56,8 +59,9 @@ def visit_ExprConst(self, expr: ExprConst) -> Tuple[ExprConst, TypeFelt]: def visit_ExprHint(self, expr: ExprHint) -> Tuple[ExprHint, TypeFelt]: return expr, TypeFelt(location=expr.location) - def visit_ExprFutureLabel(self, expr: ExprFutureLabel) -> Tuple[ExprFutureLabel, TypeFelt]: - return expr, TypeFelt(location=expr.identifier.location) + def visit_ExprFutureLabel(self, expr: ExprFutureLabel) -> Tuple[ExprFutureLabel, CairoType]: + type_cls = TypeCodeoffset if expr.is_typed else TypeFelt + return (dataclasses.replace(expr, is_typed=False), type_cls(location=expr.location)) def visit_ExprIdentifier(self, expr: ExprIdentifier) -> Tuple[Expression, CairoType]: raise CairoTypeError( @@ -79,7 +83,7 @@ def visit_ExprOperator(self, expr: ExprOperator) -> Tuple[ExprOperator, CairoTyp result_type = a_type elif isinstance(a_type, TypeFelt) and isinstance(b_type, TypePointer) and op == "+": result_type = b_type - elif isinstance(a_type, TypePointer) and a_type == b_type and op == "-": + elif isinstance(a_type, (TypePointer, TypeCodeoffset)) and a_type == b_type and op == "-": result_type = TypeFelt(location=expr.location) else: raise CairoTypeError( @@ -160,7 +164,7 @@ def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoTyp location=expr.location, ) - item_type = inner_type.members[offset_value] + item_type = inner_type.members[offset_value].typ if isinstance(inner_expr, ExprTuple): assert len(inner_expr.members.args) == tuple_len @@ -175,7 +179,7 @@ def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoTyp # Handles pointers cast as tuples*, e.g. `[cast(ap, (felt, felt)*][0]`. addr = inner_expr.addr offset_in_felts = ExprConst( - val=sum(map(self.get_size, inner_type.members[:offset_value])), + val=sum(map(self.get_size, inner_type.types[:offset_value])), location=offset_expr.location, ) addr_with_offset = ExprOperator( @@ -230,7 +234,7 @@ def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: inner_expr, inner_type = self.visit(expr.expr) if isinstance(inner_type, TypePointer): - if not isinstance(inner_type.pointee, TypeStruct): + if not isinstance(inner_type.pointee, (TypeStruct, TypeTuple)): raise CairoTypeError( f"Cannot apply dot-operator to pointer-to-non-struct type " f"'{inner_type.format()}'.", @@ -238,10 +242,10 @@ def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: ) # Allow for . as ->, once. inner_type = inner_type.pointee - elif isinstance(inner_type, TypeStruct): + elif isinstance(inner_type, (TypeStruct, TypeTuple)): if isinstance(inner_expr, ExprTuple): raise CairoTypeError( - "Accessing struct members for r-value structs is not supported yet.", + "Accessing struct/tuple members for r-value structs is not supported yet.", location=expr.location, ) # Get the address, to evaluate . as ->. @@ -252,22 +256,45 @@ def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: location=expr.location, ) - try: - struct_def = get_struct_definition( - struct_name=inner_type.resolved_scope, identifier_manager=self.identifiers - ) - except Exception as exc: - raise CairoTypeError(str(exc), location=expr.location) + if isinstance(inner_type, TypeStruct): + try: + struct_def = get_struct_definition( + struct_name=inner_type.resolved_scope, identifier_manager=self.identifiers + ) + except Exception as exc: + raise CairoTypeError(str(exc), location=expr.location) - if expr.member.name not in struct_def.members: - raise CairoTypeError( - f"Member '{expr.member.name}' does not appear in definition of struct " - f"'{inner_type.format()}'.", - location=expr.location, - ) - member_definition = struct_def.members[expr.member.name] - member_type = member_definition.cairo_type - member_offset = member_definition.offset + if expr.member.name not in struct_def.members: + raise CairoTypeError( + f"Member '{expr.member.name}' does not appear in definition of struct " + f"'{inner_type.format()}'.", + location=expr.location, + ) + + member_definition = struct_def.members[expr.member.name] + member_type = member_definition.cairo_type + member_offset = member_definition.offset + else: + if isinstance(inner_type, TypeTuple) and not inner_type.is_named: + raise CairoTypeError( + f"Cannot apply dot-operator to unnamed tuple type '{inner_type.format()}'.", + location=expr.location, + ) + + assert isinstance(inner_type, TypeTuple) + member_offset = 0 + for i, member in enumerate(inner_type.members): + if member.name == expr.member.name: + member_type = member.typ + break + + member_offset += self.get_size(member.typ) + else: + raise CairoTypeError( + f"Member '{expr.member.name}' does not appear in definition of tuple type " + f"'{inner_type.format()}'.", + location=expr.location, + ) if member_offset == 0: simplified_expr = ExprDeref(addr=inner_expr, location=expr.location) @@ -288,8 +315,9 @@ def visit_ExprCast(self, expr: ExprCast) -> Tuple[Expression, CairoType]: src_type=src_type, dest_type=dest_type, identifier_manager=self.identifiers, - expr=inner_expr, cast_type=expr.cast_type, + location=expr.location, + expr=inner_expr, ): raise CairoTypeError( f"Cannot cast '{src_type.format()}' to '{dest_type.format()}'.", @@ -303,17 +331,40 @@ def visit_ExprTuple(self, expr: ExprTuple) -> Tuple[ExprTuple, TypeTuple]: args = expr.members.args # Call visit on each member to obtain a list of the form (expr, type). member_expr_types = [self.visit(arg.expr) for arg in args] + # Replace each tuple item with the corresponding type-simplified expression from + # member_expr_types, and remove the name (as the name is part of the type rather than + # the type-simplified expression). result_members = [ - dataclasses.replace(arg, expr=expr) for arg, (expr, _) in zip(args, member_expr_types) + dataclasses.replace(arg, identifier=None, expr=expr) + for arg, (expr, _) in zip(args, member_expr_types) ] result_expr = dataclasses.replace( expr, members=dataclasses.replace(expr.members, args=result_members) ) - cairo_type = TypeTuple( - members=[expr_type for expr, expr_type in member_expr_types], location=expr.location + # Construct the resulting type. Include the names of the tuple items in the returned + # named tuple type. + cairo_type = TypeTuple.from_members( + members=[ + TypeTuple.Item( + name=(None if arg.identifier is None else arg.identifier.name), + typ=expr_type, + location=arg.location, + ) + for arg, (expr, expr_type) in safe_zip(args, member_expr_types) + ], + location=expr.location, ) return result_expr, cairo_type + def visit_ExprNewOperator(self, expr: ExprNewOperator) -> Tuple[ExprNewOperator, CairoType]: + inner_expr, inner_expr_type = self.visit(expr.expr) + expr_type = ( + TypePointer(pointee=inner_expr_type, location=expr.location) + if expr.is_typed + else TypeFelt(location=expr.location) + ) + return ExprNewOperator(expr=inner_expr, is_typed=False, location=expr.location), expr_type + def simplify_type_system( expr: Expression, identifiers: Optional[IdentifierManager] = None diff --git a/src/starkware/cairo/lang/compiler/type_system_visitor_test.py b/src/starkware/cairo/lang/compiler/type_system_visitor_test.py index d0f9124f..8a67b143 100644 --- a/src/starkware/cairo/lang/compiler/type_system_visitor_test.py +++ b/src/starkware/cairo/lang/compiler/type_system_visitor_test.py @@ -4,14 +4,10 @@ import pytest from starkware.cairo.lang.compiler.ast.ast_objects_test_utils import remove_parentheses -from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, - TypeFelt, - TypePointer, - TypeStruct, - TypeTuple, -) +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer, TypeStruct +from starkware.cairo.lang.compiler.ast.expr import ExprFutureLabel, ExprNewOperator from starkware.cairo.lang.compiler.ast_objects_test import remove_parentheses +from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.parser import parse_expr @@ -22,58 +18,79 @@ scope = ScopedName.from_string +class FixIsTypedVisitor(ExpressionTransformer): + """ + A utility class the sets expr.is_typed to false. + + This is useful for getting parsed expressions with is_typed = False to compare + against type system simplified expressions. + """ + + def visit_ExprNewOperator(self, expr: ExprNewOperator) -> ExprNewOperator: + return ExprNewOperator( + expr=self.visit(expr=expr.expr), + is_typed=False, + location=self.location_modifier(expr.location), + ) + + def visit_ExprFutureLabel(self, expr: ExprFutureLabel) -> ExprFutureLabel: + return ExprFutureLabel( + identifier=self.visit(expr.identifier), + is_typed=False, + location=self.location_modifier(expr.location), + ) + + def simplify_type_system_test( orig_expr: str, simplified_expr: str, - simplified_type: CairoType, + simplified_type: str, identifiers: Optional[IdentifierManager] = None, ): parsed_expr = mark_types_in_expr_resolved(parse_expr(orig_expr)) - assert simplify_type_system(parsed_expr, identifiers=identifiers) == ( - parse_expr(simplified_expr), - simplified_type, + expr, typ = simplify_type_system(parsed_expr, identifiers=identifiers) + expected_expr = FixIsTypedVisitor().visit( + expr=remove_parentheses(expr=parse_expr(simplified_expr)) ) + assert expr == expected_expr, f"{expr.format()} != {expected_expr.format()}" + assert typ.format() == simplified_type def test_type_visitor(): - t = TypeStruct(scope=scope("T"), is_fully_resolved=True) - t_star = TypePointer(pointee=t) - t_star2 = TypePointer(pointee=t_star) - - simplify_type_system_test("fp + 3 + [ap]", "fp + 3 + [ap]", TypeFelt()) - simplify_type_system_test("cast(fp + 3 + [ap], T*)", "fp + 3 + [ap]", t_star) + simplify_type_system_test("fp + 3 + [ap]", "fp + 3 + [ap]", "felt") + simplify_type_system_test("cast(fp + 3 + [ap], T*)", "fp + 3 + [ap]", "T*") # Two casts. - simplify_type_system_test("cast(cast(fp, T*), felt)", "fp", TypeFelt()) + simplify_type_system_test("cast(cast(fp, T*), felt)", "fp", "felt") # Cast from T to T. - simplify_type_system_test("cast([cast(fp, T*)], T)", "[fp]", t) + simplify_type_system_test("cast([cast(fp, T*)], T)", "[fp]", "T") # Dereference. - simplify_type_system_test("[cast(fp, T**)]", "[fp]", t_star) - simplify_type_system_test("[[cast(fp, T**)]]", "[[fp]]", t) + simplify_type_system_test("[cast(fp, T**)]", "[fp]", "T*") + simplify_type_system_test("[[cast(fp, T**)]]", "[[fp]]", "T") # Address of. - simplify_type_system_test("&([[cast(fp, T**)]])", "[fp]", t_star) - simplify_type_system_test("&&[[cast(fp, T**)]]", "fp", t_star2) + simplify_type_system_test("&([[cast(fp, T**)]])", "[fp]", "T*") + simplify_type_system_test("&&[[cast(fp, T**)]]", "fp", "T**") def test_type_tuples(): - t = TypeStruct(scope=scope("T"), is_fully_resolved=True) - t_star = TypePointer(pointee=t) - # Simple tuple. simplify_type_system_test( "(fp, [cast(fp, T*)], cast(fp,T*))", "(fp, [fp], fp)", - TypeTuple( - members=[TypeFelt(), t, t_star], - ), + "(felt, T, T*)", + ) + + # Named tuple. + simplify_type_system_test( + "(a=fp, b=[cast(fp, T*)], c=cast(fp,T*))", + "(fp, [fp], fp)", + "(a : felt, b : T, c : T*)", ) # Nested. simplify_type_system_test( "(fp, (), ([cast(fp, T*)],))", "(fp, (), ([fp],))", - TypeTuple( - members=[TypeFelt(), TypeTuple(members=[]), TypeTuple(members=[t])], - ), + "(felt, (), (T))", ) @@ -93,7 +110,7 @@ def test_type_tuples_failures(): verify_exception( "1 + cast((1, 2), T).x", """ -file:?:?: Accessing struct members for r-value structs is not supported yet. +file:?:?: Accessing struct/tuple members for r-value structs is not supported yet. 1 + cast((1, 2), T).x ^***************^ """, @@ -102,18 +119,16 @@ def test_type_tuples_failures(): def test_type_subscript_op(): - felt_star_star = TypePointer(pointee=TypePointer(pointee=TypeFelt())) t = TypeStruct(scope=scope("T"), is_fully_resolved=True) - t_star = TypePointer(pointee=t) identifier_dict = {scope("T"): StructDefinition(full_name=scope("T"), members={}, size=7)} identifiers = IdentifierManager.from_dict(identifier_dict) - simplify_type_system_test("cast(fp, felt*)[3]", "[fp + 3 * 1]", TypeFelt()) - simplify_type_system_test("cast(fp, felt***)[0]", "[fp + 0 * 1]", felt_star_star) - simplify_type_system_test("[cast(fp, T****)][ap][ap]", "[[[fp] + ap * 1] + ap * 1]", t_star) + simplify_type_system_test("cast(fp, felt*)[3]", "[fp + 3 * 1]", "felt") + simplify_type_system_test("cast(fp, felt***)[0]", "[fp + 0 * 1]", "felt**") + simplify_type_system_test("[cast(fp, T****)][ap][ap]", "[[[fp] + ap * 1] + ap * 1]", "T*") simplify_type_system_test( - "cast(fp, T**)[1][2]", "[[fp + 1 * 1] + 2 * 7]", t, identifiers=identifiers + "cast(fp, T**)[1][2]", "[[fp + 1 * 1] + 2 * 7]", "T", identifiers=identifiers ) # Test that 'cast(fp, T*)[2 * ap + 3]' simplifies into '[fp + (2 * ap + 3) * 7]', but without @@ -123,12 +138,12 @@ def test_type_subscript_op(): ) == (remove_parentheses(parse_expr("[fp + (2 * ap + 3) * 7]")), t) # Test subscript operator for tuples. - simplify_type_system_test("(cast(fp, felt**), fp, cast(fp, T*))[2]", "fp", t_star) - simplify_type_system_test("(cast(fp, felt**), fp, cast(fp, T*))[0]", "fp", felt_star_star) - simplify_type_system_test("(cast(fp, felt**), ap, cast(fp, T*))[3*4 - 11]", "ap", TypeFelt()) - simplify_type_system_test("[cast(ap, (felt, felt)*)][0]", "[ap + 0]", TypeFelt()) + simplify_type_system_test("(cast(fp, felt**), fp, cast(fp, T*))[2]", "fp", "T*") + simplify_type_system_test("(cast(fp, felt**), fp, cast(fp, T*))[0]", "fp", "felt**") + simplify_type_system_test("(cast(fp, felt**), ap, cast(fp, T*))[3*4 - 11]", "ap", "felt") + simplify_type_system_test("[cast(ap, (felt, felt)*)][0]", "[ap + 0]", "felt") simplify_type_system_test( - "[cast(ap, (T*, T, felt, T*, felt*)*)][3]", "[ap + 9]", t_star, identifiers=identifiers + "[cast(ap, (T*, T, felt, T*, felt*)*)][3]", "[ap + 9]", "T*", identifiers=identifiers ) # Test failures. @@ -272,23 +287,23 @@ def test_type_dot_op(): identifiers = IdentifierManager.from_dict(identifier_dict) for (orig_expr, simplified_expr, simplified_type) in [ - ("[cast(fp, T*)].t", "[fp]", TypeFelt()), - ("[cast(fp, T*)].s", "[fp + 1]", s), - ("[cast(fp, T*)].sp", "[fp + 3]", s_star), - ("[cast(fp, T*)].s.x", "[fp + 1]", TypeFelt()), - ("[cast(fp, T*)].s.y", "[fp + 1 + 1]", TypeFelt()), - ("[[cast(fp, T*)].sp].x", "[[fp + 3]]", TypeFelt()), - ("[cast(fp, R*)]", "[fp]", r), - ("[cast(fp, R*)].r", "[fp]", r_star), - ("[[[cast(fp, R*)].r].r].r", "[[[fp]]]", r_star), + ("[cast(fp, T*)].t", "[fp]", "felt"), + ("[cast(fp, T*)].s", "[fp + 1]", "S"), + ("[cast(fp, T*)].sp", "[fp + 3]", "S*"), + ("[cast(fp, T*)].s.x", "[fp + 1]", "felt"), + ("[cast(fp, T*)].s.y", "[fp + 1 + 1]", "felt"), + ("[[cast(fp, T*)].sp].x", "[[fp + 3]]", "felt"), + ("[cast(fp, R*)]", "[fp]", "R"), + ("[cast(fp, R*)].r", "[fp]", "R*"), + ("[[[cast(fp, R*)].r].r].r", "[[[fp]]]", "R*"), # Test . as -> - ("cast(fp, T*).t", "[fp]", TypeFelt()), - ("cast(fp, T*).sp.y", "[[fp + 3] + 1]", TypeFelt()), - ("cast(fp, R*).r.r.r", "[[[fp]]]", r_star), + ("cast(fp, T*).t", "[fp]", "felt"), + ("cast(fp, T*).sp.y", "[[fp + 3] + 1]", "felt"), + ("cast(fp, R*).r.r.r", "[[[fp]]]", "R*"), # More tests. - ("(cast(fp, T*).s)", "[fp + 1]", s), - ("(cast(fp, T*).s).x", "[fp + 1]", TypeFelt()), - ("(&(cast(fp, T*).s)).x", "[fp + 1]", TypeFelt()), + ("(cast(fp, T*).s)", "[fp + 1]", "S"), + ("(cast(fp, T*).s).x", "[fp + 1]", "felt"), + ("(&(cast(fp, T*).s)).x", "[fp + 1]", "felt"), ]: simplify_type_system_test( orig_expr, simplified_expr, simplified_type, identifiers=identifiers @@ -358,6 +373,58 @@ def test_type_dot_op(): ) +def test_type_dot_op_named_tuples(): + """ + Tests type_system_visitor for ExprDot-s for named tuples. + """ + identifiers = IdentifierManager() + tuple_ref = "[cast(fp, (x : felt, y : (a : felt, b : felt), z : felt)*)]" + tuple_ptr = "cast(fp, (x : felt, y : (a : felt, b : felt)*, z : felt)*)" + for (orig_expr, simplified_expr, simplified_type) in [ + (f"{tuple_ref}.x", "[fp]", "felt"), + (f"{tuple_ref}.y", "[fp + 1]", "(a : felt, b : felt)"), + (f"{tuple_ref}.y.a", "[fp + 1]", "felt"), + (f"{tuple_ref}.y.b", "[fp + 1 + 1]", "felt"), + (f"{tuple_ref}.z", "[fp + 3]", "felt"), + # Test . as -> + (f"{tuple_ptr}.y.b", "[[fp + 1] + 1]", "felt"), + ]: + simplify_type_system_test( + orig_expr, simplified_expr, simplified_type, identifiers=identifiers + ) + + # Test failures. + + verify_exception( + "[cast(fp, (felt, felt)*)].x", + """ +file:?:?: Cannot apply dot-operator to unnamed tuple type '(felt, felt)'. +[cast(fp, (felt, felt)*)].x +^*************************^ +""", + identifiers=identifiers, + ) + verify_exception( + "[cast(fp, (a : felt, b : felt)*)].x", + """ +file:?:?: Member 'x' does not appear in definition of tuple type '(a : felt, b : felt)'. +[cast(fp, (a : felt, b : felt)*)].x +^*********************************^ +""", + identifiers=identifiers, + ) + + verify_exception( + "(x=1, y=(a=2,b=3), z=4).y.b", + """ +file:?:?: Accessing struct/tuple members for r-value structs is not supported yet. +(x=1, y=(a=2,b=3), z=4).y.b +^***********************^ +""", + identifiers=identifiers, + ) + + def test_type_visitor_failures(): verify_exception( "[cast(fp, T*)] + 3", @@ -394,12 +461,15 @@ def test_type_visitor_failures(): def test_type_visitor_pointer_arithmetic(): - t = TypeStruct(scope=scope("T"), is_fully_resolved=True) - t_star = TypePointer(pointee=t) + simplify_type_system_test("cast(fp, T*) + 3", "fp + 3", "T*") + simplify_type_system_test("cast(fp, T*) - 3", "fp - 3", "T*") + simplify_type_system_test("cast(fp, T*) - cast(3, T*)", "fp - 3", "felt") + - simplify_type_system_test("cast(fp, T*) + 3", "fp + 3", t_star) - simplify_type_system_test("cast(fp, T*) - 3", "fp - 3", t_star) - simplify_type_system_test("cast(fp, T*) - cast(3, T*)", "fp - 3", TypeFelt()) +def test_type_visitor_new_operator(): + simplify_type_system_test("new (3 + 4)", "new (3 + 4)", "felt*") + simplify_type_system_test("new [cast(fp, T*)]", "new [fp]", "T*") + simplify_type_system_test("new (ap, cast(fp, felt*))", "new (ap, fp)", "(felt, felt*)*") def test_type_visitor_pointer_arithmetic_failures(): diff --git a/src/starkware/cairo/lang/compiler/type_utils.py b/src/starkware/cairo/lang/compiler/type_utils.py index 33710a2e..3545f93b 100644 --- a/src/starkware/cairo/lang/compiler/type_utils.py +++ b/src/starkware/cairo/lang/compiler/type_utils.py @@ -31,7 +31,7 @@ def check_felts_only_type( return size elif isinstance(cairo_type, TypeTuple): size = 0 - for item_type in cairo_type.members: + for item_type in cairo_type.types: res = check_felts_only_type(item_type, identifier_manager=identifier_manager) if res is None: return None diff --git a/src/starkware/cairo/lang/ide/vim/syntax/cairo.vim b/src/starkware/cairo/lang/ide/vim/syntax/cairo.vim index c1983b69..77fd2aa7 100644 --- a/src/starkware/cairo/lang/ide/vim/syntax/cairo.vim +++ b/src/starkware/cairo/lang/ide/vim/syntax/cairo.vim @@ -19,7 +19,8 @@ hi def link num Constant hi def link specialIdentifier Special syn keyword statement call jmp ret abs rel if const let end from import static_assert local tempvar - \ felt return assert member cast else alloc_locals as with with_attr nondet dw + \ felt return assert member cast else alloc_locals as with with_attr nondet dw codeoffset new + \ using syn keyword register ap fp syn keyword specialIdentifier SIZEOF_LOCALS SIZE syn match comment '#[^\n]*\n' diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index f414fe1a..572f3ee3 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.7.1", + "version": "0.8.0", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/syntaxes/cairo.tmLanguage.json b/src/starkware/cairo/lang/ide/vscode-cairo/syntaxes/cairo.tmLanguage.json index 0a767ef0..1339a383 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/syntaxes/cairo.tmLanguage.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/syntaxes/cairo.tmLanguage.json @@ -88,7 +88,7 @@ }, { "name": "keyword.other.meta", - "match": "\\b(const|let|local|tempvar|felt|as|from|import|static_assert|return|assert|member|cast|alloc_locals|with|with_attr|nondet|dw)\\b" + "match": "\\b(const|let|local|tempvar|felt|as|from|import|static_assert|return|assert|member|cast|alloc_locals|with|with_attr|nondet|dw|codeoffset|new|using)\\b" }, { "name": "markup.italic", diff --git a/src/starkware/cairo/lang/lang.cmake b/src/starkware/cairo/lang/lang.cmake index 3c4e6bef..d34f545b 100644 --- a/src/starkware/cairo/lang/lang.cmake +++ b/src/starkware/cairo/lang/lang.cmake @@ -36,6 +36,7 @@ python_venv(cairo_lang_package_venv sharp_client_config_lib sharp_client_lib starknet_block_hash_lib + starknet_business_logic_lib starknet_script_lib starknet_testing_lib starkware_eth_test_utils_lib diff --git a/src/starkware/cairo/lang/scripts/cairo-hash-program b/src/starkware/cairo/lang/scripts/cairo-hash-program index beafc0c3..167d8f8c 100755 --- a/src/starkware/cairo/lang/scripts/cairo-hash-program +++ b/src/starkware/cairo/lang/scripts/cairo-hash-program @@ -4,7 +4,7 @@ import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../..')) -from starkware.cairo.bootloader.hash_program import main # noqa +from starkware.cairo.bootloaders.hash_program import main # noqa if __name__ == '__main__': sys.exit(main()) diff --git a/src/starkware/cairo/lang/setup.py b/src/starkware/cairo/lang/setup.py index f16f7838..f8ff75fd 100644 --- a/src/starkware/cairo/lang/setup.py +++ b/src/starkware/cairo/lang/setup.py @@ -22,12 +22,12 @@ url="https://cairo-lang.org/", package_data={ "starkware.cairo.common": ["*.cairo"], - "starkware.cairo.lang.compiler": ["cairo.ebnf"], + "starkware.cairo.lang.compiler": ["cairo.ebnf", "lib/*.cairo"], "starkware.cairo.lang.tracer": ["*.html", "*.css", "*.js", "*.png"], "starkware.cairo.lang": ["VERSION"], "starkware.cairo.sharp": ["config.json"], "starkware.crypto.signature": ["pedersen_params.json"], - "starkware.starknet": ["common/*.cairo"], + "starkware.starknet": ["common/*.cairo", "definitions/*.yml"], "starkware.starknet.core.os": ["*.cairo", "*.json"], "starkware.starknet.security": ["whitelists/*.json"], "starkware.starknet.testing": ["*.json"], diff --git a/src/starkware/cairo/lang/vm/CMakeLists.txt b/src/starkware/cairo/lang/vm/CMakeLists.txt index 183ac1a1..41d513f0 100644 --- a/src/starkware/cairo/lang/vm/CMakeLists.txt +++ b/src/starkware/cairo/lang/vm/CMakeLists.txt @@ -45,6 +45,7 @@ python_lib(cairo_vm_lib cairo_compile_lib cairo_relocatable_lib cairo_vm_crypto_lib + starknet_security_lib starkware_python_utils_lib ) @@ -104,10 +105,13 @@ full_python_test(cairo_vm_test vm_test.py LIBS + cairo_common_lib + cairo_compile_lib cairo_constants_lib cairo_run_lib cairo_vm_lib cairo_vm_utils_lib + starkware_dataclasses_utils_lib starkware_python_utils_lib starkware_python_test_utils_lib pip_marshmallow_dataclass diff --git a/src/starkware/cairo/lang/vm/air_public_input.py b/src/starkware/cairo/lang/vm/air_public_input.py index e5355b89..aa2f7b04 100644 --- a/src/starkware/cairo/lang/vm/air_public_input.py +++ b/src/starkware/cairo/lang/vm/air_public_input.py @@ -1,9 +1,11 @@ +import re from dataclasses import field -from typing import ClassVar, Dict, List, Type +from typing import ClassVar, Dict, List, Tuple, Type import marshmallow import marshmallow_dataclass +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.vm.utils import IntAsHex, MemorySegmentAddresses @@ -53,3 +55,46 @@ def extract_program_output(public_input: PublicInput, memory: Dict[int, int]) -> output_addresses = public_input.memory_segments["output"] assert output_addresses.stop_ptr is not None, "Missing stop_ptr for the output segment." return [memory[addr] for addr in range(output_addresses.begin_addr, output_addresses.stop_ptr)] + + +def get_pages_and_products( + public_memory: List[PublicMemoryEntry], z: int, alpha: int +) -> Tuple[Dict[int, List[int]], Dict[int, int]]: + """ + Rearranges memory entries of the public memory by pages. + Returns a tuple: (page, page_prods). + * pages: a dictionary from page id to a list of interleaved addresses and values. + * page_prods: a dictionary from page_id to the product of the page: + * \prod_i (z - (address_i + alpha * value_i)) + """ + pages: Dict[int, List[int]] = {} + page_prods: Dict[int, int] = {} + for cell in public_memory: + page_id, addr, val = cell.page, cell.address, cell.value + page = pages.setdefault(page_id, []) + page.append(addr) + page.append(val) + page_prods[page_id] = ( + page_prods.get(page_id, 1) * (z - (addr + alpha * val)) + ) % DEFAULT_PRIME + return pages, page_prods + + +def extract_z_and_alpha(annotations: List[str]) -> Tuple[int, int]: + """ + Extracts the interaction elements z and alpha from the proof annotations. + Returns (z, alpha) + """ + interaction_elements = [ + int(x, 16) + for x in re.findall( + r"V->P: /cpu air/STARK/Interaction: Interaction element #\d+: " + r"Field Element\(0x([0-9a-f]+)\)", + "\n".join(annotations), + ) + ] + # Make sure the number of interaction_elements is as expected - z, alpha for the memory and + # z' for the permutation range-check and possibly 3 additional elements for the diluted logic. + assert len(interaction_elements) in [3, 6] + z, alpha = interaction_elements[:2] + return z, alpha diff --git a/src/starkware/cairo/lang/vm/builtin_runner.py b/src/starkware/cairo/lang/vm/builtin_runner.py index 9b26407c..9285c5be 100644 --- a/src/starkware/cairo/lang/vm/builtin_runner.py +++ b/src/starkware/cairo/lang/vm/builtin_runner.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue -from starkware.cairo.lang.vm.utils import MemorySegmentAddresses +from starkware.cairo.lang.vm.utils import MemorySegmentRelocatableAddresses from starkware.python.math_utils import div_ceil, safe_div @@ -63,10 +63,10 @@ def finalize_segments(self, runner): """ @abstractmethod - def get_memory_segment_addresses(self, runner) -> Dict[str, MemorySegmentAddresses]: + def get_memory_segment_addresses(self, runner) -> Dict[str, MemorySegmentRelocatableAddresses]: """ - Returns a dict from segment name to MemorySegmentAddresses (begin_addr and stop_ptr of the - corresponding segment). + Returns a dict from segment name to MemorySegmentRelocatableAddresses + (begin_addr and stop_ptr of the corresponding segment). """ @abstractmethod @@ -179,16 +179,20 @@ def __init__( self.name = name self.included = included self.ratio = ratio - self.base: Optional[RelocatableValue] = None + self._base: Optional[RelocatableValue] = None self.stop_ptr: Optional[RelocatableValue] = None self.cells_per_instance = cells_per_instance self.n_input_cells = n_input_cells def initialize_segments(self, runner): - self.base = runner.segments.add() + self._base = runner.segments.add() + + @property + def base(self) -> RelocatableValue: + assert self._base is not None, "Uninitialized self.base." + return self._base def initial_stack(self) -> List[MaybeRelocatable]: - assert self.base is not None, "Uninitialized self.base." return [self.base] if self.included else [] def final_stack(self, runner, pointer): @@ -234,7 +238,7 @@ def finalize_segments(self, runner): def get_memory_segment_addresses(self, runner): return { - self.name: MemorySegmentAddresses( + self.name: MemorySegmentRelocatableAddresses( begin_addr=self.base, stop_ptr=self.stop_ptr, ) diff --git a/src/starkware/cairo/lang/vm/cairo_pie.py b/src/starkware/cairo/lang/vm/cairo_pie.py index 44669a3c..cb64ebf6 100644 --- a/src/starkware/cairo/lang/vm/cairo_pie.py +++ b/src/starkware/cairo/lang/vm/cairo_pie.py @@ -17,6 +17,7 @@ from starkware.cairo.lang.compiler.program import StrippedProgram, is_valid_builtin_name from starkware.cairo.lang.vm.memory_dict import MemoryDict +from starkware.cairo.lang.vm.memory_segments import is_valid_memory_addr, is_valid_memory_value from starkware.cairo.lang.vm.relocatable import RelocatableValue from starkware.python.utils import add_counters, sub_counters @@ -171,6 +172,12 @@ def empty(cls): def copy(self) -> "ExecutionResources": return copy.deepcopy(self) + def to_dict(self) -> Dict[str, int]: + return dict( + **self.builtin_instance_counter, + n_steps=self.n_steps + self.n_memory_holes, + ) + @dataclasses.dataclass class CairoPie: @@ -289,29 +296,13 @@ def run_validity_checks(self): def run_memory_validity_checks(self): segment_sizes = self.metadata.segment_sizes() - - def is_valid_memory_addr(addr, allow_end_of_segment: bool = False): - """ - Returns True if addr is a relocatable value, such that its segment index appears in - segment_sizes and its offset is in the valid range (if allow_end_of_segment=True, offset - may refer to the next cell *after* the segment). - """ - return ( - isinstance(addr, RelocatableValue) - and isinstance(addr.segment_index, int) - and isinstance(addr.offset, int) - and addr.segment_index in segment_sizes - and 0 - <= addr.offset - < segment_sizes[addr.segment_index] + (1 if allow_end_of_segment else 0) - ) - - def is_valid_memory_value(value): - return isinstance(value, int) or is_valid_memory_addr(value, allow_end_of_segment=True) - for addr, value in self.memory.items(): - assert is_valid_memory_addr(addr), "Invalid memory cell address." - assert is_valid_memory_value(value), f"Invalid memory cell value." + assert is_valid_memory_addr( + addr=addr, segment_sizes=segment_sizes + ), "Invalid memory cell address." + assert is_valid_memory_value( + value=value, segment_sizes=segment_sizes + ), "Invalid memory cell value." @classmethod def verify_zip_format(cls, zf: zipfile.ZipFile): diff --git a/src/starkware/cairo/lang/vm/cairo_pie_test.py b/src/starkware/cairo/lang/vm/cairo_pie_test.py index f416a39f..5d53d0d0 100644 --- a/src/starkware/cairo/lang/vm/cairo_pie_test.py +++ b/src/starkware/cairo/lang/vm/cairo_pie_test.py @@ -1,5 +1,6 @@ import io import random +from typing import Dict import pytest @@ -13,7 +14,8 @@ ) from starkware.cairo.lang.vm.cairo_runner import get_runner_from_code from starkware.cairo.lang.vm.memory_dict import MemoryDict -from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.memory_segments import SEGMENT_SIZE_UPPER_BOUND +from starkware.cairo.lang.vm.relocatable import MaybeRelocatableDict, RelocatableValue from starkware.python.utils import add_counters @@ -33,12 +35,11 @@ def test_cairo_pie_serialize_deserialize(): }, extra_segments=[], ) - memory = MemoryDict( - { - 1: 2, - RelocatableValue(3, 4): RelocatableValue(6, 7), - } - ) + memory: MaybeRelocatableDict = { + 1: 2, + RelocatableValue(3, 4): RelocatableValue(6, 7), + } + additional_data = {"c": ["d", 3]} execution_resources = ExecutionResources( n_steps=10, @@ -50,7 +51,7 @@ def test_cairo_pie_serialize_deserialize(): ) cairo_pie = CairoPie( metadata=metadata, - memory=memory, + memory=MemoryDict(memory), additional_data=additional_data, execution_resources=execution_resources, ) @@ -133,7 +134,7 @@ def test_cairo_pie_memory_invalid_value(cairo_pie: CairoPie): offset=cairo_pie.metadata.execution_segment.size, ) cairo_pie.memory.unfreeze_for_testing() - cairo_pie.memory[output_end] = output_end + 2 + cairo_pie.memory[output_end] = output_end + SEGMENT_SIZE_UPPER_BOUND # It should fail because the address is outside the segment expected size. with pytest.raises(AssertionError, match="Invalid memory cell address."): cairo_pie.run_validity_checks() @@ -151,7 +152,7 @@ def test_add_execution_resources(): dummy_builtins = ["builtin1", "builtin2", "builtin3", "builtin4"] total_execution_resources = ExecutionResources.empty() - total_builtin_instance_counter = {} + total_builtin_instance_counter: Dict[str, int] = {} total_steps = 0 # Create multiple random ExecutionResources objects, sum them using __ add __() and validate @@ -160,7 +161,7 @@ def test_add_execution_resources(): for _ in range(random_n_execution_resources): # Create an ExecutionResources object with random values (random builtin_instance_counter # and random n_steps). - random_builtin_instance_counter = {} + random_builtin_instance_counter: Dict[str, int] = {} random_n_counters = random.randint(0, 3) for _ in range(random_n_counters): random_builtin_type = random.choice(dummy_builtins) diff --git a/src/starkware/cairo/lang/vm/cairo_run.py b/src/starkware/cairo/lang/vm/cairo_run.py index f92b497e..e5b8bcdf 100644 --- a/src/starkware/cairo/lang/vm/cairo_run.py +++ b/src/starkware/cairo/lang/vm/cairo_run.py @@ -236,6 +236,7 @@ def cairo_run(args): ret_code = 0 cairo_pie_input = None if args.program is not None: + assert args.run_from_cairo_pie is None program: ProgramBase = load_program(args.program) initial_memory = MemoryDict() steps_input = args.steps @@ -270,6 +271,7 @@ def cairo_run(args): end = runner.initialize_main_entrypoint() if args.run_from_cairo_pie is not None: + assert cairo_pie_input is not None # Add extra_segments. for segment_info in cairo_pie_input.metadata.extra_segments: runner.segments.add(size=segment_info.size) diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index 7429a014..ed88853c 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union from starkware.cairo.lang.builtins.bitwise.bitwise_builtin_runner import BitwiseBuiltinRunner from starkware.cairo.lang.builtins.hash.hash_builtin_runner import HashBuiltinRunner @@ -30,7 +30,12 @@ from starkware.cairo.lang.vm.output_builtin_runner import OutputBuiltinRunner from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue, relocate_value from starkware.cairo.lang.vm.trace_entry import relocate_trace -from starkware.cairo.lang.vm.utils import MemorySegmentAddresses, ResourcesError, RunResources +from starkware.cairo.lang.vm.utils import ( + MemorySegmentAddresses, + MemorySegmentRelocatableAddresses, + ResourcesError, + RunResources, +) from starkware.cairo.lang.vm.vm import RunContext, VirtualMachine, get_perm_range_check_limits from starkware.crypto.signature.signature import inv_mod_curve_size from starkware.python.math_utils import next_power_of_2, safe_div @@ -200,7 +205,8 @@ def initialize_main_entrypoint(self): if self.proof_mode: # Add the dummy last fp and pc to the public memory, so that the verifier can enforce # [fp - 2] = fp. - stack = [self.execution_base + 2, 0] + stack + stack_prefix: List[MaybeRelocatable] = [self.execution_base + 2, 0] + stack = stack_prefix + stack self.execution_public_memory = list(range(len(stack))) assert isinstance( @@ -288,14 +294,14 @@ def run_until_pc(self, addr: MaybeRelocatable, run_resources: Optional[RunResour if self.vm.run_context.pc != addr: raise self.vm.as_vm_exception( - ResourcesError("Error: End of program was not reached"), self.vm.run_context.pc + ResourcesError("Error: End of program was not reached"), with_traceback=False ) def vm_step(self): if self.vm.run_context.pc == self.final_pc: raise self.vm.as_vm_exception( Exception("Error: Execution reached the end of the program."), - self.vm.run_context.pc, + with_traceback=False, ) self.vm.step() @@ -523,7 +529,7 @@ def check_diluted_check_usage(self): ) * safe_div( self.vm.current_step, - builtin_runner.ratio if hasattr(builtin_runner, "ratio") else 1, + getattr(builtin_runner, "ratio", 1), ) for builtin_runner in self.builtin_runners.values() ) @@ -572,8 +578,13 @@ def gen_arg(self, arg, apply_modulo_to_args=True): """ return self.segments.gen_arg(arg=arg, apply_modulo_to_args=apply_modulo_to_args) - def relocate_value(self, value): - return relocate_value(value, self.segment_offsets, self.program.prime) + def relocate_value(self, value: MaybeRelocatable) -> int: + assert self.segment_offsets is not None, "segment_offsets is not initialized." + relocated = relocate_value( + value=value, segment_offsets=self.segment_offsets, prime=self.program.prime + ) + assert isinstance(relocated, int) + return relocated def get_segment_offsets(self) -> Dict[int, int]: assert self.segment_offsets is not None, "segment_offsets is not initialized." @@ -582,12 +593,11 @@ def get_segment_offsets(self) -> Dict[int, int]: def relocate(self): self.segment_offsets = self.segments.relocate_segments() - self.relocated_memory = MemoryDict( - { - self.relocate_value(addr): self.relocate_value(value) - for addr, value in self.vm_memory.items() - } - ) + initializer: Mapping[MaybeRelocatable, MaybeRelocatable] = { + self.relocate_value(addr): self.relocate_value(value) + for addr, value in self.vm_memory.items() + } + self.relocated_memory = MemoryDict(initializer) self.relocated_trace = relocate_trace( self.vm.trace, self.segment_offsets, self.program.prime ) @@ -605,7 +615,7 @@ def get_relocated_debug_info(self): def get_memory_segment_addresses(self) -> Dict[str, MemorySegmentAddresses]: def get_segment_addresses( - name: str, segment_addresses: MemorySegmentAddresses + name: str, segment_addresses: MemorySegmentRelocatableAddresses ) -> MemorySegmentAddresses: stop_ptr = ( segment_addresses.stop_ptr @@ -634,7 +644,9 @@ def print_memory(self, relocated: bool): val = memory[addr] if addr != old_addr + 1: print("\u22ee") - print(f"{addr:<5} {to_field_element(val=val, prime=self.program.prime)}") + if isinstance(val, int): + val = to_field_element(val=val, prime=self.program.prime) + print(f"{addr:<5} {val}") old_addr = addr print() @@ -643,6 +655,7 @@ def print_output(self, output_callback=to_field_element): return output_runner = self.builtin_runners["output_builtin"] + assert isinstance(output_runner, OutputBuiltinRunner) print("Program output:") _, size = output_runner.get_used_cells_and_allocated_size(self) for i in range(size): @@ -710,6 +723,7 @@ def get_builtin_segments_info(self): assert segment_addresses.stop_ptr is not None, f"{name} segment stop ptr is None." segment_index = begin_addr.segment_index segment_size = segment_addresses.stop_ptr - begin_addr + assert isinstance(segment_size, int) assert name not in builtin_segments, f"Builtin segment name collision: {name}." builtin_segments[name] = SegmentInfo(index=segment_index, size=segment_size) return builtin_segments diff --git a/src/starkware/cairo/lang/vm/memory_dict.py b/src/starkware/cairo/lang/vm/memory_dict.py index 42eb5d25..f706221b 100644 --- a/src/starkware/cairo/lang/vm/memory_dict.py +++ b/src/starkware/cairo/lang/vm/memory_dict.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union, cast from starkware.cairo.lang.vm.memory_dict_backend import MemoryDictBackend from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue @@ -26,7 +26,8 @@ def __init__(self, addr, old_value, new_value): MemoryDictInitializer = Optional[ Union[ - Dict[MaybeRelocatable, MaybeRelocatable], + Mapping[MaybeRelocatable, MaybeRelocatable], + Mapping[int, int], Iterable[Tuple[MaybeRelocatable, MaybeRelocatable]], ] ] diff --git a/src/starkware/cairo/lang/vm/memory_dict_test.py b/src/starkware/cairo/lang/vm/memory_dict_test.py index 91535e65..356d760b 100644 --- a/src/starkware/cairo/lang/vm/memory_dict_test.py +++ b/src/starkware/cairo/lang/vm/memory_dict_test.py @@ -68,7 +68,7 @@ def test_memory_dict_getitem(): def test_memory_dict_check_element(): memory = MemoryDict() with pytest.raises(KeyError, match="must be an int"): - memory["not a number"] = 12 + memory["not a number"] = 12 # type: ignore with pytest.raises(KeyError, match="must be nonnegative"): memory[-12] = 13 with pytest.raises(ValueError, match="The offset of a relocatable value must be nonnegative"): @@ -78,10 +78,11 @@ def test_memory_dict_check_element(): def test_memory_dict_get(): + DEFAULT = 12345 memory = MemoryDict({14: 15}) - assert memory.get(14, "default") == 15 - assert memory.get(1234, "default") == "default" - assert memory.get(-10, "default") == "default" + assert memory.get(14, DEFAULT) == 15 + assert memory.get(1234, DEFAULT) == DEFAULT + assert memory.get(-10, DEFAULT) == DEFAULT # Attempting to read address with a negative offset is ok, it simply returns None. assert memory.get(RelocatableValue(segment_index=10, offset=-2)) is None diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index f844972c..8578b587 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -1,11 +1,13 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple +from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence, Set, Tuple +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer, TypeStruct from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.cairo.lang.vm.vm_exceptions import SecurityError FIRST_MEMORY_ADDR = 1 +SEGMENT_SIZE_UPPER_BOUND = 2 ** 64 class MemorySegmentManager: @@ -158,6 +160,24 @@ def gen_arg(self, arg, apply_modulo_to_args=True) -> MaybeRelocatable: return arg % self.prime return arg + def gen_typed_args(self, args: NamedTuple) -> List[MaybeRelocatable]: + """ + Takes a Cairo typed NamedTuple generated with CairoStructFactory and + returns a Cairo-friendly argument list. + """ + cairo_args = [] + for value, field_type in zip(args, args._field_types.values()): + if field_type is TypePointer or field_type is TypeFelt: + # Pointer or felt. + cairo_args.append(self.gen_arg(arg=value)) + elif field_type is TypeStruct: + # Struct. + cairo_args += self.gen_typed_args(args=value) + else: + raise NotImplementedError(f"{field_type.__name__} is not supported.") + + return cairo_args + def write_arg(self, ptr, arg, apply_modulo_to_args=True): assert isinstance(arg, Iterable) data = [self.gen_arg(arg=x, apply_modulo_to_args=apply_modulo_to_args) for x in arg] @@ -206,3 +226,35 @@ def get_segment_size(self, segment_index: int) -> int: if segment_index in self._segment_sizes else self.get_segment_used_size(segment_index=segment_index) ) + + def is_valid_memory_value(self, value: MaybeRelocatable) -> bool: + assert ( + self._segment_used_sizes is not None + ), "compute_effective_sizes must be called before is_valid_memory_value." + + return is_valid_memory_value(value=value, segment_sizes=self._segment_used_sizes) + + +def is_valid_memory_addr( + addr: MaybeRelocatable, segment_sizes: Dict[int, int], is_concrete_address: bool = True +): + """ + Returns True if addr is a relocatable value, such that its segment index appears in + segment_sizes and its offset is in the valid range (if is_concrete_address=False, offset + may exceed the segment size). + """ + return ( + isinstance(addr, RelocatableValue) + and isinstance(addr.segment_index, int) + and isinstance(addr.offset, int) + and addr.segment_index in segment_sizes + and 0 + <= addr.offset + < (segment_sizes[addr.segment_index] if is_concrete_address else SEGMENT_SIZE_UPPER_BOUND) + ) + + +def is_valid_memory_value(value: MaybeRelocatable, segment_sizes: Dict[int, int]): + return isinstance(value, int) or is_valid_memory_addr( + addr=value, segment_sizes=segment_sizes, is_concrete_address=False + ) diff --git a/src/starkware/cairo/lang/vm/memory_segments_test.py b/src/starkware/cairo/lang/vm/memory_segments_test.py index a020d9c1..810637da 100644 --- a/src/starkware/cairo/lang/vm/memory_segments_test.py +++ b/src/starkware/cairo/lang/vm/memory_segments_test.py @@ -1,8 +1,16 @@ +from typing import List, Set, Tuple + import pytest +from starkware.cairo.common.structs import CairoStructFactory +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager -from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.relocatable import ( + MaybeRelocatable, + MaybeRelocatableDict, + RelocatableValue, +) PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 @@ -14,7 +22,7 @@ def test_relocate_segments(): assert segments.add() == RelocatableValue(segment_index=i, offset=0) segment_sizes = [3, 8, 0, 1, 2] - public_memory_offsets = [ + public_memory_offsets: List[List[Tuple[int, int]]] = [ [(0, 0), (1, 1)], [(i, 0) for i in range(8)], [], @@ -62,16 +70,15 @@ def test_relocate_segments(): def test_get_segment_used_size(): - memory = MemoryDict( - { - RelocatableValue(0, 0): 0, - RelocatableValue(0, 2): 0, - RelocatableValue(1, 5): 0, - RelocatableValue(1, 7): 0, - RelocatableValue(3, 0): 0, - RelocatableValue(4, 1): 0, - } - ) + memory_data: MaybeRelocatableDict = { + RelocatableValue(0, 0): 0, + RelocatableValue(0, 2): 0, + RelocatableValue(1, 5): 0, + RelocatableValue(1, 7): 0, + RelocatableValue(3, 0): 0, + RelocatableValue(4, 1): 0, + } + memory = MemoryDict(memory_data) segments = MemorySegmentManager(memory=memory, prime=PRIME) segments.n_segments = 5 memory.freeze() @@ -98,7 +105,7 @@ def test_get_memory_holes(): seg0 = segments.add(size=10) seg1 = segments.add() - accessed_addresses = {seg0, seg1, seg0 + 1, seg1 + 5} + accessed_addresses: Set[MaybeRelocatable] = {seg0, seg1, seg0 + 1, seg1 + 5} # Since segment 1 has no specified size, we must set a memory entry directly. segments.memory[seg1 + 5] = 0 @@ -108,3 +115,36 @@ def test_get_memory_holes(): seg0_holes = 10 - 2 seg1_holes = 6 - 2 assert segments.get_memory_holes(accessed_addresses) == seg0_holes + seg1_holes + + +def test_gen_typed_args(): + """ + Tests gen_typed_args. + """ + + code = """ +struct Inner: + member a : felt + member b : felt +end + +struct MyStruct: + member nested : Inner + member ptr : Inner* +end +""" + + program = compile_cairo(code=code, prime=PRIME) + + structs = CairoStructFactory.from_program(program=program).structs + my_struct = structs.MyStruct + inner = structs.Inner + + typed_args = my_struct(nested=inner(a=1, b=7), ptr=inner(a=3, b=4)) + + segments = MemorySegmentManager(memory=MemoryDict({}), prime=PRIME) + cairo_arg = segments.gen_typed_args(args=typed_args) + + assert len(cairo_arg) == 3 + assert cairo_arg[:2] == [1, 7] + assert segments.memory.get_range(addr=cairo_arg[2], size=2) == [3, 4] diff --git a/src/starkware/cairo/lang/vm/output_builtin_runner.py b/src/starkware/cairo/lang/vm/output_builtin_runner.py index 5d90b34d..5e191d1c 100644 --- a/src/starkware/cairo/lang/vm/output_builtin_runner.py +++ b/src/starkware/cairo/lang/vm/output_builtin_runner.py @@ -4,7 +4,7 @@ from starkware.cairo.lang.vm.builtin_runner import BuiltinRunner, BuiltinVerifier from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue -from starkware.cairo.lang.vm.utils import MemorySegmentAddresses +from starkware.cairo.lang.vm.utils import MemorySegmentRelocatableAddresses @dataclasses.dataclass @@ -21,13 +21,18 @@ def __init__(self, included: bool): # A map from attribute name to its value. Serialized as part of the additional data of the # builtin. self.attributes: Dict[str, dict] = {} + self._base: Optional[RelocatableValue] = None def initialize_segments(self, runner): - self.base = runner.segments.add() + self._base = runner.segments.add() self.stop_ptr: Optional[RelocatableValue] = None + @property + def base(self) -> RelocatableValue: + assert self._base is not None, "Uninitialized self.base." + return self._base + def initial_stack(self) -> List[MaybeRelocatable]: - assert self.base is not None, "Uninitialized self.base." return [self.base] if self.included else [] def final_stack(self, runner, pointer): @@ -63,7 +68,7 @@ def finalize_segments(self, runner): _, size = self.get_used_cells_and_allocated_size(runner) # A map from an offset to its page id. - offset_to_page = {} + offset_to_page: Dict[int, int] = {} for page_id, page in self.pages.items(): assert page.start + page.size <= size, f"Page {page_id} is out of bounds." for i in range(page.start, page.start + page.size): @@ -79,7 +84,7 @@ def finalize_segments(self, runner): def get_memory_segment_addresses(self, runner): return { - "output": MemorySegmentAddresses( + "output": MemorySegmentRelocatableAddresses( begin_addr=self.base, stop_ptr=self.stop_ptr, ) @@ -98,6 +103,7 @@ def add_page(self, page_id: int, page_start: MaybeRelocatable, page_size: int): and page_start.segment_index == self.base.segment_index ), "page_start must be in the output segment." start = page_start - self.base + assert isinstance(start, int) self.pages[page_id] = PublicMemoryPage(start=start, size=page_size) def add_attribute(self, attribute_name: str, attribute_value: dict): @@ -154,16 +160,16 @@ def set_state(self, state): output_builtin_runner.set_state(old_state) """ - self.base = state.base + self._base = state.base self.pages = state.pages self.attributes = state.attributes - def new_state(self, base): + def new_state(self, base: RelocatableValue): """ Clears the state of the output builtin and sets self.base to the given value. See set_state(). """ - self.base = base + self._base = base self.pages = {} self.attributes = {} diff --git a/src/starkware/cairo/lang/vm/relocatable.py b/src/starkware/cairo/lang/vm/relocatable.py index 01661e98..899037d5 100644 --- a/src/starkware/cairo/lang/vm/relocatable.py +++ b/src/starkware/cairo/lang/vm/relocatable.py @@ -1,8 +1,5 @@ import dataclasses -from typing import Dict, Tuple, TypeVar, Union - -MaybeRelocatable = Union[int, "RelocatableValue"] -T = TypeVar("T", int, MaybeRelocatable) +from typing import Dict, Mapping, SupportsInt, Tuple, Union RELOCATABLE_OFFSET_LOWER_BOUND = -(2 ** 63) RELOCATABLE_OFFSET_UPPER_BOUND = 2 ** 63 @@ -21,7 +18,7 @@ class RelocatableValue: SEGMENT_BITS = 16 OFFSET_BITS = 47 - def __add__(self, other: MaybeRelocatable) -> "RelocatableValue": + def __add__(self, other: "MaybeRelocatable") -> "RelocatableValue": if isinstance(other, int): return RelocatableValue(self.segment_index, self.offset + other) assert not isinstance( @@ -29,10 +26,10 @@ def __add__(self, other: MaybeRelocatable) -> "RelocatableValue": ), f"Cannot add two relocatable values: {self} + {other}." return NotImplemented - def __radd__(self, other: MaybeRelocatable) -> "RelocatableValue": + def __radd__(self, other: "MaybeRelocatable") -> "RelocatableValue": return self + other - def __sub__(self, other: MaybeRelocatable) -> MaybeRelocatable: + def __sub__(self, other: "MaybeRelocatable") -> "MaybeRelocatable": if isinstance(other, int): return RelocatableValue(self.segment_index, self.offset - other) assert self.segment_index == other.segment_index, ( @@ -44,7 +41,7 @@ def __sub__(self, other: MaybeRelocatable) -> MaybeRelocatable: def __mod__(self, other: int): return RelocatableValue(self.segment_index, self.offset % other) - def __lt__(self, other: MaybeRelocatable): + def __lt__(self, other: "MaybeRelocatable"): if isinstance(other, int): # Integers are considered smaller than all relocatable values. return False @@ -52,13 +49,13 @@ def __lt__(self, other: MaybeRelocatable): return NotImplemented return (self.segment_index, self.offset) < (other.segment_index, other.offset) - def __le__(self, other: MaybeRelocatable): + def __le__(self, other: "MaybeRelocatable"): return self < other or self == other - def __ge__(self, other: MaybeRelocatable): + def __ge__(self, other: "MaybeRelocatable"): return not (self < other) - def __gt__(self, other: MaybeRelocatable): + def __gt__(self, other: "MaybeRelocatable"): return not (self <= other) def __hash__(self): @@ -70,7 +67,8 @@ def __format__(self, format_spec): def __str__(self): return f"{self.segment_index}:{self.offset}" - def to_bytes(self, n_bytes: int, byte_order: str) -> bytes: + @staticmethod + def to_bytes(value: "MaybeRelocatable", n_bytes: int, byte_order: str) -> bytes: """ Serializes RelocatableValue as: 1bit | SEGMENT_BITS | OFFSET_BITS @@ -79,15 +77,15 @@ def to_bytes(self, n_bytes: int, byte_order: str) -> bytes: 1bit | num 0 | num """ - if isinstance(self, int): - assert self < 2 ** (8 * n_bytes - 1) - return self.to_bytes(n_bytes, byte_order) - assert n_bytes * 8 > self.SEGMENT_BITS + self.OFFSET_BITS - num = 2 ** (8 * n_bytes - 1) + self.segment_index * 2 ** self.OFFSET_BITS + self.offset + if isinstance(value, int): + assert value < 2 ** (8 * n_bytes - 1) + return value.to_bytes(n_bytes, byte_order) + assert n_bytes * 8 > value.SEGMENT_BITS + value.OFFSET_BITS + num = 2 ** (8 * n_bytes - 1) + value.segment_index * 2 ** value.OFFSET_BITS + value.offset return num.to_bytes(n_bytes, byte_order) @classmethod - def from_bytes(cls, data: bytes, byte_order: str) -> MaybeRelocatable: + def from_bytes(cls, data: bytes, byte_order: str) -> "MaybeRelocatable": n_bytes = len(data) num = int.from_bytes(data, byte_order) if num & (2 ** (8 * n_bytes - 1)): @@ -97,7 +95,7 @@ def from_bytes(cls, data: bytes, byte_order: str) -> MaybeRelocatable: return num @staticmethod - def to_tuple(value: MaybeRelocatable) -> Tuple[int, ...]: + def to_tuple(value: "MaybeRelocatable") -> Tuple[int, ...]: """ Converts a MaybeRelocatable to a tuple (which can be used to serialize the value in JSON). """ @@ -109,7 +107,7 @@ def to_tuple(value: MaybeRelocatable) -> Tuple[int, ...]: raise NotImplementedError(f"Expected MaybeRelocatable, got: {type(value).__name__}.") @staticmethod - def to_felt_or_relocatable(value: T): + def to_felt_or_relocatable(value: Union["RelocatableValue", SupportsInt]) -> "MaybeRelocatable": """ Converts to int unless value is RelocatableValue, otherwise return value as is. """ @@ -118,7 +116,7 @@ def to_felt_or_relocatable(value: T): return int(value) @classmethod - def from_tuple(cls, value: Tuple[int, ...]) -> MaybeRelocatable: + def from_tuple(cls, value: Tuple[int, ...]) -> "MaybeRelocatable": """ Converts a tuple to a MaybeRelocatable. See to_tuple(). """ @@ -130,12 +128,16 @@ def from_tuple(cls, value: Tuple[int, ...]) -> MaybeRelocatable: raise NotImplementedError(f"Expected a tuple of size 1 or 2, got: {value}.") +MaybeRelocatable = Union[int, RelocatableValue] +MaybeRelocatableDict = Dict[MaybeRelocatable, MaybeRelocatable] + + def relocate_value( value: MaybeRelocatable, - segment_offsets: Dict[int, T], + segment_offsets: Mapping[int, MaybeRelocatable], prime: int, allow_missing_segments: bool = False, -) -> T: +) -> MaybeRelocatable: if isinstance(value, int): return value elif isinstance(value, RelocatableValue): diff --git a/src/starkware/cairo/lang/vm/relocatable_fields_test.py b/src/starkware/cairo/lang/vm/relocatable_fields_test.py index 4cec427a..da8f13b5 100644 --- a/src/starkware/cairo/lang/vm/relocatable_fields_test.py +++ b/src/starkware/cairo/lang/vm/relocatable_fields_test.py @@ -1,21 +1,23 @@ from dataclasses import field -from typing import Dict import marshmallow_dataclass -from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.cairo.lang.vm.relocatable import ( + MaybeRelocatable, + MaybeRelocatableDict, + RelocatableValue, +) from starkware.cairo.lang.vm.relocatable_fields import ( MaybeRelocatableDictField, MaybeRelocatableField, ) +from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass @marshmallow_dataclass.dataclass(frozen=True) -class DummyStruct: +class DummyStruct(ValidatedMarshmallowDataclass): val: MaybeRelocatable = field(metadata=dict(marshmallow_field=MaybeRelocatableField())) - dct: Dict[MaybeRelocatable, MaybeRelocatable] = field( - metadata=dict(marshmallow_field=MaybeRelocatableDictField()) - ) + dct: MaybeRelocatableDict = field(metadata=dict(marshmallow_field=MaybeRelocatableDictField())) def test_relocatable_fields_serialize_deserialize(): diff --git a/src/starkware/cairo/lang/vm/relocatable_test.py b/src/starkware/cairo/lang/vm/relocatable_test.py index e3f7d1dd..b6b0b64b 100644 --- a/src/starkware/cairo/lang/vm/relocatable_test.py +++ b/src/starkware/cairo/lang/vm/relocatable_test.py @@ -2,7 +2,7 @@ import pytest -from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue def test_relocatable_operations(): @@ -14,7 +14,7 @@ def test_relocatable_operations(): assert RelocatableValue(1, 101) % 10 == RelocatableValue(1, 1) with pytest.raises(TypeError): - x * y + x * y # type: ignore with pytest.raises(AssertionError): x + x with pytest.raises(AssertionError): @@ -40,14 +40,15 @@ def test_relocatable_inequalities(): @pytest.mark.parametrize("byte_order", ["little", "big"]) @pytest.mark.parametrize("n_bytes", [16, 32]) -def test_relocatable_value_serialization(byte_order, n_bytes): - for num in [19, RelocatableValue(2, 5)]: - assert ( - RelocatableValue.from_bytes( - RelocatableValue.to_bytes(num, n_bytes, byte_order), byte_order - ) - == num +@pytest.mark.parametrize("val", [19, RelocatableValue(2, 5)]) +def test_relocatable_value_serialization(val: MaybeRelocatable, byte_order, n_bytes): + assert ( + RelocatableValue.from_bytes( + data=RelocatableValue.to_bytes(value=val, n_bytes=n_bytes, byte_order=byte_order), + byte_order=byte_order, ) + == val + ) def test_to_tuple_from_tuple(): @@ -74,4 +75,4 @@ def test_relocatable_value_frozen(): with pytest.raises( dataclasses.FrozenInstanceError, match="cannot assign to field 'no_such_field'" ): - x.no_such_field = 5 + x.no_such_field = 5 # type: ignore diff --git a/src/starkware/cairo/lang/vm/security.py b/src/starkware/cairo/lang/vm/security.py index 21bed281..e3d174f1 100644 --- a/src/starkware/cairo/lang/vm/security.py +++ b/src/starkware/cairo/lang/vm/security.py @@ -18,7 +18,7 @@ def verify_secure_runner(runner: CairoRunner, verify_builtins=True): builtin_segments = runner.get_builtin_segments_info() if verify_builtins else {} builtin_segment_names = {seg.index: name for name, seg in builtin_segments.items()} builtin_segment_sizes = {seg.index: seg.size for seg in builtin_segments.values()} - for addr in runner.vm_memory: + for addr, value in runner.vm_memory.items(): # Check pure addresses. if not isinstance(addr, RelocatableValue): raise SecurityError(f"Accessed address {addr} is not relocatable.") @@ -38,6 +38,10 @@ def verify_secure_runner(runner: CairoRunner, verify_builtins=True): if not addr.offset < len(runner.program.data): raise SecurityError(f"Out of bounds access to program segment at {addr}.") + # Check memory value, to be consistent with the CairoPie validation done by SHARP. + if not runner.segments.is_valid_memory_value(value=value): + raise SecurityError(f"Invalid memory value at address {addr}: {value}.") + # Builtin specific checks. try: for builtin_runner in runner.builtin_runners.values(): diff --git a/src/starkware/cairo/lang/vm/utils.py b/src/starkware/cairo/lang/vm/utils.py index d7efc4ec..fa402589 100644 --- a/src/starkware/cairo/lang/vm/utils.py +++ b/src/starkware/cairo/lang/vm/utils.py @@ -1,9 +1,11 @@ import dataclasses import re -from typing import Optional +from typing import Dict, Optional import marshmallow.fields as mfields +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue + class IntAsHex(mfields.Field): """ @@ -37,6 +39,16 @@ class MemorySegmentAddresses: stop_ptr: Optional[int] +@dataclasses.dataclass +class MemorySegmentRelocatableAddresses: + """ + Same as MemorySegmentAddresses, except that the addresses are RelocatableValue. + """ + + begin_addr: RelocatableValue + stop_ptr: Optional[RelocatableValue] + + class ResourcesError(Exception): """ Base class for exceptions thrown due to lack of Cairo run resources. @@ -64,3 +76,30 @@ def consume_step(self): """ if self.n_steps is not None: self.n_steps -= 1 + + +def sort_segments( + memory_segments: Dict[str, MemorySegmentAddresses] +) -> Dict[str, MemorySegmentAddresses]: + """ + Sorts the segments dictionary according to the correct serialization order in the + public input. + Gets and returns a dictionary from segment name to a MemorySegmentAddresses. + """ + segment_names = ["program", "execution", "output", "pedersen", "range_check", "ecdsa"] + if "bitwise" in memory_segments: + segment_names.append("bitwise") + res = {name: memory_segments[name] for name in segment_names} + assert len(res) == len(memory_segments), f"Wrong segments given: {memory_segments}." + return res + + +def decimal_repr(val: MaybeRelocatable, prime: int) -> str: + """ + Returns a (possibly negative) decimal representation of the given value. + """ + if isinstance(val, int): + # Shift val to the range (-prime // 2, prime // 2). + return str((val + prime // 2) % prime - (prime // 2)) + else: + return str(val) diff --git a/src/starkware/cairo/lang/vm/validated_memory_dict_test.py b/src/starkware/cairo/lang/vm/validated_memory_dict_test.py index d36cdf77..de761d2e 100644 --- a/src/starkware/cairo/lang/vm/validated_memory_dict_test.py +++ b/src/starkware/cairo/lang/vm/validated_memory_dict_test.py @@ -1,8 +1,10 @@ +from typing import cast + import pytest from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.cairo.lang.vm.validated_memory_dict import ValidatedMemoryDict +from starkware.cairo.lang.vm.validated_memory_dict import ValidatedMemoryDict, ValidationRule def test_validated_memory_dict(): @@ -29,7 +31,7 @@ def rule_constant_value(mem, addr, constant): memory_validator.add_validation_rule(1, lambda memory, addr: set()) memory_validator.add_validation_rule(2, lambda memory, addr: {addr}) memory_validator.add_validation_rule(3, rule_identical_pairs) - memory_validator.add_validation_rule(4, rule_constant_value, 0) + memory_validator.add_validation_rule(4, cast(ValidationRule, rule_constant_value), 0) addr0 = RelocatableValue.from_tuple((1, 0)) addr1 = RelocatableValue.from_tuple((2, 0)) diff --git a/src/starkware/cairo/lang/vm/virtual_machine_base.py b/src/starkware/cairo/lang/vm/virtual_machine_base.py index 1ee6e24a..779e7cdd 100644 --- a/src/starkware/cairo/lang/vm/virtual_machine_base.py +++ b/src/starkware/cairo/lang/vm/virtual_machine_base.py @@ -1,18 +1,26 @@ import dataclasses +import re import sys from abc import ABC from typing import Any, Callable, Dict, List, Optional, Tuple +from typing_extensions import Protocol + from starkware.cairo.lang.compiler.debug_info import DebugInfo, InstructionLocation from starkware.cairo.lang.compiler.encode import is_call_instruction from starkware.cairo.lang.compiler.expression_evaluator import ExpressionEvaluator from starkware.cairo.lang.compiler.instruction import decode_instruction_values +from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingDataActual from starkware.cairo.lang.compiler.preprocessor.preprocessor import AttributeBase, AttributeScope +from starkware.cairo.lang.compiler.preprocessor.reg_tracking import RegTrackingData from starkware.cairo.lang.compiler.program import Program, ProgramBase +from starkware.cairo.lang.compiler.references import ApDeductionError +from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.vm.builtin_runner import BuiltinRunner from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.cairo.lang.vm.trace_entry import TraceEntry +from starkware.cairo.lang.vm.utils import decimal_repr from starkware.cairo.lang.vm.validated_memory_dict import ValidatedMemoryDict, ValidationRule from starkware.cairo.lang.vm.vm_consts import VmConsts, VmConstsContext from starkware.cairo.lang.vm.vm_exceptions import ( @@ -21,17 +29,29 @@ VmException, VmExceptionBase, ) +from starkware.starknet.security.simple_references import ( + InvalidReferenceExpressionError, + is_simple_reference, +) + + +class Rule(Protocol): + def __call__( + self, vm: "VirtualMachineBase", addr: RelocatableValue, *args: Any + ) -> Optional[MaybeRelocatable]: + pass -Rule = Callable[["VirtualMachineBase", RelocatableValue], Optional[int]] MAX_TRACEBACK_ENTRIES = 20 ERROR_MESSAGE_ATTRIBUTE = "error_message" -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class VmAttributeScope(AttributeBase): start_pc: MaybeRelocatable end_pc: MaybeRelocatable + flow_tracking_data: Optional[FlowTrackingDataActual] + accessible_scopes: List[ScopedName] @classmethod def from_attribute_scope(cls, attr: AttributeScope, program_base: MaybeRelocatable): @@ -40,6 +60,8 @@ def from_attribute_scope(cls, attr: AttributeScope, program_base: MaybeRelocatab value=attr.value, start_pc=program_base + attr.start_pc, end_pc=program_base + attr.end_pc, + flow_tracking_data=attr.flow_tracking_data, + accessible_scopes=attr.accessible_scopes, ) @@ -60,9 +82,9 @@ class RunContextBase(ABC): fp: MaybeRelocatable prime: int - def get_traceback_entries(self): + def get_traceback_entries(self) -> List[Tuple[MaybeRelocatable, MaybeRelocatable]]: """ - Returns the values of pc of the call instructions in the traceback. + Returns the values (fp, pc) corresponding to each call instruction in the traceback. Returns the most recent call last. """ entries = [] @@ -72,25 +94,27 @@ def get_traceback_entries(self): break # Get the previous fp and the return pc. - fp, ret_pc = self.memory.get(fp - 2), self.memory.get(fp - 1) + opt_fp, opt_ret_pc = self.memory.get(fp - 2), self.memory.get(fp - 1) # If one of them is not in memory, abort. - if fp is None or ret_pc is None: + if opt_fp is None or opt_ret_pc is None: break + fp, ret_pc = opt_fp, opt_ret_pc + # Get the two memory cells before ret_pc. instruction0, instruction1 = self.memory.get(ret_pc - 2), self.memory.get(ret_pc - 1) # Try to check if the call instruction is (instruction0, instruction1) or just # instruction1 (with no immediate). # In rare cases this may be ambiguous. - if instruction1 is not None and is_call_instruction( + if isinstance(instruction1, int) and is_call_instruction( encoded_instruction=instruction1, imm=None ): call_pc = ret_pc - 1 elif ( - instruction0 is not None - and instruction1 is not None + isinstance(instruction0, int) + and isinstance(instruction1, int) and is_call_instruction(encoded_instruction=instruction0, imm=instruction1) ): call_pc = ret_pc - 2 @@ -98,7 +122,7 @@ def get_traceback_entries(self): # If none of them seems like the calling instruction, abort. break - entries.append(call_pc) + entries.append((fp, call_pc)) return entries[::-1] @@ -245,39 +269,6 @@ def exit_scope(self): assert len(self.exec_scopes) > 1, "Cannot exit main scope." self.exec_scopes.pop() - def step(self): - self.skip_instruction_execution = False - # Execute hints. - for hint_index, hint in enumerate(self.hints.get(self.run_context.pc, [])): - exec_locals = self.exec_scopes[-1] - exec_locals["memory"] = memory = self.validated_memory - exec_locals["ap"] = ap = self.run_context.ap - exec_locals["fp"] = fp = self.run_context.fp - exec_locals["pc"] = pc = self.run_context.pc - exec_locals["current_step"] = self.current_step - exec_locals["ids"] = hint.consts(pc, ap, fp, memory) - - exec_locals["vm_load_program"] = self.load_program - exec_locals["vm_enter_scope"] = self.enter_scope - exec_locals["vm_exit_scope"] = self.exit_scope - exec_locals.update(self.static_locals) - - self.exec_hint(hint.compiled, exec_locals, hint_index=hint_index) - - # Clear ids (which will be rewritten by the next hint anyway) to make the VM instance - # smaller and faster to copy. - del exec_locals["ids"] - del exec_locals["memory"] - - if self.skip_instruction_execution: - return - - # Decode. - instruction = self.decode_current_instruction() - - # Run. - self.run_instruction(instruction) - def compile_hint(self, source, filename, hint_index: int): """ Compiles the given python source code. @@ -304,32 +295,25 @@ def exec_hint(self, code, globals_, hint_index): hint_exception, notes=[hint_exception.exception_str], hint_index=hint_index ) from None - @property - def last_pc(self): - """ - Returns the value of the program counter for the last instruction that was execute. - Note that this is different from self.run_context.pc which contains the value of the - next instruction to be executed. - """ - return self.trace[-1].pc - def as_vm_exception( - self, exc, pc=None, notes: Optional[List[str]] = None, hint_index: Optional[int] = None + self, + exc, + with_traceback: bool = True, + notes: Optional[List[str]] = None, + hint_index: Optional[int] = None, ): """ - Wraps the exception with a VmException, adding to it location information. If pc is not - given the current pc is used. + Wraps the exception with a VmException, adding to it location information. + The current pc is used. """ - traceback = None - if pc is None: - pc = self.run_context.pc - traceback = self.get_traceback() + pc = self.run_context.pc + traceback = self.get_traceback() if with_traceback else None return VmException( pc=pc, inst_location=self.get_location(pc=pc), inner_exc=exc, - error_attr_value=self.get_error_attr_value(pc), + error_attr_value=self.get_error_attr_value(pc=pc, fp=self.run_context.fp), traceback=traceback, notes=notes, hint_index=hint_index, @@ -338,24 +322,94 @@ def as_vm_exception( def get_location(self, pc) -> Optional[InstructionLocation]: return self.instruction_debug_info.get(pc) - def get_error_attr_value(self, pc) -> str: + def evaluate_reference( + self, + name: str, + accessible_scopes: List[ScopedName], + flow_tracking_data: FlowTrackingDataActual, + fp: MaybeRelocatable, + ) -> MaybeRelocatable: + """ + Returns the value of the given reference with respect to the given fp. + If the reference is ap-based, ApDeductionError is thrown. + """ + assert isinstance(self.program, Program) + identifier = self.program.identifiers.search( + accessible_scopes=accessible_scopes, name=ScopedName.from_string(name) + ) + reference = flow_tracking_data.resolve_reference( + reference_manager=self.program.reference_manager, + name=identifier.get_canonical_name(), + ) + + # A security check that the reference is not too complicated, doesn't rely on other + # references, doesn't contain nondet-hints, etc. + if not is_simple_reference(reference.value, simplicity_bound=20): + raise InvalidReferenceExpressionError() + + # Evaluate the reference using an invalid ap_tracking, which will throw an ApDeductionError + # exception if the reference is ap-based. + expr = reference.eval(RegTrackingData(-1, 0)) + return ExpressionEvaluator[MaybeRelocatable]( + prime=self.prime, + ap=None, + fp=fp, + memory=self.validated_memory, # type: ignore + ).eval(expr) + + def substitute_error_message_references(self, error_message_attr: VmAttributeScope, fp) -> str: + """ + Substitutes references in the given error_message attribute with their actual value. + References are defined with '{}'. E.g., 'x must be positive. Got: {x}'. + """ + error_message = error_message_attr.value + if error_message_attr.flow_tracking_data is None: + return error_message + flow_tracking_data = error_message_attr.flow_tracking_data + + invalid_references = [] + + def substitute_ref(match): + reference = match.group("name") + try: + val = self.evaluate_reference( + name=reference, + accessible_scopes=error_message_attr.accessible_scopes, + flow_tracking_data=flow_tracking_data, + fp=fp, + ) + return decimal_repr(val, self.prime) + except (ApDeductionError, InvalidReferenceExpressionError): + invalid_references.append(reference) + return match.group(0) + + error_message = re.sub(r"{(?P[a-zA-Z0-9.]+)}", substitute_ref, error_message) + if len(invalid_references) > 0: + error_message += ( + f" (Cannot evaluate ap-based or complex references: {invalid_references})" + ) + + return error_message + + def get_error_attr_value(self, pc, fp) -> str: """ Returns the error messages that correspond to the error_message attribute scopes surrounding the given pc. """ - error_value = "" + errors = "" for error_message_attr in self.error_message_attributes: if error_message_attr.start_pc <= pc < error_message_attr.end_pc: - error_value += f"Error message: {error_message_attr.value}\n" - return error_value + error_message = self.substitute_error_message_references(error_message_attr, fp) + errors += f"Error message: {error_message}\n" + return errors def get_traceback(self) -> Optional[str]: """ Returns the traceback at the current pc. """ traceback = "" - for traceback_pc in self.run_context.get_traceback_entries(): - traceback += self.get_error_attr_value(traceback_pc) + for fp, traceback_pc in self.run_context.get_traceback_entries(): + traceback += self.get_error_attr_value(pc=traceback_pc, fp=fp) location = self.get_location(pc=traceback_pc) if location is None: traceback += f"Unknown location (pc={traceback_pc})\n" @@ -417,6 +471,13 @@ def end_run(self): if len(self.exec_scopes) != 1: raise VmExceptionBase("Every enter_scope() requires a corresponding exit_scope().") + def check_eq(self, val0: MaybeRelocatable, val1: MaybeRelocatable) -> bool: + """ + Called when an instruction encounters an assertion that two values should be equal. + This function can be overridden by subclasses. + """ + return val0 == val1 + def get_perm_range_check_limits( trace: List[TraceEntry[int]], memory: MemoryDict diff --git a/src/starkware/cairo/lang/vm/vm_consts.py b/src/starkware/cairo/lang/vm/vm_consts.py index e7d7d26b..25be8f6c 100644 --- a/src/starkware/cairo/lang/vm/vm_consts.py +++ b/src/starkware/cairo/lang/vm/vm_consts.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, MutableMapping, Optional, Union from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, @@ -30,7 +30,6 @@ from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_system_visitor import simplify_type_system -from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import MaybeRelocatable @@ -40,7 +39,7 @@ class VmConstsContext: evaluator: Callable[[Expression], Any] reference_manager: ReferenceManager flow_tracking_data: FlowTrackingData - memory: MemoryDict + memory: MutableMapping[MaybeRelocatable, MaybeRelocatable] pc: int diff --git a/src/starkware/cairo/lang/vm/vm_consts_test.py b/src/starkware/cairo/lang/vm/vm_consts_test.py index b0d17c93..a40543f5 100644 --- a/src/starkware/cairo/lang/vm/vm_consts_test.py +++ b/src/starkware/cairo/lang/vm/vm_consts_test.py @@ -1,5 +1,5 @@ import re -from typing import ClassVar +from typing import ClassVar, MutableMapping, Optional import pytest @@ -26,6 +26,7 @@ from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_system import mark_types_in_expr_resolved +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, MaybeRelocatableDict from starkware.cairo.lang.vm.vm_consts import VmConsts, VmConstsContext scope = ScopedName.from_string @@ -185,7 +186,7 @@ def test_references(): prime = 2 ** 64 + 13 ap = 100 fp = 200 - memory = { + memory: MaybeRelocatableDict = { (ap - 2) + 1: 1234, (ap - 1) + 1: 1000, (ap - 1) + 1 + 2: 13, @@ -236,20 +237,21 @@ def test_references(): assert memory == {(ap - 2) + 1: 1234} memory.clear() - consts.x.typeref.member = 1001 + # Use "type: ignore" since mypy is unable to deduce the type of members of VmConsts. + consts.x.typeref.member = 1001 # type: ignore assert memory == {(ap - 1) + 1 + 10: 1001} memory.clear() consts.x.typeref2 = 4321 assert memory == {(ap - 1) + 1: 4321} - consts.x.typeref2.member = 1 + consts.x.typeref2.member = 1 # type: ignore assert memory == { (ap - 1) + 1: 4321, 4321 + 10: 1, } - consts.x.typeref2.struct.member = 2 + consts.x.typeref2.struct.member = 2 # type: ignore assert memory == { (ap - 1) + 1: 4321, 4321 + 10: 1, @@ -270,14 +272,22 @@ def test_references(): consts.x.typeref = 1000 -def get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memory={}): +def get_vm_consts( + identifier_values, + reference_manager, + flow_tracking_data, + memory: Optional[MutableMapping[MaybeRelocatable, MaybeRelocatable]] = None, +): """ Creates a simple VmConsts object. """ + memory = {} if memory is None else memory identifiers = IdentifierManager.from_dict(identifier_values) context = VmConstsContext( identifiers=identifiers, - evaluator=ExpressionEvaluator(2 ** 64 + 13, 0, 0, memory, identifiers).eval, + evaluator=ExpressionEvaluator[MaybeRelocatable]( + 2 ** 64 + 13, 0, 0, memory, identifiers + ).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, @@ -301,7 +311,7 @@ def test_reference_rebinding(): with pytest.raises(FlowTrackingError, match="Reference 'ref' is revoked."): consts.ref - flow_tracking_data = flow_tracking_data.add_reference( + flow_tracking_data2 = flow_tracking_data.add_reference( reference_manager=reference_manager, name=scope("ref"), ref=Reference( @@ -310,7 +320,7 @@ def test_reference_rebinding(): ap_tracking_data=RegTrackingData(group=0, offset=2), ), ) - consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data) + consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data2) assert consts.ref == 10 @@ -329,7 +339,7 @@ def test_reference_to_structs(): } reference_manager = ReferenceManager() flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData()) - flow_tracking_data = flow_tracking_data.add_reference( + flow_tracking_data2 = flow_tracking_data.add_reference( reference_manager=reference_manager, name=scope("ref"), ref=Reference( @@ -338,15 +348,16 @@ def test_reference_to_structs(): ap_tracking_data=RegTrackingData(group=0, offset=2), ), ) - memory = {103: 200} - consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memory=memory) + memory: MaybeRelocatableDict = {103: 200} + consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data2, memory=memory) assert consts.ref.address_ == 100 assert consts.ref.x.address_ == 200 # Set the pointer ref.x.x to 300. consts.ref.x.x = 300 assert memory[203] == 300 - assert consts.ref.x.x.address_ == 300 + # Use "type: ignore" since mypy is unable to deduce the type of members of VmConsts. + assert consts.ref.x.x.address_ == 300 # type: ignore assert consts.ref.type_ == consts.T @@ -474,7 +485,7 @@ def test_revoked_reference(): prime = 2 ** 64 + 13 ap = 100 fp = 200 - memory = {} + memory: MaybeRelocatableDict = {} flow_tracking_data = FlowTrackingDataActual( ap_tracking=RegTrackingData(group=1, offset=4), diff --git a/src/starkware/cairo/lang/vm/vm_core.py b/src/starkware/cairo/lang/vm/vm_core.py index 52938b84..7a1e75ca 100644 --- a/src/starkware/cairo/lang/vm/vm_core.py +++ b/src/starkware/cairo/lang/vm/vm_core.py @@ -440,9 +440,35 @@ def run_instruction(self, instruction): self.current_step += 1 - def check_eq(self, val0, val1): - """ - Called when an instruction encounters an assertion that two values should be equal. - This function can be overridden by subclasses. - """ - return val0 == val1 + def step(self): + self.skip_instruction_execution = False + # Execute hints. + for hint_index, hint in enumerate(self.hints.get(self.run_context.pc, [])): + exec_locals = self.exec_scopes[-1] + exec_locals["memory"] = memory = self.validated_memory + exec_locals["ap"] = ap = self.run_context.ap + exec_locals["fp"] = fp = self.run_context.fp + exec_locals["pc"] = pc = self.run_context.pc + exec_locals["current_step"] = self.current_step + exec_locals["ids"] = hint.consts(pc, ap, fp, memory) + + exec_locals["vm_load_program"] = self.load_program + exec_locals["vm_enter_scope"] = self.enter_scope + exec_locals["vm_exit_scope"] = self.exit_scope + exec_locals.update(self.static_locals) + + self.exec_hint(hint.compiled, exec_locals, hint_index=hint_index) + + # Clear ids (which will be rewritten by the next hint anyway) to make the VM instance + # smaller and faster to copy. + del exec_locals["ids"] + del exec_locals["memory"] + + if self.skip_instruction_execution: + return + + # Decode. + instruction = self.decode_current_instruction() + + # Run. + self.run_instruction(instruction) diff --git a/src/starkware/cairo/lang/vm/vm_exceptions.py b/src/starkware/cairo/lang/vm/vm_exceptions.py index 50f98883..4b55792a 100644 --- a/src/starkware/cairo/lang/vm/vm_exceptions.py +++ b/src/starkware/cairo/lang/vm/vm_exceptions.py @@ -91,7 +91,7 @@ def replace_stack_item(item: traceback.FrameSummary) -> traceback.FrameSummary: return traceback.FrameSummary(filename=filename, lineno=line_num, name=item.name) tb_exception.stack = traceback.StackSummary.from_list( - map(replace_stack_item, tb_exception.stack) + map(replace_stack_item, tb_exception.stack) # type: ignore ) super().__init__(f"Got an exception while executing a hint.") self.exception_str = "".join(tb_exception.format()) diff --git a/src/starkware/cairo/lang/vm/vm_test.py b/src/starkware/cairo/lang/vm/vm_test.py index ded0de7b..5749de84 100644 --- a/src/starkware/cairo/lang/vm/vm_test.py +++ b/src/starkware/cairo/lang/vm/vm_test.py @@ -1,5 +1,5 @@ import tempfile -from typing import Dict +from typing import Optional, cast import pytest @@ -9,7 +9,12 @@ MemoryDict, UnknownMemoryError, ) -from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.cairo.lang.vm.relocatable import ( + MaybeRelocatable, + MaybeRelocatableDict, + RelocatableValue, +) +from starkware.cairo.lang.vm.virtual_machine_base import Rule, VirtualMachineBase from starkware.cairo.lang.vm.vm import RunContext, VirtualMachine from starkware.cairo.lang.vm.vm_exceptions import InconsistentAutoDeductionError, VmException from starkware.python.test_utils import maybe_raises @@ -21,7 +26,7 @@ def run_single(code: str, steps: int, *, pc=RelocatableValue(0, 10), ap=100, fp= program = compile_cairo(code, PRIME, debug_info=True) # Set memory[fp - 1] to an arbitrary value, since [fp - 1] is assumed to be set. - memory: Dict[MaybeRelocatable, MaybeRelocatable] = { + memory: MaybeRelocatableDict = { **{pc + i: v for i, v in enumerate(program.data)}, fp - 1: 1234, **extra_mem, @@ -41,7 +46,7 @@ def run_single(code: str, steps: int, *, pc=RelocatableValue(0, 10), ap=100, fp= def test_memory_dict(): - d = {1: 2} + d: MaybeRelocatableDict = {1: 2} mem = MemoryDict(d) d[2] = 3 assert 2 not in mem @@ -176,7 +181,7 @@ def test_addap(): vm = run_single(code, 3) mem = [vm.run_context.memory.get(100 + i) for i in range(32)] - assert mem == [3] + [None] * 30 + [4] + assert mem == [3, *[None] * 30, 4] assert vm.run_context.ap == 131 @@ -299,7 +304,7 @@ def f(): cairo_file.flush() program = compile_cairo(code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) program_base = 10 - memory = {program_base + i: v for i, v in enumerate(program.data)} + memory: MaybeRelocatableDict = {program_base + i: v for i, v in enumerate(program.data)} # Set memory[fp - 1] to an arbitrary value, since [fp - 1] is assumed to be set. memory[99] = 1234 @@ -353,7 +358,7 @@ def f(): cairo_file.flush() program = compile_cairo(code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) program_base = 10 - memory = {program_base + i: v for i, v in enumerate(program.data)} + memory: MaybeRelocatableDict = {program_base + i: v for i, v in enumerate(program.data)} # Set memory[fp - 1] to an arbitrary value, since [fp - 1] is assumed to be set. memory[99] = 1234 @@ -401,7 +406,7 @@ def f(): cairo_file.flush() program = compile_cairo(code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) program_base = 10 - memory = {program_base + i: v for i, v in enumerate(program.data)} + memory: MaybeRelocatableDict = {program_base + i: v for i, v in enumerate(program.data)} # Set memory[fp - 1] to an arbitrary value, since [fp - 1] is assumed to be set. memory[99] = 1234 @@ -488,7 +493,7 @@ def test_skip_instruction_execution(): program = compile_cairo(code, PRIME, debug_info=True) initial_ap = 100 - memory: Dict[MaybeRelocatable, MaybeRelocatable] = { + memory: MaybeRelocatableDict = { **{i: v for i, v in enumerate(program.data)}, initial_ap - 1: 1234, } @@ -523,7 +528,7 @@ def test_auto_deduction_rules(): """ program = compile_cairo(code=code, prime=PRIME, debug_info=True) - memory = {i: v for i, v in enumerate(program.data)} + memory: MaybeRelocatableDict = {i: v for i, v in enumerate(program.data)} initial_ap = RelocatableValue(segment_index=1, offset=200) initial_fp = RelocatableValue(segment_index=2, offset=100) @@ -537,13 +542,15 @@ def test_auto_deduction_rules(): vm = VirtualMachine(program, context, {}) - def rule_ap_segment(vm, addr, val): + def rule_ap_segment( + vm: VirtualMachineBase, addr: MaybeRelocatable, val: MaybeRelocatable + ) -> Optional[MaybeRelocatable]: return val - vm.add_auto_deduction_rule(1, rule_ap_segment, 100) - vm.add_auto_deduction_rule(2, lambda vm, addr: None) - vm.add_auto_deduction_rule(2, lambda vm, addr: 200 if addr == initial_fp else None) - vm.add_auto_deduction_rule(2, lambda vm, addr: 456) + vm.add_auto_deduction_rule(1, cast(Rule, rule_ap_segment), 100) + vm.add_auto_deduction_rule(2, cast(Rule, lambda vm, addr: None)) + vm.add_auto_deduction_rule(2, cast(Rule, lambda vm, addr: 200 if addr == initial_fp else None)) + vm.add_auto_deduction_rule(2, cast(Rule, lambda vm, addr: 456)) vm.step() @@ -565,7 +572,7 @@ def test_memory_validation_in_hints(): program = compile_cairo(code=code, prime=PRIME, debug_info=True) initial_ap_and_fp = RelocatableValue(segment_index=1, offset=200) - memory = {i: v for i, v in enumerate(program.data)} + memory: MaybeRelocatableDict = {i: v for i, v in enumerate(program.data)} # Set memory[fp - 1] to an arbitrary value, since [fp - 1] is assumed to be set. memory[initial_ap_and_fp - 1] = 1234 @@ -619,7 +626,7 @@ def test_jmp_segment(): program_base_a = RelocatableValue(0, 10) program_base_b = RelocatableValue(1, 20) - memory = { + memory: MaybeRelocatableDict = { **{program_base_a + i: v for i, v in enumerate(program.data)}, **{program_base_b + i: v for i, v in enumerate(program.data)}, 99: 0, @@ -777,12 +784,12 @@ def test_traceback_with_attr(): call main func foo(x): - with_attr error_message("Error in foo."): + with_attr error_message("Error in foo (x={x})."): with_attr error_message("Should not appear in trace."): assert 0 = 0 end with_attr attr_name("Should not appear in trace (attr_name instead of error_message)."): - %{ assert ids.x != 0 %} + %{ assert ids.x != 1 %} [ap] = 1; ap++ end end @@ -790,8 +797,9 @@ def test_traceback_with_attr(): end func bar(x): - with_attr error_message("Error in bar."): - foo(x * x * x) + tempvar y = x + 2 + with_attr error_message("Error in bar (x={x}, y={y})."): + foo(y * y * y) end return () end @@ -802,7 +810,7 @@ def test_traceback_with_attr(): bar(x=1) end with_attr error_message("Running bar(x=0)."): - bar(x=0) # This line will cause an error. + bar(x=-1) # This line will cause an error. end end return () @@ -815,10 +823,10 @@ def test_traceback_with_attr(): assert ( str(exc_info.value) == """\ -Error message: Error in foo. +Error message: Error in foo (x=1). :10:17: Error at pc=0:16: Got an exception while executing a hint. - %{ assert ids.x != 0 %} + %{ assert ids.x != 1 %} ^*********************^ Cairo traceback (most recent call last): :2:5: (pc=0:10) @@ -826,12 +834,12 @@ def test_traceback_with_attr(): ^*******^ Error message: Running bar(x=0). Error message: Error in main. -:30:17: (pc=0:30) - bar(x=0) # This line will cause an error. - ^******^ -Error message: Error in bar. -:19:13: (pc=0:21) - foo(x * x * x) +:31:17: (pc=0:32) + bar(x=-1) # This line will cause an error. + ^*******^ +Error message: Error in bar (x=-1, y={y}). (Cannot evaluate ap-based or complex references: ['y']) +:20:13: (pc=0:23) + foo(y * y * y) ^************^ Traceback (most recent call last): diff --git a/src/starkware/cairo/sharp/client_lib_test.py b/src/starkware/cairo/sharp/client_lib_test.py index ba140608..6c762c46 100644 --- a/src/starkware/cairo/sharp/client_lib_test.py +++ b/src/starkware/cairo/sharp/client_lib_test.py @@ -3,6 +3,7 @@ import json import pytest +from pytest import MonkeyPatch from urllib3 import PoolManager from starkware.cairo.sharp.client_lib import ClientLib @@ -22,7 +23,7 @@ class Response: data: bytes -def test_add_job(monkeypatch): +def test_add_job(monkeypatch: MonkeyPatch): expected_url = "some url" expected_data = { "action": "add_job", @@ -45,7 +46,7 @@ def check_expected(_, method: str, url: str, body: str): assert res == expected_res -def test_get_status(monkeypatch): +def test_get_status(monkeypatch: MonkeyPatch): expected_url = "some url" expected_id = "some id" expected_data = {"action": "get_status", "request": {"cairo_job_key": expected_id}} @@ -66,7 +67,7 @@ def check_expected(_, method: str, url: str, body: str): assert res == expected_res -def test_error(monkeypatch): +def test_error(monkeypatch: MonkeyPatch): # A mock function enforcing expected scenario. def check_expected(_, method: str, url: str, body: str): # Return an empty response - this should be invalid. diff --git a/src/starkware/cairo/sharp/sharp_client.py b/src/starkware/cairo/sharp/sharp_client.py index aed298c8..1f1192f8 100755 --- a/src/starkware/cairo/sharp/sharp_client.py +++ b/src/starkware/cairo/sharp/sharp_client.py @@ -8,8 +8,8 @@ import tempfile from typing import List, Optional -from starkware.cairo.bootloader.generate_fact import get_cairo_pie_fact_info -from starkware.cairo.bootloader.hash_program import compute_program_hash_chain +from starkware.cairo.bootloaders.generate_fact import get_cairo_pie_fact_info +from starkware.cairo.bootloaders.hash_program import compute_program_hash_chain from starkware.cairo.lang.compiler.assembler import Program from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager from starkware.cairo.sharp.client_lib import CairoPie, ClientLib diff --git a/src/starkware/cairo/sharp/sharp_client_test.py b/src/starkware/cairo/sharp/sharp_client_test.py index f2f99997..b2b53e83 100644 --- a/src/starkware/cairo/sharp/sharp_client_test.py +++ b/src/starkware/cairo/sharp/sharp_client_test.py @@ -2,9 +2,11 @@ import os import tempfile +from pytest import MonkeyPatch + import starkware.cairo.sharp.sharp_client as sharp_client -from starkware.cairo.bootloader.fact_topology import FactInfo -from starkware.cairo.bootloader.generate_fact import get_program_output +from starkware.cairo.bootloaders.fact_topology import FactInfo +from starkware.cairo.bootloaders.generate_fact import get_program_output from starkware.cairo.sharp.sharp_client import SharpClient DIR = os.path.dirname(__file__) @@ -49,7 +51,7 @@ def test_compile_and_run(): assert get_program_output(cairo_pie) == [3 ** 2] -def test_get_fact(monkeypatch): +def test_get_fact(monkeypatch: MonkeyPatch): """ Tests that get_fact() command computes the fact correctly. """ diff --git a/src/starkware/contracts/CMakeLists.txt b/src/starkware/contracts/CMakeLists.txt index cd8df317..754af12e 100644 --- a/src/starkware/contracts/CMakeLists.txt +++ b/src/starkware/contracts/CMakeLists.txt @@ -1,7 +1,3 @@ -add_subdirectory(libraries) +include(utils.cmake) -python_lib(starkware_contracts_utils_lib - PREFIX starkware/contracts - FILES - utils.py -) +add_subdirectory(libraries) diff --git a/src/starkware/contracts/utils.cmake b/src/starkware/contracts/utils.cmake new file mode 100644 index 00000000..c526657d --- /dev/null +++ b/src/starkware/contracts/utils.cmake @@ -0,0 +1,5 @@ +python_lib(starkware_contracts_utils_lib + PREFIX starkware/contracts + FILES + utils.py +) diff --git a/src/starkware/eth/eth_test_utils.py b/src/starkware/eth/eth_test_utils.py index 59c52868..749064d3 100644 --- a/src/starkware/eth/eth_test_utils.py +++ b/src/starkware/eth/eth_test_utils.py @@ -48,6 +48,15 @@ def context_manager(cls): finally: res.stop() + def advance_time(self, n_seconds: int): + self.w3.provider.make_request( + method=web3_types.RPCEndpoint("evm_increaseTime"), params=n_seconds + ) + self.w3.provider.make_request(method=web3_types.RPCEndpoint("evm_mine"), params=[]) + + def get_block_by_hash(self, block_hash: str) -> "EthBlock": + return EthBlock(w3_block=self.w3.eth.getBlock(block_hash)) + class Ganache: """ @@ -177,7 +186,7 @@ def get_events(self, tx: "EthReceipt", name: str) -> List[dict]: event = getattr(self.w3_contract.events, name) return [ {arg_name: handle_w3_value(arg_value) for arg_name, arg_value in event.args.items()} - for event in event().processReceipt(tx.tx_receipt) + for event in event().processReceipt(tx.w3_tx_receipt) ] def decode_transaction_data(self, data): @@ -210,8 +219,8 @@ def transact(self, *args, transact_args: Optional[Dict[str, Any]] = None) -> "Et try: tx_hash = self._func(*args).transact(transact_args) - tx_receipt = self.contract.w3.eth.waitForTransactionReceipt(tx_hash) - return EthReceipt(contract=self.contract, tx_receipt=tx_receipt) + w3_tx_receipt = self.contract.w3.eth.waitForTransactionReceipt(tx_hash) + return EthReceipt(contract=self.contract, w3_tx_receipt=w3_tx_receipt) except web3.exceptions.ContractLogicError as ex: raise EthRevertException(str(ex)) from None @@ -230,19 +239,32 @@ def __call__(self, *args, transact_args=None): class EthReceipt: - def __init__(self, contract, tx_receipt): + def __init__(self, contract, w3_tx_receipt): self.contract = contract - self.tx_receipt = tx_receipt + self.w3_tx_receipt = w3_tx_receipt def get_events(self, name: str) -> List[dict]: return self.contract.get_events(tx=self, name=name) def get_cost(self) -> int: - tx = self.contract.w3.eth.get_transaction(self.tx_receipt.transactionHash) + tx = self.contract.w3.eth.get_transaction(self.w3_tx_receipt.transactionHash) gas_price = tx.get("effectiveGasPrice") if gas_price is None: gas_price = tx["gasPrice"] - return self.tx_receipt.gasUsed * gas_price + return self.w3_tx_receipt.gasUsed * gas_price + + @property + def block_hash(self) -> str: + return self.w3_tx_receipt.blockHash.hex() + + +class EthBlock: + def __init__(self, w3_block): + self.w3_block = w3_block + + @property + def timestamp(self) -> int: + return self.w3_block.timestamp class EthRevertException(Exception): diff --git a/src/starkware/python/CMakeLists.txt b/src/starkware/python/CMakeLists.txt index e04e29c8..d4430412 100644 --- a/src/starkware/python/CMakeLists.txt +++ b/src/starkware/python/CMakeLists.txt @@ -6,8 +6,11 @@ python_lib(starkware_python_utils_lib object_utils.py python_dependencies.py utils.py + utils_stub_module.py + utils_stub_module.pyi LIBS + pip_pyyaml pip_sympy ) diff --git a/src/starkware/python/expression_string.py b/src/starkware/python/expression_string.py index fb1a4a16..71423ea9 100644 --- a/src/starkware/python/expression_string.py +++ b/src/starkware/python/expression_string.py @@ -109,11 +109,20 @@ def double_star_pow(self, other): return ExpressionString(f"{self:HIGHEST} ** {other:HIGHEST}", OperatorPrecedence.POW) def __neg__(self): + # Use OperatorPrecedence.LOWEST (even though the actual precedence of the unary minus is + # higher) so that parentheses will be added even when lower-precedence operators are used. + # For example: `(-x) + y`. return ExpressionString(f"-{self:ADDROF}", OperatorPrecedence.LOWEST) def address_of(self): return ExpressionString(f"&{self:ADDROF}", OperatorPrecedence.ADDROF) + def operator_new(self): + # Use OperatorPrecedence.LOWEST (even though the actual precedence of the new operator is + # higher) so that parentheses will be added even when lower-precedence operators are used. + # For example: `(new x) + y`. + return ExpressionString(f"new {self:ADDROF}", OperatorPrecedence.LOWEST) + def prepend(self, txt): """ Prepends the given text to the string, without changing its OperatorPrecedence. diff --git a/src/starkware/python/utils.py b/src/starkware/python/utils.py index ecbdd9f5..58ce97f4 100644 --- a/src/starkware/python/utils.py +++ b/src/starkware/python/utils.py @@ -8,9 +8,26 @@ import subprocess import time from collections import UserDict -from typing import Any, AsyncIterable, Awaitable, Iterable, List, Optional, TypeVar +from typing import ( + Any, + AsyncIterable, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + TypeVar, +) + +import yaml + +# All functions with stubs are imported from this module. +from starkware.python.utils_stub_module import * # noqa T = TypeVar("T") +NumType = TypeVar("NumType", int, float) HASH_BYTES = 32 @@ -52,6 +69,13 @@ def assert_same_and_get(*args): return args[0] +def assert_exhausted(iterator: Iterator): + """ + Verifies that given iterator is empty. + """ + assert all(False for _ in iterator), "Iterator is not empty." + + def unique(x): """ Removes duplicates while preserving order. @@ -66,7 +90,7 @@ def unique_ordered_union(x, y): return list(dict.fromkeys(list(x) + list(y)).keys()) -def add_counters(x, y): +def add_counters(x: Mapping[T, NumType], y: Mapping[T, NumType]) -> Dict[T, NumType]: """ Given two dicts x, y, returns a dict d s.t. d[a] = d[x] + d[y] @@ -74,7 +98,7 @@ def add_counters(x, y): return {k: x.get(k, 0) + y.get(k, 0) for k in unique_ordered_union(x.keys(), y.keys())} -def sub_counters(x, y): +def sub_counters(x: Mapping[T, NumType], y: Mapping[T, NumType]) -> Dict[T, NumType]: """ Given two dicts x, y, returns a dict d s.t. d[a] = d[x] - d[y] @@ -206,16 +230,6 @@ async def cancel_futures(*futures: asyncio.Future): pass -def safe_zip(*iterables: Iterable[Any]) -> Iterable: - """ - Zips iterables. Makes sure the lengths of all iterables are equal. - """ - sentinel = object() - for combo in itertools.zip_longest(*iterables, fillvalue=sentinel): - assert sentinel not in combo, "Iterables to safe_zip are not equal in length." - yield combo - - def composite(*funcs): """ Returns the composition of all the given functions, which is a function that runs the last @@ -367,3 +381,12 @@ def to_ascii_string(value: str) -> str: Converts the given string to an ascii-encodeable one by replacing non-ascii characters with '?'. """ return value.encode("ascii", "replace").decode("ascii") + + +def update_yaml_file(file_path: str, data: Dict[str, Any]): + """ + Updates yaml file in given path with given data. + """ + with open(file_path, "w") as fp: + fp.write(yaml.dump(data=data, default_flow_style=False, width=400)) + fp.flush() diff --git a/src/starkware/python/utils_stub_module.py b/src/starkware/python/utils_stub_module.py new file mode 100644 index 00000000..2ee0575a --- /dev/null +++ b/src/starkware/python/utils_stub_module.py @@ -0,0 +1,15 @@ +import itertools +from typing import Any, Iterable + +# This file contains functions of utils.py, for which stubs exist in the corresponding *.pyi file. +# It is needed since mypy looks for all definitions in *.pyi files, without fallingback to the *.py. + + +def safe_zip(*iterables: Iterable[Any]) -> Iterable: + """ + Zips iterables. Makes sure the lengths of all iterables are equal. + """ + sentinel = object() + for combo in itertools.zip_longest(*iterables, fillvalue=sentinel): + assert sentinel not in combo, "Iterables to safe_zip are not equal in length." + yield combo diff --git a/src/starkware/python/utils_stub_module.pyi b/src/starkware/python/utils_stub_module.pyi new file mode 100644 index 00000000..0c56eff1 --- /dev/null +++ b/src/starkware/python/utils_stub_module.pyi @@ -0,0 +1,62 @@ +from typing import Any, Generic, Iterable, Iterator, Tuple, TypeVar, overload + + +# Type variables. + +_T_co = TypeVar("_T_co", covariant=True) +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") +_T4 = TypeVar("_T4") +_T5 = TypeVar("_T5") + + +# Stubs. + +class safe_zip(Iterator[_T_co], Generic[_T_co]): + @overload + def __new__(cls, __iter1: Iterable[_T1]) -> safe_zip[Tuple[_T1]]: ... + + @overload + def __new__( + cls, __iter1: Iterable[_T1], __iter2: Iterable[_T2] + ) -> safe_zip[Tuple[_T1, _T2]]: ... + + @overload + def __new__( + cls, __iter1: Iterable[_T1], __iter2: Iterable[_T2], __iter3: Iterable[_T3] + ) -> safe_zip[Tuple[_T1, _T2, _T3]]: ... + + @overload + def __new__( + cls, + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], + ) -> safe_zip[Tuple[_T1, _T2, _T3, _T4]]: ... + + @overload + def __new__( + cls, + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], + __iter5: Iterable[_T5], + ) -> safe_zip[Tuple[_T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def __new__( + cls, + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], + __iter5: Iterable[_T5], + *iterables: Iterable[Any], + ) -> safe_zip[Tuple[Any, ...]]: ... + + def __iter__(self) -> Iterator[_T_co]: ... + def __next__(self) -> _T_co: ... diff --git a/src/starkware/python/utils_test.py b/src/starkware/python/utils_test.py index 1966416a..ced0e6bd 100644 --- a/src/starkware/python/utils_test.py +++ b/src/starkware/python/utils_test.py @@ -1,12 +1,14 @@ import random import re import string +from itertools import count import pytest from starkware.python.utils import ( WriteOnceDict, all_subclasses, + assert_exhausted, blockify, composite, gather_in_chunks, @@ -149,3 +151,22 @@ def test_to_ascii_str(): converted_string = to_ascii_string(value=string_pattern.format(value=chr(order))) assert converted_string.isascii() assert converted_string == expected_string + + +def test_assert_exhausted(): + # Positive flow. + assert_exhausted(iterator=iter([])) + + # Negative flow. + with pytest.raises( + AssertionError, + match=re.escape("Iterator is not empty."), + ): + assert_exhausted(iterator=iter([1])) + + # Check that infinite iterator fails assertion. + with pytest.raises( + AssertionError, + match=re.escape("Iterator is not empty."), + ): + assert_exhausted(iterator=count(start=0, step=1)) diff --git a/src/starkware/starknet/business_logic/CMakeLists.txt b/src/starkware/starknet/business_logic/CMakeLists.txt index e6491d01..2add5d49 100644 --- a/src/starkware/starknet/business_logic/CMakeLists.txt +++ b/src/starkware/starknet/business_logic/CMakeLists.txt @@ -14,6 +14,7 @@ python_lib(starknet_business_logic_lib starknet_general_config_lib starknet_os_abi_lib starknet_storage_lib + starkware_commitment_tree_facts_lib starkware_config_utils_lib starkware_dataclasses_utils_lib starkware_error_handling_lib @@ -23,6 +24,21 @@ python_lib(starknet_business_logic_lib pip_marshmallow_dataclass ) +python_lib(starknet_business_logic_utils_lib + PREFIX starkware/starknet/business_logic + + FILES + utils.py + + LIBS + cairo_function_runner_lib + starknet_abi_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_transaction_hash_lib + starkware_error_handling_lib +) + python_lib(starknet_internal_transaction_interface_lib PREFIX starkware/starknet/business_logic @@ -57,6 +73,7 @@ python_lib(starknet_transaction_execution_objects_lib starknet_contract_definition_lib starknet_definitions_lib starkware_dataclasses_utils_lib + pip_marshmallow pip_marshmallow_dataclass ) @@ -67,23 +84,22 @@ python_lib(starknet_internal_transaction_lib internal_transaction.py LIBS - cairo_function_runner_lib - cairo_relocatable_lib cairo_vm_lib everest_business_logic_lib everest_internal_transaction_lib everest_transaction_lib starknet_abi_lib starknet_business_logic_lib + starknet_business_logic_utils_lib starknet_contract_address_lib starknet_contract_definition_lib starknet_definitions_lib + starknet_execute_entry_point_lib starknet_general_config_lib starknet_internal_transaction_interface_lib starknet_os_abi_lib - starknet_os_utils_lib - starknet_storage_lib starknet_transaction_execution_objects_lib + starknet_transaction_fee_lib starknet_transaction_hash_lib starknet_transaction_lib starkware_config_utils_lib @@ -96,3 +112,59 @@ python_lib(starknet_internal_transaction_lib pip_marshmallow_enum pip_marshmallow_oneofschema ) + +python_lib(starknet_execute_entry_point_base_lib + PREFIX starkware/starknet/business_logic + + FILES + execute_entry_point_base.py + + LIBS + starknet_business_logic_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_general_config_lib + starknet_transaction_execution_objects_lib +) + +python_lib(starknet_execute_entry_point_lib + PREFIX starkware/starknet/business_logic + + FILES + execute_entry_point.py + + LIBS + cairo_function_runner_lib + cairo_relocatable_lib + cairo_vm_lib + starknet_abi_lib + starknet_business_logic_lib + starknet_business_logic_utils_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_execute_entry_point_base_lib + starknet_general_config_lib + starknet_os_utils_lib + starknet_storage_lib + starknet_transaction_execution_objects_lib + starkware_dataclasses_utils_lib + starkware_error_handling_lib + pip_marshmallow_dataclass +) + +python_lib(starknet_transaction_fee_lib + PREFIX starkware/starknet/business_logic + + FILES + transaction_fee.py + + LIBS + starknet_abi_lib + starknet_business_logic_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_execute_entry_point_lib + starknet_general_config_lib + starknet_transaction_execution_objects_lib + starkware_error_handling_lib +) diff --git a/src/starkware/starknet/business_logic/execute_entry_point.py b/src/starkware/starknet/business_logic/execute_entry_point.py new file mode 100644 index 00000000..11eb5312 --- /dev/null +++ b/src/starkware/starknet/business_logic/execute_entry_point.py @@ -0,0 +1,295 @@ +import asyncio +import functools +import logging +from typing import List, Tuple, cast + +import marshmallow_dataclass + +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.lang.vm.cairo_pie import ExecutionResources +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.security import SecurityError +from starkware.cairo.lang.vm.utils import ResourcesError +from starkware.cairo.lang.vm.vm_exceptions import HintException, VmException, VmExceptionBase +from starkware.starknet.business_logic.execute_entry_point_base import ExecuteEntryPointBase +from starkware.starknet.business_logic.state import CarriedState +from starkware.starknet.business_logic.transaction_execution_objects import ( + CallInfo, + TransactionExecutionContext, +) +from starkware.starknet.business_logic.utils import get_return_values +from starkware.starknet.core.os import os_utils, syscall_utils +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starknet.definitions.general_config import StarknetGeneralConfig +from starkware.starknet.public import abi as starknet_abi +from starkware.starknet.services.api.contract_definition import ( + ContractDefinition, + ContractEntryPoint, +) +from starkware.starknet.storage.starknet_storage import BusinessLogicStarknetStorage +from starkware.starkware_utils.error_handling import ( + StarkException, + stark_assert, + wrap_with_stark_exception, +) +from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass + +logger = logging.getLogger(__name__) + + +@marshmallow_dataclass.dataclass(frozen=True) +class ExecuteEntryPoint(ValidatedMarshmallowDataclass, ExecuteEntryPointBase): + """ + Represents a Cairo entry point execution of a StarkNet contract. + """ + + async def execute( + self, + state: CarriedState, + general_config: StarknetGeneralConfig, + tx_execution_context: TransactionExecutionContext, + ) -> CallInfo: + """ + Executes the selected entry point with the given calldata in the specified contract. + The information collected from this run (number of steps required, modifications to the + contract storage, etc.) is saved on the carried state argument. + Returns a CallInfo object that represents the execution. + """ + # Pass the running loop before entering to it. It will be used to run asynchronous + # tasks, such as fetching data from storage. + loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + sync_execute = functools.partial( + self.sync_execute, + state=state, + general_config=general_config, + loop=loop, + tx_execution_context=tx_execution_context, + ) + + return await loop.run_in_executor( + executor=None, # Runs on the default executor. + func=sync_execute, + ) + + def sync_execute( + self, + state: CarriedState, + general_config: StarknetGeneralConfig, + loop: asyncio.AbstractEventLoop, + tx_execution_context: TransactionExecutionContext, + ) -> CallInfo: + """ + Synchronous version of execute_entry_point with a given TransactionExecutionContext object; + needed since this function also runs inside Cairo hints (when processing internal contract + calls). + Should be called from whithin the given loop. + """ + previous_cairo_usage = state.cairo_usage + + runner, syscall_handler = self._run( + state=state, + general_config=general_config, + loop=loop, + tx_execution_context=tx_execution_context, + ) + + # Apply modifications to the contract storage. + state.update_contract_storage( + contract_address=self.contract_address, + modifications=syscall_handler.starknet_storage.get_modifications(), + ) + + # Update resources usage (for bouncer). + state.cairo_usage += runner.get_execution_resources() + + # Build and return call info. + return self._build_call_info( + previous_cairo_usage=previous_cairo_usage, + syscall_handler=syscall_handler, + retdata=get_return_values(runner=runner), + ) + + def _run( + self, + state: CarriedState, + general_config: StarknetGeneralConfig, + loop: asyncio.AbstractEventLoop, + tx_execution_context: TransactionExecutionContext, + ) -> Tuple[CairoFunctionRunner, syscall_utils.BusinessLogicSysCallHandler]: + """ + Runs the selected entry point with the given calldata in the code of the contract deployed + at self.code_address. + The execution is done in the context (e.g., storage) of the contract at + self.contract_address. + Returns the corresponding CairoFunctionRunner and BusinessLogicSysCallHandler in order to + retrieve the execution information. + """ + # Extract pre-fetched contract code from carried state. + code_contract_state = state.contract_states[self.code_address].state + code_contract_state.assert_initialized(contract_address=self.code_address) + + # Prepare input for Cairo function runner. + contract_definition = state.contract_definitions[code_contract_state.contract_hash] + contract_definition.validate() + entry_point = self._get_selected_entry_point(contract_definition=contract_definition) + + # Run the specified contract entry point with given calldata. + with wrap_with_stark_exception(code=StarknetErrorCode.SECURITY_ERROR): + runner = CairoFunctionRunner(program=contract_definition.program, layout="all") + os_context = os_utils.prepare_os_context(runner=runner) + + # Extract pre-fetched contract state from carried state. + pre_run_contract_carried_state = state.contract_states[self.contract_address] + contract_state = pre_run_contract_carried_state.state + contract_state.assert_initialized(contract_address=self.contract_address) + + starknet_storage = BusinessLogicStarknetStorage( + commitment_tree=contract_state.storage_commitment_tree, + ffc=state.ffc, + # Note that pending_modifications might be modified during the run as a result of an + # internal call. + pending_modifications=pre_run_contract_carried_state.storage_updates.copy(), + loop=loop, + ) + + initial_syscall_ptr = cast(RelocatableValue, os_context[starknet_abi.SYSCALL_PTR_OFFSET]) + syscall_handler = syscall_utils.BusinessLogicSysCallHandler( + execute_entry_point_cls=ExecuteEntryPoint, + tx_execution_context=tx_execution_context, + state=state, + caller_address=self.caller_address, + contract_address=self.contract_address, + starknet_storage=starknet_storage, + general_config=general_config, + initial_syscall_ptr=initial_syscall_ptr, + ) + + # Positional arguments are passed to *args in the 'run_from_entrypoint' function. + entry_points_args = [ + self.entry_point_selector, + os_context, + len(self.calldata), + self.calldata, + ] + + try: + runner.run_from_entrypoint( + entry_point.offset, + *entry_points_args, + hint_locals={ + "__storage": starknet_storage, + "syscall_handler": syscall_handler, + }, + static_locals={ + "__find_element_max_size": 2 ** 20, + "__squash_dict_max_size": 2 ** 20, + "__keccak_max_size": 2 ** 20, + "__usort_max_size": 2 ** 20, + }, + run_resources=tx_execution_context.run_resources, + verify_secure=True, + ) + except VmException as exception: + code = StarknetErrorCode.TRANSACTION_FAILED + if isinstance(exception.inner_exc, HintException): + hint_exception = exception.inner_exc + + if isinstance(hint_exception.inner_exc, syscall_utils.HandlerException): + stark_exception = hint_exception.inner_exc.stark_exception + code = stark_exception.code + called_contract_address = hint_exception.inner_exc.called_contract_address + message_prefix = ( + f"Error in the called contract ({hex(called_contract_address)}):\n" + ) + # Override python's traceback and keep the Cairo one of the inner exception. + exception.notes = [message_prefix + str(stark_exception.message)] + + if isinstance(exception.inner_exc, ResourcesError): + code = StarknetErrorCode.OUT_OF_RESOURCES + + raise StarkException(code=code, message=str(exception)) + except VmExceptionBase as exception: + raise StarkException(code=StarknetErrorCode.TRANSACTION_FAILED, message=str(exception)) + except SecurityError as exception: + raise StarkException(code=StarknetErrorCode.SECURITY_ERROR, message=str(exception)) + except Exception: + logger.error("Got an unexpected exception.", exc_info=True) + raise StarkException( + code=StarknetErrorCode.UNEXPECTED_FAILURE, + message="Got an unexpected exception during the execution of the transaction.", + ) + + # Complete handler validations. + os_utils.validate_and_process_os_context( + runner=runner, + syscall_handler=syscall_handler, + initial_os_context=os_context, + ) + + # When execution starts the stack holds entry_points_args + [ret_fp, ret_pc]. + args_ptr = runner.initial_fp - (len(entry_points_args) + 2) + + # The arguments are touched by the OS and should not be counted as holes, mark them + # as accessed. + assert isinstance(args_ptr, RelocatableValue) # Downcast. + runner.mark_as_accessed(address=args_ptr, size=len(entry_points_args)) + + return runner, syscall_handler + + def _get_selected_entry_point( + self, contract_definition: ContractDefinition + ) -> ContractEntryPoint: + """ + Returns the entry point with selector corresponding with self.entry_point_selector. + """ + entry_points = contract_definition.entry_points_by_type[self.entry_point_type] + filtered_entry_points = list( + filter( + lambda ep: ep.selector == self.entry_point_selector, + entry_points, + ) + ) + + if len(filtered_entry_points) == 0 and len(entry_points) > 0: + first_entry_point = entry_points[0] + if first_entry_point.selector == starknet_abi.DEFAULT_ENTRY_POINT_SELECTOR: + return first_entry_point + + selector_formatter = fields.EntryPointSelectorField.format + address_formatter = fields.ContractAddressField.format + # Non-unique entry points are not possible in a ContractDefinition object, thus + # len(filtered_entry_points) <= 1. + stark_assert( + len(filtered_entry_points) == 1, + code=StarknetErrorCode.ENTRY_POINT_NOT_FOUND_IN_CONTRACT, + message=( + f"Entry point {selector_formatter(self.entry_point_selector)} not found in contract" + f" with address {address_formatter(self.contract_address)}." + ), + ) + + (entry_point,) = filtered_entry_points + return entry_point + + def _build_call_info( + self, + previous_cairo_usage: ExecutionResources, + syscall_handler: syscall_utils.BusinessLogicSysCallHandler, + retdata: List[int], + ) -> CallInfo: + return CallInfo( + caller_address=self.caller_address, + contract_address=self.contract_address, + code_address=self.code_address, + entry_point_selector=self.entry_point_selector, + entry_point_type=self.entry_point_type, + calldata=self.calldata, + retdata=retdata, + execution_resources=syscall_handler.state.cairo_usage - previous_cairo_usage, + events=syscall_handler.events, + l2_to_l1_messages=syscall_handler.l2_to_l1_messages, + storage_read_values=syscall_handler.starknet_storage.read_values, + accessed_storage_keys=syscall_handler.starknet_storage.accessed_addresses, + internal_calls=syscall_handler.internal_calls, + ) diff --git a/src/starkware/starknet/business_logic/execute_entry_point_base.py b/src/starkware/starknet/business_logic/execute_entry_point_base.py new file mode 100644 index 00000000..ef8a2ae4 --- /dev/null +++ b/src/starkware/starknet/business_logic/execute_entry_point_base.py @@ -0,0 +1,57 @@ +import asyncio +import dataclasses +from abc import ABC, abstractmethod +from dataclasses import field +from typing import List, TypeVar + +from starkware.starknet.business_logic.state import CarriedState, StateSelector +from starkware.starknet.business_logic.transaction_execution_objects import ( + CallInfo, + TransactionExecutionContext, +) +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.general_config import StarknetGeneralConfig +from starkware.starknet.services.api.contract_definition import EntryPointType + +TExecuteEntryPoint = TypeVar("TExecuteEntryPoint", bound="ExecuteEntryPointBase") + + +# Mypy has a problem with dataclasses that contain unimplemented abstract methods. +# See https://github.com/python/mypy/issues/5374 for details on this problem. +@dataclasses.dataclass(frozen=True) # type: ignore[misc] +class ExecuteEntryPointBase(ABC): + """ + Represents a StarkNet contract call. This interface is meant to prevent a cyclic dependency + with the BusinessLogicSyscallHandler. + """ + + contract_address: int = field(metadata=fields.contract_address_metadata) + # The address that holds the code to execute. + # It may differ from contract_address in the case of delegate call. + code_address: int = field(metadata=fields.contract_address_metadata) + entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) + # The decorator type of the called function. Note that a single function may be decorated with + # multiple decorators and this member specifies which one. + entry_point_type: EntryPointType + calldata: List[int] = field(metadata=fields.call_data_metadata) + # Caller address is zero for external calls and the caller (contract) address for composed ones. + caller_address: int = field(metadata=fields.caller_address_metadata) + + def get_call_state_selector(self) -> StateSelector: + """ + Returns the state selector of the call (i.e., subset of state commitment tree leaves + it affects). + """ + return StateSelector(contract_addresses={self.contract_address, self.code_address}) + + @abstractmethod + def sync_execute( + self, + state: CarriedState, + general_config: StarknetGeneralConfig, + loop: asyncio.AbstractEventLoop, + tx_execution_context: TransactionExecutionContext, + ) -> CallInfo: + """ + Executes the entry point. Should be called from within the given loop. + """ diff --git a/src/starkware/starknet/business_logic/internal_transaction.py b/src/starkware/starknet/business_logic/internal_transaction.py index ee597b77..ff663a5c 100644 --- a/src/starkware/starknet/business_logic/internal_transaction.py +++ b/src/starkware/starknet/business_logic/internal_transaction.py @@ -1,10 +1,8 @@ -import asyncio import dataclasses -import functools import logging from abc import abstractmethod from dataclasses import field -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, cast +from typing import Any, ClassVar, Dict, List, Optional, Type import marshmallow import marshmallow_dataclass @@ -13,17 +11,9 @@ from services.everest.api.gateway.transaction import EverestTransaction from services.everest.business_logic.internal_transaction import EverestInternalTransaction from services.everest.business_logic.state import CarriedStateBase -from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner from starkware.cairo.lang.vm.cairo_pie import ExecutionResources -from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.cairo.lang.vm.utils import ResourcesError -from starkware.cairo.lang.vm.vm_exceptions import ( - HintException, - SecurityError, - VmException, - VmExceptionBase, -) from starkware.python.utils import to_bytes +from starkware.starknet.business_logic.execute_entry_point import ExecuteEntryPoint from starkware.starknet.business_logic.internal_transaction_interface import ( InternalStateTransaction, ) @@ -34,42 +24,33 @@ ContractState, ) from starkware.starknet.business_logic.transaction_execution_objects import ( - ContractCall, + CallInfo, TransactionExecutionContext, TransactionExecutionInfo, ) -from starkware.starknet.core.os import os_utils, syscall_utils +from starkware.starknet.business_logic.transaction_fee import charge_fee +from starkware.starknet.business_logic.utils import preprocess_invoke_function_fields from starkware.starknet.core.os.contract_hash import compute_contract_hash from starkware.starknet.core.os.transaction_hash import ( - TransactionHashPrefix, calculate_deploy_transaction_hash, calculate_transaction_hash_common, ) -from starkware.starknet.definitions import fields +from starkware.starknet.definitions import constants, fields from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.definitions.transaction_type import TransactionType -from starkware.starknet.public.abi import ( - DEFAULT_ENTRY_POINT_SELECTOR, - SYSCALL_PTR_OFFSET, - get_selector_from_name, -) +from starkware.starknet.public import abi as starknet_abi from starkware.starknet.services.api.contract_definition import ( + CONSTRUCTOR_SELECTOR, ContractDefinition, - ContractEntryPoint, EntryPointType, ) from starkware.starknet.services.api.gateway.contract_address import ( calculate_contract_address_from_hash, ) from starkware.starknet.services.api.gateway.transaction import Deploy, InvokeFunction, Transaction -from starkware.starknet.storage.starknet_storage import BusinessLogicStarknetStorage from starkware.starkware_utils.config_base import Config -from starkware.starkware_utils.error_handling import ( - StarkException, - stark_assert, - wrap_with_stark_exception, -) +from starkware.starkware_utils.error_handling import stark_assert from starkware.storage.storage import FactFetchingContext, Storage logger = logging.getLogger(__name__) @@ -122,7 +103,7 @@ def __init_subclass__(cls, **kwargs): # Record only the first class with this related_external_type. recorded_cls = InternalTransaction.external_to_internal_cls.setdefault( - cls.related_external_cls, cls + cls.related_external_cls, cls # type: ignore[arg-type] ) # Check that this class is indeed that class or a subclass of it. @@ -186,15 +167,6 @@ async def _apply_specific_state_updates( ) -> TransactionExecutionInfo: pass - def _synchronous_apply_specific_state_updates( - self, - state: CarriedState, - general_config: StarknetGeneralConfig, - loop: asyncio.AbstractEventLoop, - tx_execution_context: TransactionExecutionContext, - ) -> TransactionExecutionInfo: - pass - class SyntheticTransaction(InternalStateTransaction): """ @@ -243,18 +215,6 @@ async def _apply_specific_state_updates( def get_state_selector(self, general_config: Config) -> StateSelector: return StateSelector.empty() - def _synchronous_apply_specific_state_updates( - self, - state: CarriedState, - general_config: StarknetGeneralConfig, - loop: asyncio.AbstractEventLoop, - tx_execution_context: TransactionExecutionContext, - ) -> Optional[TransactionExecutionInfo]: - """ - This method is not supported. - """ - raise NotImplementedError - @marshmallow_dataclass.dataclass(frozen=True) class InternalDeploy(InternalTransaction): @@ -327,7 +287,7 @@ async def create_for_testing( contract_definition: ContractDefinition, contract_address_salt: int, constructor_calldata: List[int], - chain_id: int, + chain_id: Optional[int] = None, ) -> "InternalDeploy": """ Creates an InternalDeploy transaction and writes its contract definition to the DB. @@ -339,7 +299,7 @@ async def create_for_testing( contract_address_salt=contract_address_salt, contract_definition=contract_definition, constructor_calldata=constructor_calldata, - chain_id=chain_id, + chain_id=0 if chain_id is None else chain_id, ) return tx @@ -369,9 +329,6 @@ def get_state_selector(self, general_config: Config) -> StateSelector: Returns the state selector of the transaction (i.e., subset of state commitment tree leaves it affects). """ - # Downcast arguments to application-specific types. - assert isinstance(general_config, StarknetGeneralConfig) - return StateSelector(contract_addresses={self.contract_address}) async def _apply_specific_state_updates( @@ -425,48 +382,44 @@ async def invoke_constructor( code=StarknetErrorCode.TRANSACTION_FAILED, message="Cannot pass calldata to a contract with no constructor.", ) - return TransactionExecutionInfo.create( - call_info=ContractCall.empty(to_address=self.contract_address) + return TransactionExecutionInfo( + call_info=CallInfo.empty(contract_address=self.contract_address), + fee_transfer_info=None, + actual_fee=0, ) - tx = InternalInvokeFunction( + call = ExecuteEntryPoint( contract_address=self.contract_address, code_address=self.contract_address, - entry_point_selector=get_selector_from_name("constructor"), + entry_point_selector=CONSTRUCTOR_SELECTOR, entry_point_type=EntryPointType.CONSTRUCTOR, calldata=self.constructor_calldata, - signature=[], - hash_value=0, caller_address=0, - nonce=None, ) - return await tx._apply_specific_state_updates(state=state, general_config=general_config) + tx_execution_context = TransactionExecutionContext.create_for_call( + account_contract_address=0, + n_steps=general_config.invoke_tx_max_n_steps, + ) + call_info = await call.execute( + state=state, general_config=general_config, tx_execution_context=tx_execution_context + ) + return TransactionExecutionInfo(call_info=call_info, fee_transfer_info=None, actual_fee=0) @marshmallow_dataclass.dataclass(frozen=True) -class InternalInvokeFunction(InternalTransaction): +class InternalInvokeFunction(InternalTransaction, ExecuteEntryPoint): """ Represents an internal transaction in the StarkNet network that is an invocation of a Cairo contract function. """ # For fields that are shared with InvokeFunction, see documentation there. - contract_address: int = field(metadata=fields.contract_address_metadata) - # The address that holds the code to execute. - # It may differ from contract_address in the case of delegate call. - code_address: int = field(metadata=fields.contract_address_metadata) - entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) - # The decorator type of the called function. Note that a single function may be decorated with - # multiple decorators and this member specifies which one. - entry_point_type: EntryPointType - calldata: List[int] = field(metadata=fields.call_data_metadata) - signature: List[int] = field(metadata=fields.signature_metadata) + max_fee: int = field(metadata=fields.fee_metadata) + version: int = field(metadata=fields.tx_version_metadata) # A unique identifier of the transaction in the StarkNet network. hash_value: int = field(metadata=fields.transaction_hash_metadata) - # Caller address is zero for external calls and the caller (contract) address for composed ones. - caller_address: int = field(metadata=fields.caller_address_metadata) - + signature: List[int] = field(metadata=fields.signature_metadata) # A unique nonce, added by the StarkNet core contract on L1. # This nonce is used to make the hash_value of transactions that service L1 messages unique. # This field may be set only when entry_point_type is EntryPointType.L1_HANDLER. @@ -482,25 +435,26 @@ def create_for_testing( contract_address: int, calldata: List[int], entry_point_selector: int, - code_address: Optional[int] = None, + max_fee: Optional[int] = None, entry_point_type: Optional[EntryPointType] = None, signature: Optional[List[int]] = None, - hash_value: Optional[int] = None, caller_address: Optional[int] = None, nonce: Optional[int] = None, + chain_id: Optional[int] = None, ): - return cls( + return cls.create( contract_address=contract_address, - code_address=contract_address if code_address is None else code_address, entry_point_selector=entry_point_selector, + max_fee=0 if max_fee is None else max_fee, + version=constants.TRANSACTION_VERSION, entry_point_type=( EntryPointType.EXTERNAL if entry_point_type is None else entry_point_type ), calldata=calldata, signature=[] if signature is None else signature, - hash_value=0 if hash_value is None else hash_value, caller_address=0 if caller_address is None else caller_address, nonce=nonce, + chain_id=0 if chain_id is None else chain_id, ) @classmethod @@ -509,47 +463,53 @@ def _specific_from_external( ) -> "InternalInvokeFunction": assert isinstance(external_tx, InvokeFunction) return cls.create( - general_config=general_config, contract_address=external_tx.contract_address, entry_point_selector=external_tx.entry_point_selector, + max_fee=external_tx.max_fee, entry_point_type=EntryPointType.EXTERNAL, calldata=external_tx.calldata, signature=external_tx.signature, nonce=None, + chain_id=general_config.chain_id.value, + version=external_tx.version, ) @classmethod def create( cls, - general_config: StarknetGeneralConfig, contract_address: int, entry_point_selector: int, + max_fee: int, entry_point_type: EntryPointType, calldata: List[int], signature: List[int], nonce: Optional[int], + chain_id: int, + version: Optional[int] = None, # The caller_address of an external transaction or L1 handler is always 0. # The caller_address is passed as paramater to allow the testing framework to initiate # transactions with a user specified caller_address. caller_address: int = 0, ) -> "InternalInvokeFunction": - if entry_point_type is EntryPointType.EXTERNAL: - tx_hash_prefix = TransactionHashPrefix.INVOKE - assert nonce is None, "An InvokeFunction transaction cannot have a nonce." - additional_data = [] - elif entry_point_type is EntryPointType.L1_HANDLER: - tx_hash_prefix = TransactionHashPrefix.L1_HANDLER - assert nonce is not None, "An L1 handler transaction should must have a nonce." - additional_data = [nonce] - else: - raise NotImplementedError(f"Entry point type {entry_point_type.name} is not supported.") + if version is None: + version = constants.TRANSACTION_VERSION + + tx_hash_prefix, additional_data = preprocess_invoke_function_fields( + entry_point_type=entry_point_type, + entry_point_selector=entry_point_selector, + message_from_l1_nonce=nonce, + max_fee=max_fee, + version=version, + ) hash_value = calculate_transaction_hash_common( tx_hash_prefix=tx_hash_prefix, + version=version, contract_address=contract_address, entry_point_selector=entry_point_selector, calldata=calldata, - chain_id=general_config.chain_id.value, + max_fee=max_fee, + chain_id=chain_id, additional_data=additional_data, ) @@ -557,6 +517,8 @@ def create( contract_address=contract_address, code_address=contract_address, entry_point_selector=entry_point_selector, + max_fee=max_fee, + version=version, entry_point_type=entry_point_type, calldata=calldata, signature=signature, @@ -579,6 +541,8 @@ def to_external(self) -> InvokeFunction: contract_address=self.contract_address, entry_point_selector=self.entry_point_selector, calldata=self.calldata, + max_fee=self.max_fee, + version=self.version, signature=self.signature, ) @@ -587,317 +551,58 @@ def get_state_selector(self, general_config: Config) -> StateSelector: Returns the state selector of the transaction (i.e., subset of state commitment tree leaves it affects). """ + call_selector = self.get_call_state_selector() + if self.max_fee == 0: + return call_selector + # Downcast arguments to application-specific types. assert isinstance(general_config, StarknetGeneralConfig) - - return StateSelector(contract_addresses={self.contract_address, self.code_address}) + return call_selector | StateSelector(contract_addresses={general_config.fee_token_address}) async def _apply_specific_state_updates( self, state: CarriedState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: """ - Applies self to 'state' by running _synchronous_apply_specific_state_updates. - This is the asynchronous version of the method below. - """ - account_contract_address = ( - 0 if self.entry_point_type is EntryPointType.CONSTRUCTOR else self.contract_address - ) - - # Pass the running loop before entering to it. It will be used to run asynchronous - # tasks, such as fetching data from storage. - loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - execute_contract_function = functools.partial( - self.execute_contract_function, - state=state, - general_config=general_config, - loop=loop, - tx_execution_context=TransactionExecutionContext.create( - account_contract_address=account_contract_address, - n_steps=general_config.invoke_tx_max_n_steps, - ), - ) - - execution_info = await loop.run_in_executor( - executor=None, # Runs on the default executor. - func=execute_contract_function, - ) - - return execution_info - - def execute_contract_function( - self, - state: CarriedState, - general_config: StarknetGeneralConfig, - loop: asyncio.AbstractEventLoop, - tx_execution_context: TransactionExecutionContext, - ) -> TransactionExecutionInfo: - """ - Runs the selected entry point with the given calldata in the contract specified by the - transaction. - The information collected from this run (number of steps required, modifications to the - contract storage, etc.) is saved on the carried state argument. - In addition, builds and return the specific transaction execution information, to be used - by the StarkNet OS run in the GpsAmbassador, and by the FeederGateway. - - This function also runs inside Cairo hints (when processing internal contract calls), - thus must be synchronous. + Applies self to 'state' by executing the entry point and charging fee for it (if needed). """ - previous_cairo_usage = state.cairo_usage - - runner, syscall_handler = self._run( - state=state, - general_config=general_config, - loop=loop, - caller_address=self.caller_address, - tx_execution_context=tx_execution_context, - ) + call_info = await self.execute_tx(state=state, general_config=general_config) - # Apply modifications to the contract storage. - state.update_contract_storage( - contract_address=self.contract_address, - modifications=syscall_handler.starknet_storage.get_modifications(), - ) - - # Update resources usage (for bouncer). - state.cairo_usage += runner.get_execution_resources() - - # Build and return transaction execution info. - return TransactionExecutionInfo( - call_info=self._build_call_info( - previous_cairo_usage=previous_cairo_usage, syscall_handler=syscall_handler - ), - l2_to_l1_messages=syscall_handler.l2_to_l1_messages, - retdata=self._get_return_values(runner=runner), - internal_calls=syscall_handler.internal_calls, - ) - - def _run( - self, - state: CarriedState, - general_config: StarknetGeneralConfig, - loop: asyncio.AbstractEventLoop, - caller_address: int, - tx_execution_context: TransactionExecutionContext, - ) -> Tuple[CairoFunctionRunner, syscall_utils.BusinessLogicSysCallHandler]: - """ - Runs the selected entry point with the given calldata in the code of the contract deployed - at self.code_address. - The execution is done in the context (e.g., storage) of the contract at - self.contract_address. - Returns the corresponding CairoFunctionRunner and BusinessLogicSysCallHandler in order to - retrieve the execution information. - """ - # Extract pre-fetched contract code from carried state. - code_contract_state = state.contract_states[self.code_address].state - code_contract_state.assert_initialized(contract_address=self.code_address) - - # Prepare input for Cairo function runner. - contract_definition = state.contract_definitions[code_contract_state.contract_hash] - contract_definition.validate() - entry_point = self._get_selected_entry_point(contract_definition=contract_definition) - - # Run the specified contract entry point with given calldata. - with wrap_with_stark_exception(code=StarknetErrorCode.SECURITY_ERROR): - runner = CairoFunctionRunner(program=contract_definition.program, layout="all") - os_context = os_utils.prepare_os_context(runner=runner) - - # Extract pre-fetched contract state from carried state. - pre_run_contract_carried_state = state.contract_states[self.contract_address] - contract_state = pre_run_contract_carried_state.state - contract_state.assert_initialized(contract_address=self.contract_address) - - starknet_storage = BusinessLogicStarknetStorage( - commitment_tree=contract_state.storage_commitment_tree, - ffc=state.ffc, - # Note that pending_modifications might be modified during the run as a result of an - # internal call. - pending_modifications=pre_run_contract_carried_state.storage_updates.copy(), - loop=loop, - ) - - initial_syscall_ptr = cast(RelocatableValue, os_context[SYSCALL_PTR_OFFSET]) - syscall_handler = syscall_utils.BusinessLogicSysCallHandler( - tx_execution_context=tx_execution_context, - state=state, - caller_address=caller_address, - contract_address=self.contract_address, - signature=self.signature, - starknet_storage=starknet_storage, - general_config=general_config, - initial_syscall_ptr=initial_syscall_ptr, - ) - - # Positional arguments are passed to *args in the 'run_from_entrypoint' function. - - entry_points_args = [ - self.entry_point_selector, - os_context, - len(self.calldata), - self.calldata, - ] - - try: - runner.run_from_entrypoint( - entry_point.offset, - *entry_points_args, - hint_locals={ - "__storage": starknet_storage, - "syscall_handler": syscall_handler, - }, - static_locals={ - "__find_element_max_size": 2 ** 20, - "__squash_dict_max_size": 2 ** 20, - "__keccak_max_size": 2 ** 20, - }, - run_resources=tx_execution_context.run_resources, - verify_secure=True, - ) - except VmException as exception: - code = StarknetErrorCode.TRANSACTION_FAILED - if isinstance(exception.inner_exc, HintException): - hint_exception = exception.inner_exc - - if isinstance(hint_exception.inner_exc, syscall_utils.HandlerException): - stark_exception = hint_exception.inner_exc.stark_exception - code = stark_exception.code - called_contract_address = hint_exception.inner_exc.called_contract_address - message_prefix = ( - f"Error in the called contract ({hex(called_contract_address)}):\n" - ) - # Override python's traceback and keep the Cairo one of the inner exception. - exception.notes = [message_prefix + str(stark_exception.message)] - - if isinstance(exception.inner_exc, ResourcesError): - code = StarknetErrorCode.OUT_OF_RESOURCES - - raise StarkException(code=code, message=str(exception)) - except VmExceptionBase as exception: - raise StarkException(code=StarknetErrorCode.TRANSACTION_FAILED, message=str(exception)) - except SecurityError as exception: - raise StarkException(code=StarknetErrorCode.SECURITY_ERROR, message=str(exception)) - except Exception: - logger.error("Got an unexpected exception.", exc_info=True) - raise StarkException( - code=StarknetErrorCode.UNEXPECTED_FAILURE, - message="Got an unexpected exception during the execution of the transaction.", - ) - - # Complete handler validations. - os_utils.validate_and_process_os_context( - runner=runner, - syscall_handler=syscall_handler, - initial_os_context=os_context, - ) - - # When execution starts the stack holds entry_points_args + [ret_fp, ret_pc]. - args_ptr = runner.initial_fp - (len(entry_points_args) + 2) - - # The arguments are touched by the OS and should not be counted as holes, mark them - # as accessed. - runner.mark_as_accessed(address=args_ptr, size=len(entry_points_args)) - - return runner, syscall_handler - - def _get_selected_entry_point( - self, contract_definition: ContractDefinition - ) -> ContractEntryPoint: - """ - Returns the entry point with selector corresponding with self.entry_point_selector. - """ - - entry_points = contract_definition.entry_points_by_type[self.entry_point_type] - filtered_entry_points = list( - filter( - lambda ep: ep.selector == self.entry_point_selector, - entry_points, + fee_transfer_info: Optional[CallInfo] = None + actual_fee = 0 + # Fee charging is not enforced yet, one can skip that by setting max_fee=0. + if self.max_fee > 0: + # Should always pass on regular flows (verified in the create() method). + assert self.entry_point_selector == starknet_abi.EXECUTE_ENTRY_POINT_SELECTOR + fee_transfer_info, actual_fee = await charge_fee( + general_config=general_config, + state=state, + account_contract_address=self.contract_address, + execution_resources=call_info.execution_resources.to_dict(), + max_fee=self.max_fee, ) - ) - if len(filtered_entry_points) == 0 and len(entry_points) > 0: - ep0 = entry_points[0] - if ep0.selector == DEFAULT_ENTRY_POINT_SELECTOR: - return ep0 - - selector_formatter = fields.EntryPointSelectorField.format - address_formatter = fields.ContractAddressField.format - # Non-unique entry points are not possible in a ContractDefinition object, thus - # len(filtered_entry_points) <= 1. - stark_assert( - len(filtered_entry_points) == 1, - code=StarknetErrorCode.ENTRY_POINT_NOT_FOUND_IN_CONTRACT, - message=( - f"Entry point {selector_formatter(self.entry_point_selector)} not found in contract" - f" with address {address_formatter(self.contract_address)}." - ), + return TransactionExecutionInfo( + call_info=call_info, fee_transfer_info=fee_transfer_info, actual_fee=actual_fee ) - (entry_point,) = filtered_entry_points - return entry_point - - async def call(self, state: CarriedState, general_config: StarknetGeneralConfig) -> List[int]: + async def execute_tx( + self, state: CarriedState, general_config: StarknetGeneralConfig + ) -> CallInfo: """ - Runs the selected entry point with the given calldata in the contract specified by the - transaction. - Returns the return data. - Note that this function modifies the state. + Builds the transaction execution context and executes the entry point. + Returns the CallInfo. """ - # Pass the running loop before entering to it. It will be used to run asynchronous - # tasks, such as fetching data from storage. - loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - _run = functools.partial( - self._run, - state=state, - general_config=general_config, - loop=loop, - caller_address=self.caller_address, - tx_execution_context=TransactionExecutionContext.create( - account_contract_address=self.contract_address, - n_steps=general_config.invoke_tx_max_n_steps, - ), - ) - - runner, _ = await loop.run_in_executor(executor=None, func=_run) - return self._get_return_values(runner=runner) - - def _build_call_info( - self, - previous_cairo_usage: ExecutionResources, - syscall_handler: syscall_utils.BusinessLogicSysCallHandler, - ) -> ContractCall: - return ContractCall( - from_address=self.caller_address, - to_address=self.contract_address, - code_address=self.code_address, - entry_point_selector=self.entry_point_selector, - entry_point_type=self.entry_point_type, - calldata=self.calldata, + tx_execution_context = TransactionExecutionContext.create( + account_contract_address=self.contract_address, + transaction_hash=self.hash_value, signature=self.signature, - cairo_usage=syscall_handler.state.cairo_usage - previous_cairo_usage, - events=syscall_handler.events, - l2_to_l1_messages=[], - internal_call_responses=syscall_handler.internal_call_responses, - storage_read_values=syscall_handler.starknet_storage.read_values, - storage_accessed_addresses=syscall_handler.starknet_storage.accessed_addresses, + max_fee=self.max_fee, + n_steps=general_config.invoke_tx_max_n_steps, ) - - def _get_return_values(self, runner: CairoFunctionRunner) -> List[int]: - with wrap_with_stark_exception( - code=StarknetErrorCode.INVALID_RETURN_DATA, - message="Error extracting return data in call().", - logger=logger, - exception_types=[Exception], - ): - ret_data_size, ret_data_ptr = runner.get_return_values(2) - values = runner.memory.get_range(ret_data_ptr, ret_data_size) - - stark_assert( - all(isinstance(value, int) for value in values), - code=StarknetErrorCode.INVALID_RETURN_DATA, - message="Return data expected to be non-relocatable.", + return await self.execute( + state=state, general_config=general_config, tx_execution_context=tx_execution_context ) - return cast(List[int], values) - class InternalTransactionSchema(OneOfSchema): """ diff --git a/src/starkware/starknet/business_logic/state.py b/src/starkware/starknet/business_logic/state.py index a93a9529..bbcceb7c 100644 --- a/src/starkware/starknet/business_logic/state.py +++ b/src/starkware/starknet/business_logic/state.py @@ -425,7 +425,7 @@ async def apply_state_updates( for contract_state in current_carried_state.contract_states.values() ) ) - contract_states = ChainMap( + contract_states: typing.ChainMap[int, ContractCarriedState] = ChainMap( dict(safe_zip(current_carried_state.contract_states.keys(), updated_contract_states)) ) diff --git a/src/starkware/starknet/business_logic/state_objects.py b/src/starkware/starknet/business_logic/state_objects.py index dbbe1b88..9712ef5b 100644 --- a/src/starkware/starknet/business_logic/state_objects.py +++ b/src/starkware/starknet/business_logic/state_objects.py @@ -10,6 +10,7 @@ from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.storage.starknet_storage import StorageLeaf +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import EmptyNodeFact from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.starkware_utils.error_handling import stark_assert @@ -34,7 +35,7 @@ def _hash(self, hash_func: HashFunctionType) -> bytes: @marshmallow_dataclass.dataclass(frozen=True) -class ContractState(ValidatedMarshmallowDataclass, Fact): +class ContractState(ValidatedMarshmallowDataclass, LeafFact): """ Represents the state of a single contract (sub-commitment tree) in the full StarkNet state commitment tree. @@ -67,21 +68,19 @@ async def empty( @property def is_empty(self) -> bool: - return ( - self.contract_hash == self.UNINITIALIZED_CONTRACT_HASH - and self.storage_commitment_tree.root == EmptyNodeFact.EMPTY_NODE_HASH - ) + return not self.initialized def _hash(self, hash_func: HashFunctionType) -> bytes: """ Computes the hash of the node containing the contract's information, including the contract definition and storage. """ - CONTRACT_STATE_HASH_VERSION = 0 - RESERVED = 0 if self.is_empty: return EmptyNodeFact.EMPTY_NODE_HASH + CONTRACT_STATE_HASH_VERSION = 0 + RESERVED = 0 + # Set hash_value = H(H(contract_hash, storage_root), RESERVED). hash_value = hash_func(self.contract_hash, self.storage_commitment_tree.root) hash_value = hash_func(hash_value, to_bytes(RESERVED)) @@ -115,7 +114,13 @@ async def fetch_contract_definitions( @property def initialized(self) -> bool: - return self.contract_hash != ContractState.UNINITIALIZED_CONTRACT_HASH + uninitialized = self.contract_hash == self.UNINITIALIZED_CONTRACT_HASH + if uninitialized: + assert ( + self.storage_commitment_tree.root == EmptyNodeFact.EMPTY_NODE_HASH + ), "Contract storage commitment root must be empty if contract hash is uninitialized." + + return not uninitialized def assert_initialized(self, contract_address: int): """ diff --git a/src/starkware/starknet/business_logic/transaction_execution_objects.py b/src/starkware/starknet/business_logic/transaction_execution_objects.py index ebdacfa3..04336bad 100644 --- a/src/starkware/starknet/business_logic/transaction_execution_objects.py +++ b/src/starkware/starknet/business_logic/transaction_execution_objects.py @@ -3,9 +3,11 @@ import logging import operator from dataclasses import field -from typing import List, Optional, Set, cast +from typing import Any, Dict, Iterator, List, Optional, Set, cast +import marshmallow.fields as mfields import marshmallow_dataclass +from marshmallow.decorators import pre_load from services.everest.business_logic.internal_transaction import EverestTransactionExecutionInfo from services.everest.definitions import fields as everest_fields @@ -15,7 +17,11 @@ from starkware.starknet.definitions import fields from starkware.starknet.services.api.contract_definition import EntryPointType from starkware.starkware_utils.marshmallow_dataclass_fields import SetField -from starkware.starkware_utils.validated_dataclass import ValidatedDataclass +from starkware.starkware_utils.serializable_dataclass import SerializableMarshmallowDataclass +from starkware.starkware_utils.validated_dataclass import ( + ValidatedDataclass, + ValidatedMarshmallowDataclass, +) from starkware.starkware_utils.validated_fields import sequential_id_metadata logger = logging.getLogger(__name__) @@ -31,21 +37,55 @@ class TransactionExecutionContext(ValidatedDataclass): account_contract_address: int = field( metadata=fields.AddressField.metadata(field_name="account_contract_address") ) + # The hash of the transaction. + transaction_hash: int = field(metadata=fields.transaction_hash_metadata) + # The signature of the transaction. + signature: List[int] = field(metadata=fields.signature_metadata) + max_fee: int = field(metadata=fields.fee_metadata) run_resources: RunResources # Used for tracking global events order. n_emitted_events: int = field(metadata=sequential_id_metadata("Number of emitted events")) + # Used for tracking global L2-to-L1 messages order. + n_sent_messages: int = field(metadata=sequential_id_metadata("Number of messages sent to L1")) @classmethod - def create(cls, account_contract_address: int, n_steps: int) -> "TransactionExecutionContext": + def create( + cls, + account_contract_address: int, + transaction_hash: int, + signature: List[int], + max_fee: int, + n_steps: int, + ) -> "TransactionExecutionContext": return cls( account_contract_address=account_contract_address, + transaction_hash=transaction_hash, + signature=signature, + max_fee=max_fee, run_resources=RunResources(n_steps=n_steps), n_emitted_events=0, + n_sent_messages=0, + ) + + @classmethod + def create_for_call( + cls, account_contract_address: int, n_steps: int + ) -> "TransactionExecutionContext": + """ + Creates a context for transaction execution. To be used when executing an entry point + without a concrete InternalInvokeFunction object. + """ + return cls.create( + account_contract_address=account_contract_address, + n_steps=n_steps, + signature=[], + transaction_hash=0, + max_fee=0, ) @dataclasses.dataclass(frozen=True) -class OrderedEventContent(ValidatedDataclass): +class OrderedEvent(ValidatedDataclass): """ Contains the raw content of an event, without the context its origin (emitting contract, etc.) along with its order in the transaction execution. @@ -68,12 +108,12 @@ class Event(ValidatedDataclass): # Emitting contract address. from_address: int = field(metadata=fields.contract_address_metadata) # The keys by which the event will be indexed. - keys: List[int] = field(metadata=fields.felt_list_metadata) + keys: List[int] = field(metadata=fields.felt_as_hex_list_metadata) # The data of the event. - data: List[int] = field(metadata=fields.felt_list_metadata) + data: List[int] = field(metadata=fields.felt_as_hex_list_metadata) @classmethod - def create(cls, event_content: OrderedEventContent, emitting_contract_address: int): + def create(cls, event_content: OrderedEvent, emitting_contract_address: int): return cls( from_address=emitting_contract_address, keys=event_content.keys, @@ -81,6 +121,18 @@ def create(cls, event_content: OrderedEventContent, emitting_contract_address: i ) +@dataclasses.dataclass(frozen=True) +class OrderedL2ToL1Message(ValidatedDataclass): + """ + A class containing the raw content of a L2-to-L1 message, without the context its origin + (the sending contract, etc.) along with its order in the transaction execution. + """ + + order: int = field(metadata=sequential_id_metadata("L2-to-L1 message order")) + to_address: int = field(metadata=everest_fields.EthAddressIntField.metadata("to_address")) + payload: List[int] = field(metadata=fields.felt_list_metadata) + + @dataclasses.dataclass(frozen=True) class L2ToL1MessageInfo(ValidatedDataclass): """ @@ -91,6 +143,191 @@ class L2ToL1MessageInfo(ValidatedDataclass): to_address: int = field(metadata=everest_fields.EthAddressIntField.metadata("to_address")) payload: List[int] = field(metadata=fields.felt_list_metadata) + @classmethod + def create(cls, message_content: OrderedL2ToL1Message, sending_contract_address: int): + return cls( + from_address=sending_contract_address, + to_address=message_content.to_address, + payload=message_content.payload, + ) + + +# NOTE: This dataclass isn't validated due to a forward-declaration issue. +@marshmallow_dataclass.dataclass(frozen=True) +class CallInfo(SerializableMarshmallowDataclass): + """ + Represents a contract call, either internal or external. + Holds the information needed for the execution of the represented contract call by the OS. + No need for validations here, as the fields are taken from validated objects. + """ + + # Static info. + + caller_address: int # Should be zero if the call represents an external transaction. + contract_address: int + # The address that holds the executed code; relevant just for delegate calls, where it may + # differ from the code of the to_address contract. + code_address: Optional[int] + entry_point_selector: Optional[int] + entry_point_type: Optional[EntryPointType] + calldata: List[int] + # Execution info. + retdata: List[int] + execution_resources: ExecutionResources + # Note that the order starts from a transaction-global offset. + events: List[OrderedEvent] + l2_to_l1_messages: List[OrderedL2ToL1Message] + + # Information kept for the StarkNet OS run in the GpsAmbassador. + + # A list of values read from storage by this call, **excluding** readings from nested calls. + storage_read_values: List[int] + # A set of storage keys accessed by this call, **excluding** keys from nested calls; + # kept in order to calculate and prepare the commitment tree facts before the StarkNet OS run. + accessed_storage_keys: Set[int] = field( + metadata=dict( + marshmallow_field=SetField( + everest_fields.felt_metadata("storage_accessed_address")["marshmallow_field"] + ) + ) + ) + + # Internal calls made by this call. + + internal_calls: List["CallInfo"] = field( + metadata=dict(marshmallow_field=mfields.List(mfields.Nested(lambda: CallInfo.Schema()))) + ) + + def get_state_selector(self) -> StateSelector: + code_address = self.contract_address if self.code_address is None else self.code_address + selector = StateSelector(contract_addresses={self.contract_address, code_address}) + return functools.reduce( + StateSelector.__or__, + (call.get_state_selector() for call in self.internal_calls), + selector, + ) + + def gen_call_topology(self) -> Iterator["CallInfo"]: + """ + Yields the contract calls in DFS (preorder). + """ + yield self + for call in self.internal_calls: + yield from call.gen_call_topology() + + @classmethod + def empty(cls, contract_address: int) -> "CallInfo": + return cls( + caller_address=0, + contract_address=contract_address, + code_address=None, + entry_point_type=None, + entry_point_selector=None, + calldata=[], + retdata=[], + execution_resources=ExecutionResources.empty(), + events=[], + l2_to_l1_messages=[], + storage_read_values=[], + accessed_storage_keys=set(), + internal_calls=[], + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionExecutionInfo(EverestTransactionExecutionInfo): + """ + Contains the information gathered by the execution of a transation. Main usages: + 1. Supplies hints for the OS run on the corresponding transaction; e.g., internal call results. + 2. Stores useful information for users; e.g., L2-to-L1 messages and emitted events. + """ + + call_info: CallInfo + # Fee transfer call info, executed by the BE for external InvokeFunction transactions; + # Optional since currently Deploy transactions do not have fee (and backward compatibility). + fee_transfer_info: Optional[CallInfo] + actual_fee: int = field(metadata=fields.FeeField.metadata(field_name="actual_fee")) + + def get_state_selector(self) -> StateSelector: + call_info_selector = self.call_info.get_state_selector() + if self.fee_transfer_info is None: + return call_info_selector + + return call_info_selector | self.fee_transfer_info.get_state_selector() + + def gen_call_iterator(self) -> Iterator[CallInfo]: + """ + Yields the contract calls in the order that they are going to be executed in the OS. + (Preorder of the original call tree followed by the preorder of the call tree that was + generated while charging the fee). + """ + yield from self.call_info.gen_call_topology() + if self.fee_transfer_info is None: + return + + yield from self.fee_transfer_info.gen_call_topology() + + def get_sorted_events(self) -> List[Event]: + """ + Returns a list of StarkNet Event objects collected during the execution, sorted by the order + in which they were emitted. + """ + n_events = sum(len(call.events) for call in self.call_info.gen_call_topology()) + starknet_events: List[Optional[Event]] = [None] * n_events + + for call in self.call_info.gen_call_topology(): + for ordered_event_content in call.events: + # Convert OrderedEvent -> Event. I.e., add emitting contract address + # and remove the order. + starknet_events[ordered_event_content.order] = Event.create( + emitting_contract_address=call.contract_address, + event_content=ordered_event_content, + ) + + assert all( + starknet_event is not None for starknet_event in starknet_events + ), "Unexpected holes in the event order." + + return cast(List[Event], starknet_events) + + def get_sorted_l2_to_l1_messages(self) -> List[L2ToL1MessageInfo]: + """ + Returns a list of StarkNet L2ToL1MessageInfo objects collected during the execution, sorted + by the order in which they were sent. + """ + n_messages = sum(len(call.l2_to_l1_messages) for call in self.call_info.gen_call_topology()) + starknet_l2_to_l1_messages: List[Optional[L2ToL1MessageInfo]] = [None] * n_messages + + for call in self.call_info.gen_call_topology(): + for ordered_message_content in call.l2_to_l1_messages: + # Convert OrderedL2ToL1Message -> L2ToL1MessageInfo. I.e., add sending + # contract address and remove the order. + starknet_l2_to_l1_messages[ + ordered_message_content.order + ] = L2ToL1MessageInfo.create( + sending_contract_address=call.contract_address, + message_content=ordered_message_content, + ) + + assert all( + message is not None for message in starknet_l2_to_l1_messages + ), "Unexpected holes in the L2-to-L1 message order." + + return cast(List[L2ToL1MessageInfo], starknet_l2_to_l1_messages) + + @staticmethod + def get_state_selector_of_many( + execution_infos: List["TransactionExecutionInfo"], + ) -> StateSelector: + return functools.reduce( + operator.__or__, + (execution_info.get_state_selector() for execution_info in execution_infos), + StateSelector.empty(), + ) + + +# Deprecated classes. + @dataclasses.dataclass(frozen=True) class ContractCallResponse(ValidatedDataclass): @@ -101,12 +338,11 @@ class ContractCallResponse(ValidatedDataclass): retdata: List[int] -@dataclasses.dataclass(frozen=True) -class ContractCall(ValidatedDataclass): +@marshmallow_dataclass.dataclass(frozen=True) +class ContractCall(ValidatedMarshmallowDataclass): """ Represents a contract call, either internal or external. Holds the information needed for the execution of the represented contract call by the OS. - The addresses are of L2 contracts. No need for validations here, as the fields are taken from validated objects. """ @@ -116,7 +352,7 @@ class ContractCall(ValidatedDataclass): to_address: int # The called contract address. # The address that holds the executed code; relevant just for delegate calls, where it may # differ from the code of the to_address contract. - code_address: Optional[int] = field(metadata=fields.optional_contract_address_metadata) + code_address: Optional[int] = field(metadata=fields.optional_code_address_metadata) entry_point_selector: Optional[int] = field(metadata=dict(load_default=None, required=False)) entry_point_type: Optional[EntryPointType] = field( metadata=dict(load_default=None, required=False) @@ -128,8 +364,8 @@ class ContractCall(ValidatedDataclass): cairo_usage: ExecutionResources # Note that the order starts from a transaction-global offset. - events: List[OrderedEventContent] = field(metadata=dict(load_default=list, required=False)) - l2_to_l1_messages: List[L2ToL1MessageInfo] = field( + events: List[OrderedEvent] = field(metadata=dict(load_default=list, required=False)) + l2_to_l1_messages: List[OrderedL2ToL1Message] = field( metadata=dict(load_default=list, required=False) ) @@ -175,44 +411,53 @@ def state_selector(self) -> StateSelector: @marshmallow_dataclass.dataclass(frozen=True) -class TransactionExecutionInfo(EverestTransactionExecutionInfo): +class TransactionExecutionInfoDeprecated(EverestTransactionExecutionInfo): """ Contains the information gathered by the execution of a transation. Main uses: 1. Supplies hints for the OS run on the corresponding transaction; e.g., internal call results. 2. Stores useful information for users; e.g., L2-to-L1 messages it sent and emitted events. """ - l2_to_l1_messages: List[L2ToL1MessageInfo] + call_info: ContractCall # The retdata of the main transaction. retdata: List[int] - call_info: ContractCall # The internal contract calls; arranged in DFS order, which is the order they are invoked by the # OS. internal_calls: List[ContractCall] + actual_fee: int = field(metadata=fields.fee_metadata) @classmethod def create( - cls, call_info: ContractCall, internal_calls: Optional[List[ContractCall]] = None - ) -> "TransactionExecutionInfo": + cls, + call_info: ContractCall, + internal_calls: Optional[List[ContractCall]] = None, + actual_fee: int = 0, + ) -> "TransactionExecutionInfoDeprecated": return cls( - l2_to_l1_messages=[], retdata=[], call_info=call_info, internal_calls=[] if internal_calls is None else internal_calls, + actual_fee=actual_fee, ) @property def contract_calls(self) -> List[ContractCall]: return [self.call_info, *self.internal_calls] - @property - def state_selector(self) -> StateSelector: + def get_state_selector(self) -> StateSelector: return functools.reduce( operator.__or__, (contract_call.state_selector for contract_call in self.contract_calls), StateSelector.empty(), ) + @pre_load + def remove_l2_to_l1_messages( + self, data: Dict[str, Any], many: bool, **kwargs + ) -> Dict[str, Any]: + data.pop("l2_to_l1_messages", None) + return data + def get_sorted_events(self) -> List[Event]: """ Returns a list of StarkNet Event objects collected during the execution, sorted by the order @@ -223,8 +468,8 @@ def get_sorted_events(self) -> List[Event]: for contract_call in self.contract_calls: for ordered_event_content in contract_call.events: - # Convert OrderedEventContent -> Event. I.e., add emitting contract address - # and remove order. + # Convert OrderedEvent -> Event. I.e., add emitting contract address + # and remove the order. starknet_events[ordered_event_content.order] = Event.create( emitting_contract_address=contract_call.to_address, event_content=ordered_event_content, @@ -233,14 +478,42 @@ def get_sorted_events(self) -> List[Event]: assert all( starknet_event is not None for starknet_event in starknet_events ), "Unexpected holes in the event order." + return cast(List[Event], starknet_events) + def get_sorted_l2_to_l1_messages(self) -> List[L2ToL1MessageInfo]: + """ + Returns a list of StarkNet L2ToL1MessageInfo objects collected during the execution, sorted + by the order in which they were sent. + """ + n_messages = sum( + len(contract_call.l2_to_l1_messages) for contract_call in self.contract_calls + ) + starknet_l2_to_l1_messages: List[Optional[L2ToL1MessageInfo]] = [None] * n_messages + + for contract_call in self.contract_calls: + for ordered_message_content in contract_call.l2_to_l1_messages: + # Convert OrderedL2ToL1Message -> L2ToL1MessageInfo. I.e., add sending + # contract address and remove the order. + starknet_l2_to_l1_messages[ + ordered_message_content.order + ] = L2ToL1MessageInfo.create( + sending_contract_address=contract_call.to_address, + message_content=ordered_message_content, + ) + + assert all( + message is not None for message in starknet_l2_to_l1_messages + ), "Unexpected holes in the L2-to-L1 message order." + + return cast(List[L2ToL1MessageInfo], starknet_l2_to_l1_messages) + @staticmethod def get_state_selector_of_many( - execution_infos: List["TransactionExecutionInfo"], + execution_infos: List["TransactionExecutionInfoDeprecated"], ) -> StateSelector: return functools.reduce( operator.__or__, - (execution_info.state_selector for execution_info in execution_infos), + (execution_info.get_state_selector() for execution_info in execution_infos), StateSelector.empty(), ) diff --git a/src/starkware/starknet/business_logic/transaction_fee.py b/src/starkware/starknet/business_logic/transaction_fee.py new file mode 100644 index 00000000..fd2210fc --- /dev/null +++ b/src/starkware/starknet/business_logic/transaction_fee.py @@ -0,0 +1,88 @@ +import math +from typing import Dict, Tuple + +from starkware.starknet.business_logic.execute_entry_point import ExecuteEntryPoint +from starkware.starknet.business_logic.state import CarriedState +from starkware.starknet.business_logic.transaction_execution_objects import ( + CallInfo, + TransactionExecutionContext, +) +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starknet.definitions.general_config import StarknetGeneralConfig +from starkware.starknet.public import abi as starknet_abi +from starkware.starknet.services.api.contract_definition import EntryPointType +from starkware.starkware_utils.error_handling import StarkException, stark_assert_le + + +async def charge_fee( + general_config: StarknetGeneralConfig, + state: CarriedState, + account_contract_address: int, + execution_resources: Dict[str, int], + max_fee: int, +) -> Tuple[CallInfo, int]: + """ + Calculates the actual fee from the given execution resources and transfers the amount from the + caller account to the sequencer. Returns the resulting CallInfo of the transfer call and the + actual fee. + """ + actual_fee = calculate_tx_fee_by_cairo_usage( + general_config=general_config, + cairo_resource_usage=execution_resources, + l1_gas_usage=0, + ) + + stark_assert_le( + actual_fee, + max_fee, + code=StarknetErrorCode.FEE_TRANSFER_FAILURE, + message="Actual fee exceeded max fee.", + ) + + tx_execution_context = TransactionExecutionContext.create_for_call( + account_contract_address=account_contract_address, + n_steps=general_config.invoke_tx_max_n_steps, + ) + + fee_token_address = general_config.fee_token_address + fee_transfer_call = ExecuteEntryPoint( + caller_address=account_contract_address, # The account contract address. + contract_address=fee_token_address, + code_address=fee_token_address, + entry_point_selector=starknet_abi.TRANSFER_ENTRY_POINT_SELECTOR, + entry_point_type=EntryPointType.EXTERNAL, + calldata=[general_config.sequencer_address, actual_fee, 0], # Recipient, amount (128-bit). + ) + try: + fee_transfer_info = await fee_transfer_call.execute( + state=state, general_config=general_config, tx_execution_context=tx_execution_context + ) + except StarkException as exception: + raise StarkException(code=StarknetErrorCode.FEE_TRANSFER_FAILURE, message=str(exception)) + + return fee_transfer_info, actual_fee + + +def calculate_tx_fee_by_cairo_usage( + general_config: StarknetGeneralConfig, cairo_resource_usage: Dict[str, int], l1_gas_usage: int +) -> int: + """ + Calculates the transaction fee by considering the heaviest Cairo resource (in terms of L1 gas), + as the size of a proof is determined similarly - by the (normalized) largest segment. + We add to that the given l1_gas_usage (which may include, for example, the direct cost of + L2-to-L1 messages) and multiply by the L1 gas price. + """ + cairo_resource_fee_weights = general_config.cairo_resource_fee_weights + cairo_resource_names = set(cairo_resource_usage.keys()) + assert cairo_resource_names.issubset( + cairo_resource_fee_weights.keys() + ), "Cairo resource names must be contained in fee weights dict." + + # Convert Cairo usage to L1 gas usage. + cairo_l1_gas_usage = max( + cairo_resource_fee_weights[key] * cairo_resource_usage.get(key, 0) + for key in cairo_resource_fee_weights + ) + + total_l1_gas_usage = cairo_l1_gas_usage + l1_gas_usage + return math.ceil(total_l1_gas_usage * general_config.gas_price) diff --git a/src/starkware/starknet/business_logic/utils.py b/src/starkware/starknet/business_logic/utils.py new file mode 100644 index 00000000..c1a6332d --- /dev/null +++ b/src/starkware/starknet/business_logic/utils.py @@ -0,0 +1,78 @@ +import logging +from typing import List, Optional, Tuple, cast + +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.starknet.core.os.transaction_hash import TransactionHashPrefix +from starkware.starknet.definitions import constants +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starknet.public import abi as starknet_abi +from starkware.starknet.services.api.contract_definition import EntryPointType +from starkware.starkware_utils.error_handling import stark_assert, wrap_with_stark_exception + +logger = logging.getLogger(__name__) + + +def get_return_values(runner: CairoFunctionRunner) -> List[int]: + """ + Extracts the return values of a StarkNet contract function from the Cairo runner. + """ + with wrap_with_stark_exception( + code=StarknetErrorCode.INVALID_RETURN_DATA, + message="Error extracting return data.", + logger=logger, + exception_types=[Exception], + ): + ret_data_size, ret_data_ptr = runner.get_return_values(2) + values = runner.memory.get_range(ret_data_ptr, ret_data_size) + + stark_assert( + all(isinstance(value, int) for value in values), + code=StarknetErrorCode.INVALID_RETURN_DATA, + message="Return data expected to be non-relocatable.", + ) + + return cast(List[int], values) + + +def preprocess_invoke_function_fields( + entry_point_type: EntryPointType, + entry_point_selector: int, + message_from_l1_nonce: Optional[int], + max_fee: int, + version: int, +) -> Tuple[TransactionHashPrefix, List[int]]: + """ + Performs validation on fields related to function invocation transaction. + Deduces and returns entry point type-related fields required for hash calculation of + InvokeFunction transaction. + """ + # Validate version. + assert ( + version == constants.TRANSACTION_VERSION + ), f"Transaction version {version} is not supported." + + # Validate entry point type-related fields. + if entry_point_type is EntryPointType.EXTERNAL: + assert message_from_l1_nonce is None, "An InvokeFunction transaction cannot have a nonce." + if max_fee != 0: + stark_assert( + entry_point_selector == starknet_abi.EXECUTE_ENTRY_POINT_SELECTOR, + code=StarknetErrorCode.UNSUPPORTED_SELECTOR_FOR_FEE, + message=( + "Transactions with positive fee should go through the " + f"{starknet_abi.EXECUTE_ENTRY_POINT_NAME} entrypoint." + ), + ) + + tx_hash_prefix = TransactionHashPrefix.INVOKE + additional_data = [] + elif entry_point_type is EntryPointType.L1_HANDLER: + assert message_from_l1_nonce is not None, "An L1 handler transaction must have a nonce." + assert max_fee == 0, "An L1 handler transaction must have max_fee=0." + + tx_hash_prefix = TransactionHashPrefix.L1_HANDLER + additional_data = [message_from_l1_nonce] + else: + raise NotImplementedError(f"Entry point type {entry_point_type.name} is not supported.") + + return tx_hash_prefix, additional_data diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt index 6d38fe5a..533e01d3 100644 --- a/src/starkware/starknet/cli/CMakeLists.txt +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -20,9 +20,12 @@ python_lib(starknet_cli_lib starknet_definitions_lib starknet_feeder_gateway_client_lib starknet_gateway_client_lib + starknet_general_config_lib starknet_transaction_lib starknet_wallets_lib starkware_error_handling_lib + starkware_python_utils_lib + pip_web3 ) python_venv(starknet_cli_venv diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 3fe7e5a7..057bb095 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -2,12 +2,16 @@ import argparse import asyncio +import dataclasses import functools import json +import math import os import sys from typing import Any, Dict, List, Optional +from web3 import Web3 + from services.everest.definitions import fields as everest_fields from services.external_api.base_client import RetryConfig from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer @@ -19,8 +23,10 @@ from starkware.cairo.lang.tracer.tracer_data import field_element_repr from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager +from starkware.python.utils import from_bytes from starkware.starknet.cli.reconstruct_starknet_traceback import reconstruct_starknet_traceback -from starkware.starknet.definitions import fields +from starkware.starknet.definitions import constants, fields +from starkware.starknet.definitions.general_config import StarknetChainId from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.public.abi_structs import identifier_manager_from_abi from starkware.starknet.services.api.contract_definition import ContractDefinition @@ -37,6 +43,24 @@ "alpha-mainnet": "alpha-mainnet.starknet.io", } +CHAIN_IDS = { + "alpha-goerli": StarknetChainId.TESTNET.value, + "alpha-mainnet": StarknetChainId.MAINNET.value, +} + +FEE_MARGIN_OF_ESTIMATION = 1.1 + + +@dataclasses.dataclass +class InvokeFunctionArgs: + address: int + selector: int + calldata: List[int] + signature: List[int] + + +# Utilities. + def felt_formatter(hex_felt: str) -> str: return field_element_repr(val=int(hex_felt, 16), prime=everest_fields.FeltField.upper_bound) @@ -67,6 +91,18 @@ def get_arg_value(args, arg_name: str, environment_var: str) -> str: return value +def get_chain_id(args) -> int: + chain_id = get_arg_value(args=args, arg_name="chain_id", environment_var="STARKNET_CHAIN_ID") + + if chain_id.startswith("0x"): + chain_id_int = int(chain_id, 16) + else: + chain_id_int = from_bytes(chain_id.encode()) + + assert chain_id_int in CHAIN_IDS.values(), f"Unsupported chain ID: {chain_id}." + return chain_id_int + + def get_network_id(args) -> str: """ Returns a textual identifier of the network. Used for account management. @@ -225,6 +261,132 @@ async def load_account( return await account_class.create(starknet_context=starknet_context, account_name=account_name) +def handle_network_param(args): + """ + Gives default values to the gateways if the network parameter is set. + """ + network = get_network(args) + if network is not None: + if network not in NETWORKS: + networks_str = ", ".join(NETWORKS.keys()) + print( + f"Unknown network '{network}'. Supported networks: {networks_str}.", + file=sys.stderr, + ) + return 1 + + dns = NETWORKS[network] + if args.gateway_url is None: + args.gateway_url = f"https://{dns}/gateway" + + if args.feeder_gateway_url is None: + args.feeder_gateway_url = f"https://{dns}/feeder_gateway" + + if args.network_id is None: + args.network_id = network + + if args.chain_id is None: + args.chain_id = hex(CHAIN_IDS[network]) + + return 0 + + +def parse_invoke_tx_args(args: argparse.Namespace) -> InvokeFunctionArgs: + """ + Parses the arguments and validates that the function name is in the abi. + """ + inputs = cast_to_felts(values=args.inputs) + + abi = json.load(args.abi) + for abi_entry in abi: + if abi_entry["type"] == "function" and abi_entry["name"] == args.function: + validate_arguments( + inputs=inputs, + abi_entry=abi_entry, + identifier_manager=identifier_manager_from_abi(abi=abi), + ) + break + else: + raise Exception(f"Function {args.function} not found.") + + return InvokeFunctionArgs( + signature=cast_to_felts(values=args.signature), + address=parse_address(args.address), + selector=get_selector_from_name(args.function), + calldata=inputs, + ) + + +async def create_invoke_tx( + args: argparse.Namespace, invoke_tx_args: InvokeFunctionArgs, max_fee: int, has_wallet: bool +) -> InvokeFunction: + """ + Creates and returns an InvokeFunction transaction with the given parameters. + If a wallet provider was provided in args, that transaction will be wrapped and signed. + """ + version = constants.TRANSACTION_VERSION + + if has_wallet: + account = await load_account_from_args(args=args) + assert invoke_tx_args.signature == [], ( + "Signature cannot be passed explicitly when using an account contract. " + "Consider making a direct contract call using --no_wallet." + ) + wrapped_method = await account.sign_invoke_transaction( + contract_address=invoke_tx_args.address, + selector=invoke_tx_args.selector, + calldata=invoke_tx_args.calldata, + chain_id=get_chain_id(args), + max_fee=max_fee, + nonce=args.nonce, + ) + tx = InvokeFunction( + contract_address=wrapped_method.address, + entry_point_selector=wrapped_method.selector, + calldata=wrapped_method.calldata, + max_fee=wrapped_method.max_fee, + version=version, + signature=wrapped_method.signature, + ) + else: + assert args.nonce is None, "--nonce cannot be used in direct calls." + tx = InvokeFunction( + contract_address=invoke_tx_args.address, + entry_point_selector=invoke_tx_args.selector, + calldata=invoke_tx_args.calldata, + max_fee=max_fee, + version=version, + signature=invoke_tx_args.signature, + ) + return tx + + +async def estimate_fee_inner( + args: argparse.Namespace, + invoke_tx_args: InvokeFunctionArgs, + has_wallet: bool, + has_block_info: bool, +) -> Dict[str, Any]: + """ + Estimates the fee of a transaction with the given parameters. + Returns a response of the form: + {"amount": , "unit": "wei"} + """ + tx = await create_invoke_tx( + args=args, invoke_tx_args=invoke_tx_args, max_fee=0, has_wallet=has_wallet + ) + feeder_client = get_feeder_gateway_client(args=args) + block_hash = args.block_hash if has_block_info else None + block_number = args.block_number if has_block_info else None + response = await feeder_client.estimate_fee( + invoke_tx=tx, block_hash=block_hash, block_number=block_number + ) + return response + + +# Subparsers. + + async def deploy(args, command_args): parser = argparse.ArgumentParser(description="Sends a deploy transaction to StarkNet.") parser.add_argument( @@ -310,85 +472,35 @@ async def deploy_account(args, command_args): await account.deploy() -async def invoke_or_call(args, command_args, call: bool): +async def invoke_or_call(args: argparse.Namespace, command_args: List[str], call: bool): parser = argparse.ArgumentParser(description="Sends an invoke transaction to StarkNet.") + add_invoke_tx_arguments(parser=parser, call=call) parser.add_argument( - "--address", type=str, required=True, help="The address of the invoked contract." - ) - parser.add_argument( - "--abi", type=argparse.FileType("r"), required=True, help="The Cairo contract ABI." - ) - parser.add_argument( - "--function", type=str, required=True, help="The name of the invoked function." + "--max_fee", type=int, help="The maximal fee to be paid for the function invocation." ) - parser.add_argument( - "--inputs", type=str, nargs="*", default=[], help="The inputs to the invoked function." - ) - parser.add_argument( - "--nonce", - type=int, - help=( - "Allows to explicitly specify the transaction nonce. " - "If not specified, the current nonce of the account contract " - "(as returned from StarkNet) will be used." - ), - ) - parser.add_argument( - "--signature", - type=str, - nargs="*", - default=[], - help="The signature information for the invoked function.", - ) - if call: - add_block_identifier_argument( - parser=parser, block_role_description="be used as the context for the call operation" - ) parser.parse_args(command_args, namespace=args) - inputs = cast_to_felts(values=args.inputs) - signature = cast_to_felts(values=args.signature) - - abi = json.load(args.abi) + invoke_tx_args = parse_invoke_tx_args(args=args) + address = invoke_tx_args.address - address = parse_address(args.address) - for abi_entry in abi: - if abi_entry["type"] == "function" and abi_entry["name"] == args.function: - validate_arguments( - inputs=inputs, - abi_entry=abi_entry, - identifier_manager=identifier_manager_from_abi(abi=abi), + has_wallet = get_wallet_provider(args=args) is not None + max_fee = args.max_fee + if max_fee is None: + if has_wallet: + fee_info = await estimate_fee_inner( + args=args, invoke_tx_args=invoke_tx_args, has_wallet=has_wallet, has_block_info=call ) - break - else: - raise Exception(f"Function {args.function} not found.") - selector = get_selector_from_name(args.function) - calldata = inputs + max_fee = math.ceil(fee_info["amount"] * FEE_MARGIN_OF_ESTIMATION) + max_fee_eth = float(Web3.fromWei(max_fee, "ether")) - if get_wallet_provider(args) is not None: - account = await load_account_from_args(args) - assert signature == [], ( - "Signature cannot be passed explicitly when using an account contract. " - "Consider making a direct contract call using --no_wallet." - ) - wrapped_method = await account.sign_invoke_transaction( - contract_address=address, selector=selector, calldata=calldata, nonce=args.nonce - ) - tx = InvokeFunction( - contract_address=wrapped_method.address, - entry_point_selector=wrapped_method.selector, - calldata=wrapped_method.calldata, - signature=wrapped_method.signature, - ) - else: - assert args.nonce is None, "--nonce cannot be used in direct calls." - tx = InvokeFunction( - contract_address=address, - entry_point_selector=selector, - calldata=calldata, - signature=signature, - ) + print(f"Sending the transaction with max_fee: {max_fee_eth:.6f} ETH.") + else: + max_fee = 0 + + tx = await create_invoke_tx( + args=args, invoke_tx_args=invoke_tx_args, max_fee=max_fee, has_wallet=has_wallet + ) gateway_response: dict if call: @@ -412,6 +524,23 @@ async def invoke_or_call(args, command_args, call: bool): ) +async def estimate_fee(args: argparse.Namespace, command_args: List[str]): + parser = argparse.ArgumentParser(description="Estimates the fee of a transaction.") + add_invoke_tx_arguments(parser=parser, call=True) + + parser.parse_args(command_args, namespace=args) + invoke_tx_args = parse_invoke_tx_args(args=args) + has_wallet = get_wallet_provider(args=args) is not None + + fee_info = await estimate_fee_inner( + args=args, invoke_tx_args=invoke_tx_args, has_wallet=has_wallet, has_block_info=True + ) + + fee_wei = fee_info["amount"] + fee_eth = float(Web3.fromWei(fee_wei, "ether")) + print(f"The estimated fee is: {fee_wei} WEI ({fee_eth:.6f} ETH).") + + async def tx_status(args, command_args): parser = argparse.ArgumentParser( description="Queries the status of a transaction given its ID." @@ -484,8 +613,8 @@ async def get_transaction(args, command_args): print(tx_info.dumps(indent=4, sort_keys=True)) -async def get_transaction_receipt(args, command_args): - parser = argparse.ArgumentParser(description="Outputs the transaction receipt given its ID.") +async def get_transaction_trace(args, command_args): + parser = argparse.ArgumentParser(description="Outputs the transaction trace given its ID.") parser.add_argument( "--hash", type=str, required=True, help="The hash of the transaction to query." ) @@ -493,35 +622,21 @@ async def get_transaction_receipt(args, command_args): feeder_gateway_client = get_feeder_gateway_client(args) - tx_receipt = await feeder_gateway_client.get_transaction_receipt(tx_hash=args.hash) - print(tx_receipt.dumps(indent=4, sort_keys=True)) + tx_trace = await feeder_gateway_client.get_transaction_trace(tx_hash=args.hash) + print(tx_trace.dumps(indent=4, sort_keys=True)) -def handle_network_param(args): - """ - Gives default values to the gateways if the network parameter is set. - """ - network = get_network(args) - if network is not None: - if network not in NETWORKS: - networks_str = ", ".join(NETWORKS.keys()) - print( - f"Unknown network '{network}'. Supported networks: {networks_str}.", - file=sys.stderr, - ) - return 1 - - dns = NETWORKS[network] - if args.gateway_url is None: - args.gateway_url = f"https://{dns}/gateway" - - if args.feeder_gateway_url is None: - args.feeder_gateway_url = f"https://{dns}/feeder_gateway" +async def get_transaction_receipt(args, command_args): + parser = argparse.ArgumentParser(description="Outputs the transaction receipt given its ID.") + parser.add_argument( + "--hash", type=str, required=True, help="The hash of the transaction to query." + ) + parser.parse_args(command_args, namespace=args) - if args.network_id is None: - args.network_id = network + feeder_gateway_client = get_feeder_gateway_client(args) - return 0 + tx_receipt = await feeder_gateway_client.get_transaction_receipt(tx_hash=args.hash) + print(tx_receipt.dumps(indent=4, sort_keys=True)) async def get_block(args, command_args): @@ -531,7 +646,7 @@ async def get_block(args, command_args): "In case no ID is given, outputs the latest block." ) ) - add_block_identifier_argument( + add_block_identifier_arguments( parser=parser, block_role_description="display", with_block_prefix=False ) @@ -545,7 +660,7 @@ async def get_block(args, command_args): async def get_state_update(args, command_args): parser = argparse.ArgumentParser(description=("Outputs the state update of a given block")) - add_block_identifier_argument( + add_block_identifier_arguments( parser=parser, block_role_description="display", with_block_prefix=True ) @@ -568,7 +683,7 @@ async def get_code(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_argument(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments(parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) @@ -592,7 +707,7 @@ async def get_full_contract(args, command_args): parser.add_argument( "--contract_address", type=str, help="The address of the contract.", required=True ) - add_block_identifier_argument(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments(parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) @@ -627,7 +742,7 @@ async def get_storage_at(args, command_args): parser.add_argument( "--key", type=int, help="The position in the contract's storage.", required=True ) - add_block_identifier_argument(parser=parser, block_role_description="extract information from") + add_block_identifier_arguments(parser=parser, block_role_description="extract information from") parser.parse_args(command_args, namespace=args) @@ -643,7 +758,49 @@ async def get_storage_at(args, command_args): ) -def add_block_identifier_argument( +# Add arguments. + + +def add_invoke_tx_arguments(parser: argparse.ArgumentParser, call: bool): + """ + Adds the arguments: address, abi, function, inputs, nonce, signature. + """ + parser.add_argument( + "--address", type=str, required=True, help="The address of the invoked contract." + ) + parser.add_argument( + "--abi", type=argparse.FileType("r"), required=True, help="The Cairo contract ABI." + ) + parser.add_argument( + "--function", type=str, required=True, help="The name of the invoked function." + ) + parser.add_argument( + "--inputs", type=str, nargs="*", default=[], help="The inputs to the invoked function." + ) + parser.add_argument( + "--nonce", + type=int, + help=( + "Allows to explicitly specify the transaction nonce. " + "If not specified, the current nonce of the account contract " + "(as returned from StarkNet) will be used." + ), + ) + parser.add_argument( + "--signature", + type=str, + nargs="*", + default=[], + help="The signature information for the invoked function.", + ) + + if call: + add_block_identifier_arguments( + parser=parser, block_role_description="be used as the context for the call operation" + ) + + +def add_block_identifier_arguments( parser: argparse.ArgumentParser, block_role_description: str, with_block_prefix: bool = True ): identifier_prefix = "block_" if with_block_prefix else "" @@ -670,6 +827,7 @@ async def main(): "call": functools.partial(invoke_or_call, call=True), "deploy": deploy, "deploy_account": deploy_account, + "estimate_fee": estimate_fee, "get_block": get_block, "get_state_update": get_state_update, "get_code": get_code, @@ -678,6 +836,7 @@ async def main(): "get_storage_at": get_storage_at, "get_transaction": get_transaction, "get_transaction_receipt": get_transaction_receipt, + "get_transaction_trace": get_transaction_trace, "invoke": functools.partial(invoke_or_call, call=False), "tx_status": tx_status, } @@ -689,6 +848,11 @@ async def main(): type=str, help="A textual identifier of the network. Used for account management.", ) + parser.add_argument( + "--chain_id", + type=str, + help="The chain id (either as a hex number or as a string).", + ) parser.add_argument( "--wallet", type=str, diff --git a/src/starkware/starknet/common/CMakeLists.txt b/src/starkware/starknet/common/CMakeLists.txt index 55fecbe4..5ee89f08 100644 --- a/src/starkware/starknet/common/CMakeLists.txt +++ b/src/starkware/starknet/common/CMakeLists.txt @@ -5,6 +5,7 @@ python_lib(starknet_common_lib messages.cairo storage.cairo syscalls.cairo + eth_utils.cairo LIBS cairo_common_lib @@ -17,10 +18,12 @@ full_python_test(starknet_common_lib_test FILES storage_test.py + eth_utils_test.py LIBS cairo_function_runner_lib starknet_abi_lib starknet_common_lib + starkware_python_test_utils_lib pip_pytest ) diff --git a/src/starkware/starknet/common/eth_utils.cairo b/src/starkware/starknet/common/eth_utils.cairo new file mode 100644 index 00000000..fe08ca25 --- /dev/null +++ b/src/starkware/starknet/common/eth_utils.cairo @@ -0,0 +1,14 @@ +from starkware.cairo.common.math import assert_lt_felt, assert_not_zero + +const ETH_ADDRESS_BOUND = 2 ** 160 + +func assert_valid_eth_address{range_check_ptr}(address : felt): + with_attr error_message("Invalid Ethereum address - value is more than 160 bits"): + assert_lt_felt(address, ETH_ADDRESS_BOUND) + end + + with_attr error_message("Invalid Ethereum address - value is zero"): + assert_not_zero(address) + end + return () +end diff --git a/src/starkware/starknet/common/eth_utils_test.py b/src/starkware/starknet/common/eth_utils_test.py new file mode 100644 index 00000000..b56e1f44 --- /dev/null +++ b/src/starkware/starknet/common/eth_utils_test.py @@ -0,0 +1,43 @@ +import os + +import pytest + +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.vm.vm_exceptions import VmException +from starkware.python.test_utils import maybe_raises + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), "eth_utils.cairo") + + +@pytest.fixture +def program() -> Program: + return compile_cairo_files([CAIRO_FILE], prime=DEFAULT_PRIME) + + +@pytest.fixture +def runner(program: Program) -> CairoFunctionRunner: + return CairoFunctionRunner(program) + + +@pytest.mark.parametrize( + "address,error_message", + [ + (0, "Invalid Ethereum address - value is zero"), + (1, None), + (2 ** 160 - 1, None), + (2 ** 160, "Invalid Ethereum address - value is more than 160 bits"), + (DEFAULT_PRIME - 1, "Invalid Ethereum address - value is more than 160 bits"), + ], +) +def test_assert_valid_eth_address(runner: CairoFunctionRunner, address, error_message): + with maybe_raises(expected_exception=VmException, error_message=error_message): + runner.run( + "assert_valid_eth_address", + range_check_ptr=runner.range_check_builtin.base, + address=address, + ) + (range_check_ptr_end,) = runner.get_return_values(1) + assert range_check_ptr_end.segment_index == runner.range_check_builtin.base.segment_index diff --git a/src/starkware/starknet/common/syscalls.cairo b/src/starkware/starknet/common/syscalls.cairo index c1d7a1fe..5bb8d63b 100644 --- a/src/starkware/starknet/common/syscalls.cairo +++ b/src/starkware/starknet/common/syscalls.cairo @@ -322,6 +322,13 @@ struct TxInfo: # The signature of the transaction. member signature_len : felt member signature : felt* + + # The hash of the transaction. + member transaction_hash : felt + + # The identifier of the chain. + # This field can be used to prevent replay of testnet transactions on mainnet. + member chain_id : felt end const GET_TX_INFO_SELECTOR = 'GetTxInfo' diff --git a/src/starkware/starknet/compiler/compile.py b/src/starkware/starknet/compiler/compile.py index f927103f..618379d9 100644 --- a/src/starkware/starknet/compiler/compile.py +++ b/src/starkware/starknet/compiler/compile.py @@ -17,6 +17,7 @@ from starkware.cairo.lang.compiler.identifier_manager import IdentifierScope, MissingIdentifierError from starkware.cairo.lang.compiler.module_reader import ModuleReader from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager +from starkware.cairo.lang.compiler.preprocessor.preprocessor import PreprocessedProgram from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.starknet.compiler.external_wrapper import ( @@ -28,7 +29,8 @@ ) from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager from starkware.starknet.compiler.starknet_preprocessor import StarknetPreprocessedProgram -from starkware.starknet.public.abi import get_selector_from_name +from starkware.starknet.compiler.validation_utils import verify_account_contract +from starkware.starknet.public.abi import AbiType, get_selector_from_name from starkware.starknet.services.api.contract_definition import ( ContractDefinition, ContractEntryPoint, @@ -94,6 +96,11 @@ def get_entry_points_by_decorators( ) +def get_abi(preprocessed: PreprocessedProgram) -> AbiType: + assert isinstance(preprocessed, StarknetPreprocessedProgram) + return preprocessed.abi + + def compile_starknet_files( files, debug_info: bool = False, @@ -131,11 +138,10 @@ def compile_starknet_codes( # Dump and load program, so that it is converted to the canonical form. program = Program.load(data=program.dump()) - assert isinstance(preprocessed, StarknetPreprocessedProgram) return ContractDefinition( program=program, entry_points_by_type=get_entry_points_by_type(program=program), - abi=preprocessed.abi, + abi=get_abi(preprocessed=preprocessed), ) @@ -145,7 +151,7 @@ def assemble_starknet_contract( add_debug_info: bool, file_contents_for_debug_info: Dict[str, str], ) -> ContractDefinition: - assert isinstance(preprocessed_program, StarknetPreprocessedProgram) + abi = get_abi(preprocessed=preprocessed_program) program = assemble( preprocessed_program, main_scope=main_scope, @@ -156,7 +162,7 @@ def assemble_starknet_contract( return ContractDefinition( program=program, entry_points_by_type=get_entry_points_by_type(program=program), - abi=preprocessed_program.abi, + abi=abi, ) @@ -166,6 +172,9 @@ def main(): parser.add_argument( "--disable_hint_validation", action="store_true", help="Disable the hint validation." ) + parser.add_argument( + "--account_contract", action="store_true", help="Compile as account contract." + ) def pass_manager_factory(args: argparse.Namespace, module_reader: ModuleReader) -> PassManager: return starknet_pass_manager( @@ -182,9 +191,10 @@ def pass_manager_factory(args: argparse.Namespace, module_reader: ModuleReader) pass_manager_factory=pass_manager_factory, assemble_func=assemble_starknet_contract, ) - assert isinstance(preprocessed, StarknetPreprocessedProgram) + abi = get_abi(preprocessed=preprocessed) + verify_account_contract(contract_abi=abi, is_account_contract=args.account_contract) if args.abi is not None: - json.dump(preprocessed.abi, args.abi, indent=4, sort_keys=True) + json.dump(abi, args.abi, indent=4, sort_keys=True) args.abi.write("\n") except LocationError as err: print(err, file=sys.stderr) diff --git a/src/starkware/starknet/compiler/external_wrapper.py b/src/starkware/starknet/compiler/external_wrapper.py index acd1bbfe..594d2190 100644 --- a/src/starkware/starknet/compiler/external_wrapper.py +++ b/src/starkware/starknet/compiler/external_wrapper.py @@ -355,10 +355,9 @@ def prepare_raw_input_args( arg_struct_members: Dict[str, MemberDefinition], func_location: Location, ) -> ArgList: - call_args = ArgList( args=[ - ExprAssignment(identifier=arg_name, expr=expr) + ExprAssignment(identifier=ExprIdentifier(name=arg_name), expr=expr) for arg_name, expr in safe_zip( arg_struct_members, [selector, calldata_size, calldata_ptr] ) diff --git a/src/starkware/starknet/compiler/external_wrapper_test.py b/src/starkware/starknet/compiler/external_wrapper_test.py index 3d69296d..9b2a8d43 100644 --- a/src/starkware/starknet/compiler/external_wrapper_test.py +++ b/src/starkware/starknet/compiler/external_wrapper_test.py @@ -63,7 +63,7 @@ def test_wrapper_with_implicit_args(builtins_directive: bool): assert isinstance(program.identifiers.get_by_full_name(WRAPPER_SCOPE + "f"), FunctionDefinition) - expected_result = "%builtins pedersen range_check ecdsa\n" + strip_comments_and_linebreaks( + expected_result = "%builtins pedersen range_check ecdsa\n\n" + strip_comments_and_linebreaks( """\ # Implementation of f [ap] = [fp + (-6)]; ap++ # Return ecdsa_ptr. @@ -129,7 +129,7 @@ def test_wrapper_with_return_values(builtins_directive: bool): assert isinstance(program.identifiers.get_by_full_name(WRAPPER_SCOPE + "f"), FunctionDefinition) - expected_result = "%builtins pedersen range_check ecdsa\n" + strip_comments_and_linebreaks( + expected_result = "%builtins pedersen range_check ecdsa\n\n" + strip_comments_and_linebreaks( """\ # A dummy memcpy(). ap += [ap] diff --git a/src/starkware/starknet/compiler/starknet_pass_manager.py b/src/starkware/starknet/compiler/starknet_pass_manager.py index d472faf3..5f0fbf70 100644 --- a/src/starkware/starknet/compiler/starknet_pass_manager.py +++ b/src/starkware/starknet/compiler/starknet_pass_manager.py @@ -49,6 +49,7 @@ def starknet_pass_manager( "starkware.cairo.common.cairo_builtins", "starkware.cairo.common.hash", "starkware.cairo.common.memcpy", + "starkware.cairo.lang.compiler.lib.registers", "starkware.starknet.common.storage", "starkware.starknet.common.syscalls", ], diff --git a/src/starkware/starknet/compiler/starknet_preprocessor.py b/src/starkware/starknet/compiler/starknet_preprocessor.py index 6825c21e..d7eacd83 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from starkware.cairo.lang.compiler.ast.code_elements import ( BuiltinsDirective, @@ -27,6 +27,7 @@ ) from starkware.starknet.compiler.validation_utils import get_function_attr from starkware.starknet.definitions import constants +from starkware.starknet.public.abi import AbiType from starkware.starknet.public.abi_structs import ( prepare_type_for_abi, struct_definition_to_abi_entry, @@ -39,7 +40,7 @@ @dataclasses.dataclass class StarknetPreprocessedProgram(PreprocessedProgram): # JSON dict that contains information on the callable functions in the contract. - abi: Any + abi: AbiType class StarknetPreprocessor(Preprocessor): @@ -54,7 +55,7 @@ def __init__(self, **kwargs): super().__init__(supported_decorators=supported_decorators, **kwargs) # JSON dict for the ABI output. - self.abi: List[dict] = [] + self.abi: AbiType = [] # A map from external struct (short) name to its ABI entry. self.abi_structs: Dict[str, dict] = {} # A map from external struct (short) name to the fully qualified name. diff --git a/src/starkware/starknet/compiler/starknet_preprocessor_test.py b/src/starkware/starknet/compiler/starknet_preprocessor_test.py index a1dd7f24..9cc42d37 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor_test.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor_test.py @@ -103,7 +103,7 @@ def test_abi_basic(): namespace MyNamespace: struct ExternalStruct: - member y: (felt, felt) + member y: (x : felt, y : felt) end end @@ -124,7 +124,7 @@ def test_abi_basic(): end @external -func f(a : felt, arr_len : felt, arr : felt*) -> (b : felt, c : felt): +func f(a : (x : felt, y : felt), arr_len : felt, arr : felt*) -> (b : felt, c : felt): return (0, 1) end @@ -163,7 +163,7 @@ def test_abi_basic(): { "type": "struct", "name": "ExternalStruct", - "members": [{"name": "y", "offset": 0, "type": "(felt, felt)"}], + "members": [{"name": "y", "offset": 0, "type": "(x : felt, y : felt)"}], "size": 2, }, { @@ -180,7 +180,7 @@ def test_abi_basic(): }, { "inputs": [ - {"name": "a", "type": "felt"}, + {"name": "a", "type": "(x : felt, y : felt)"}, {"name": "arr_len", "type": "felt"}, {"name": "arr", "type": "felt*"}, ], diff --git a/src/starkware/starknet/compiler/storage_var_test.py b/src/starkware/starknet/compiler/storage_var_test.py index 1dbc748c..3fd78ac1 100644 --- a/src/starkware/starknet/compiler/storage_var_test.py +++ b/src/starkware/starknet/compiler/storage_var_test.py @@ -175,10 +175,9 @@ def test_storage_var_success(): [ap] = [ap + (-12)]; ap++ # Return (updated) range_check_ptr. ret """ - assert ( - re.sub("call rel -?[0-9]+", "call rel ???", strip_comments_and_linebreaks(program.format())) - == strip_comments_and_linebreaks(expected_result).lstrip() - ) + assert re.sub( + "call rel -?[0-9]+", "call rel ???", strip_comments_and_linebreaks(program.format()) + ) == strip_comments_and_linebreaks(expected_result) def test_storage_var_failures(): diff --git a/src/starkware/starknet/compiler/test_utils.py b/src/starkware/starknet/compiler/test_utils.py index 1b55d970..17337804 100644 --- a/src/starkware/starknet/compiler/test_utils.py +++ b/src/starkware/starknet/compiler/test_utils.py @@ -1,6 +1,9 @@ from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError -from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import preprocess_str_ex +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( + CAIRO_TEST_MODULES, + preprocess_str_ex, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( verify_exception as generic_verify_exception, ) @@ -8,7 +11,8 @@ from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager from starkware.starknet.compiler.starknet_preprocessor import StarknetPreprocessedProgram -TEST_MODULES = { +STARKNET_TEST_MODULES = { + **CAIRO_TEST_MODULES, "starkware.starknet.common.storage": """ struct Storage: end @@ -72,7 +76,7 @@ def preprocess_str(code: str) -> StarknetPreprocessedProgram: preprocessed = preprocess_str_ex( code=code, pass_manager=starknet_pass_manager( - prime=DEFAULT_PRIME, read_module=read_file_from_dict(TEST_MODULES) + prime=DEFAULT_PRIME, read_module=read_file_from_dict(STARKNET_TEST_MODULES) ), ) assert isinstance(preprocessed, StarknetPreprocessedProgram) @@ -84,7 +88,7 @@ def verify_exception(code: str, error: str, exc_type=PreprocessorError): code=code, error=error, pass_manager=starknet_pass_manager( - prime=DEFAULT_PRIME, read_module=read_file_from_dict(TEST_MODULES) + prime=DEFAULT_PRIME, read_module=read_file_from_dict(STARKNET_TEST_MODULES) ), exc_type=exc_type, ) diff --git a/src/starkware/starknet/compiler/validation_utils.py b/src/starkware/starknet/compiler/validation_utils.py index 1ff9a2ba..b4563160 100644 --- a/src/starkware/starknet/compiler/validation_utils.py +++ b/src/starkware/starknet/compiler/validation_utils.py @@ -12,6 +12,7 @@ from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.starknet.compiler.data_encoder import ArgumentInfo, EncodingType, encode_data from starkware.starknet.definitions import constants +from starkware.starknet.public.abi import EXECUTE_ENTRY_POINT_NAME, AbiType TAttr = TypeVar("TAttr") @@ -71,6 +72,27 @@ def verify_no_return_values(elm: CodeElementFunction, name_in_error_message: str ) +def verify_account_contract(contract_abi: AbiType, is_account_contract: bool): + """ + Verifies that the given abi is that of a StarkNet account contract if and only if it + has an entry point named "__execute__" and raises an exception otherwise. + """ + contains_execute_entry_point = any( + entry_point["type"] == "function" and entry_point["name"] == EXECUTE_ENTRY_POINT_NAME + for entry_point in contract_abi + ) + if contains_execute_entry_point and (not is_account_contract): + raise PreprocessorError( + message=f"Only account contracts may have a function named " + f'"{EXECUTE_ENTRY_POINT_NAME}". Use --account_contract flag.' + ) + + if (not contains_execute_entry_point) and is_account_contract: + raise PreprocessorError( + message=f'Account contracts must have a function named "{EXECUTE_ENTRY_POINT_NAME}".' + ) + + # Common utils. diff --git a/src/starkware/starknet/core/CMakeLists.txt b/src/starkware/starknet/core/CMakeLists.txt index da7f9846..90f0f508 100644 --- a/src/starkware/starknet/core/CMakeLists.txt +++ b/src/starkware/starknet/core/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(os) +add_subdirectory(test_contract) diff --git a/src/starkware/starknet/core/os/CMakeLists.txt b/src/starkware/starknet/core/os/CMakeLists.txt index e2383ba5..7ed5a3d0 100644 --- a/src/starkware/starknet/core/os/CMakeLists.txt +++ b/src/starkware/starknet/core/os/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(os_config) + cairo_compile(starknet_os_program starknet_os_compiled.json os.cairo "--debug_info_with_source") @@ -45,6 +47,7 @@ python_lib(starknet_os_utils_lib starknet_business_logic_lib starknet_contract_definition_lib starknet_definitions_lib + starknet_execute_entry_point_base_lib starknet_general_config_lib starknet_internal_transaction_interface_lib starknet_storage_lib @@ -97,7 +100,10 @@ full_python_test(starknet_transaction_hash_test LIBS cairo_common_lib + cairo_function_runner_lib + starknet_contract_definition_lib starknet_definitions_lib + starknet_os_utils_lib starknet_transaction_hash_lib starkware_crypto_lib pip_pytest diff --git a/src/starkware/starknet/core/os/block_context.cairo b/src/starkware/starknet/core/os/block_context.cairo new file mode 100644 index 00000000..b51bf763 --- /dev/null +++ b/src/starkware/starknet/core/os/block_context.cairo @@ -0,0 +1,54 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.starknet.core.os.builtins import BuiltinParams, get_builtin_params +from starkware.starknet.core.os.contracts import ( + ContractDefinitionFact, load_contract_definition_facts) +from starkware.starknet.core.os.os_config.os_config import StarknetOsConfig + +struct BlockInfo: + # Currently, the block timestamp is not validated. + member block_timestamp : felt + member block_number : felt +end + +# Represents information that is the same throughout the block. +struct BlockContext: + # Parameters for select_builtins. + member builtin_params : BuiltinParams* + + # A list of (contract_hash, contract_definition) with the contracts that are executed + # in this block. + member n_contract_definition_facts : felt + member contract_definition_facts : ContractDefinitionFact* + # The address of the sequencer that is creating this block. + member sequencer_address : felt + # Information about the block. + member block_info : BlockInfo + # StarknetOsConfig instance. + member starknet_os_config : StarknetOsConfig +end + +# Returns a BlockContext instance. +# +# 'syscall_handler' and 'os_input' should be passed as hint variables. +func get_block_context{pedersen_ptr : HashBuiltin*, range_check_ptr}() -> ( + block_context : BlockContext*): + alloc_locals + let (n_contract_definition_facts, contract_definition_facts) = load_contract_definition_facts() + let (builtin_params) = get_builtin_params() + local block_context : BlockContext = BlockContext( + builtin_params=builtin_params, + n_contract_definition_facts=n_contract_definition_facts, + contract_definition_facts=contract_definition_facts, + sequencer_address=nondet %{ os_input.general_config.sequencer_address %}, + block_info=BlockInfo( + block_timestamp=nondet %{ syscall_handler.block_info.block_timestamp %}, + block_number=nondet %{ syscall_handler.block_info.block_number %}), + starknet_os_config=StarknetOsConfig( + chain_id=nondet %{ os_input.general_config.chain_id.value %}, + fee_token_address=nondet %{ os_input.general_config.fee_token_address %} + )) + + let (__fp__, _) = get_fp_and_pc() + return (block_context=&block_context) +end diff --git a/src/starkware/starknet/core/os/contract_hash.py b/src/starkware/starknet/core/os/contract_hash.py index c72d18b1..78d3cb82 100644 --- a/src/starkware/starknet/core/os/contract_hash.py +++ b/src/starkware/starknet/core/os/contract_hash.py @@ -110,6 +110,14 @@ def compute_hinted_contract_definition_hash(contract_definition: ContractDefinit # Remove attributes field from raw dictionary, for hash backward compatibility of # contracts deployed prior to adding this feature. del dumped_program["attributes"] + else: + # Remove accessible_scopes and flow_tracking_data fields from raw dictionary, for hash + # backward compatibility of contracts deployed prior to adding this feature. + for attr in dumped_program["attributes"]: + if len(attr["accessible_scopes"]) == 0: + del attr["accessible_scopes"] + if attr["flow_tracking_data"] is None: + del attr["flow_tracking_data"] input_to_hash = dict(program=dumped_program, abi=contract_definition.abi) return starknet_keccak(data=json.dumps(input_to_hash, sort_keys=True).encode()) diff --git a/src/starkware/starknet/core/os/contracts.cairo b/src/starkware/starknet/core/os/contracts.cairo index 2652278a..d90996c6 100644 --- a/src/starkware/starknet/core/os/contracts.cairo +++ b/src/starkware/starknet/core/os/contracts.cairo @@ -202,8 +202,12 @@ func load_contract_definition_facts_inner{pedersen_ptr : HashBuiltin*, range_che %{ from starkware.python.utils import from_bytes - assert ids.contract_definition_fact.hash == from_bytes(contract_hash), \ - 'Computed contract_hash is inconsistent with the hash in the os_input' + computed_hash = ids.contract_definition_fact.hash + expected_hash = from_bytes(contract_hash) + assert computed_hash == expected_hash, ( + "Computed contract_hash is inconsistent with the hash in the os_input" + f"Computed hash = {computed_hash}, Expected hash = {expected_hash}.") + vm_load_program(contract_definition.program, ids.contract_definition.bytecode_ptr) %} diff --git a/src/starkware/starknet/core/os/os.cairo b/src/starkware/starknet/core/os/os.cairo index 86b5b23f..e42c2040 100644 --- a/src/starkware/starknet/core/os/os.cairo +++ b/src/starkware/starknet/core/os/os.cairo @@ -2,12 +2,10 @@ from starkware.cairo.common.alloc import alloc from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.dict import DictAccess from starkware.cairo.common.math import assert_not_equal -from starkware.cairo.common.segments import relocate_segment -from starkware.cairo.common.serialize import serialize_word -from starkware.starknet.core.os.output import ( - BlockInfo, OsCarriedOutputs, OsOutput, os_output_serialize) +from starkware.starknet.core.os.block_context import BlockContext, get_block_context +from starkware.starknet.core.os.os_config.os_config import get_starknet_os_config_hash +from starkware.starknet.core.os.output import OsCarriedOutputs, os_output_serialize from starkware.starknet.core.os.state import state_update from starkware.starknet.core.os.transactions import execute_transactions @@ -21,29 +19,26 @@ func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecds let initial_range_check_ptr = range_check_ptr let range_check_ptr = range_check_ptr + 1 - let (local os_output : OsOutput*) = alloc() + let (initial_carried_outputs : OsCarriedOutputs*) = alloc() %{ from starkware.starknet.core.os.os_input import StarknetOsInput os_input = StarknetOsInput.load(data=program_input) - ids.os_output.initial_outputs.messages_to_l1 = segments.add_temp_segment() - ids.os_output.initial_outputs.messages_to_l2 = segments.add_temp_segment() - ids.os_output.initial_outputs.deployment_info = segments.add_temp_segment() + ids.initial_carried_outputs.messages_to_l1 = segments.add_temp_segment() + ids.initial_carried_outputs.messages_to_l2 = segments.add_temp_segment() + ids.initial_carried_outputs.deployment_info = segments.add_temp_segment() %} - assert os_output.block_info = BlockInfo( - block_timestamp=nondet %{ syscall_handler.block_info.block_timestamp %}, - block_number=nondet %{ syscall_handler.block_info.block_number %}) - - tempvar outputs : OsCarriedOutputs* = &os_output.initial_outputs + let (block_context : BlockContext*) = get_block_context() + let outputs = initial_carried_outputs with outputs: let (local reserved_range_checks_end, state_changes) = execute_transactions( - block_info=&os_output.block_info) + block_context=block_context) end + let final_carried_outputs = outputs - assert os_output.final_outputs = [outputs] local ecdsa_ptr = ecdsa_ptr local bitwise_ptr = bitwise_ptr @@ -56,19 +51,32 @@ func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecds ids.initial_storage_updates_ptr = segments.add_temp_segment() %} let storage_updates_ptr = initial_storage_updates_ptr + with storage_updates_ptr: let (commitment_tree_update_output) = state_update{hash_ptr=pedersen_ptr}( state_changes_dict=state_changes.changes_start, state_changes_dict_end=state_changes.changes_end) end - assert os_output.commitment_tree_update_output = commitment_tree_update_output %{ vm_exit_scope() %} + # Compute the general config hash. + # This is done here to avoid passing pedersen_ptr to os_output_serialize. + let hash_ptr = pedersen_ptr + with hash_ptr: + let (starknet_os_config_hash) = get_starknet_os_config_hash( + starknet_os_config=&block_context.starknet_os_config) + end + let pedersen_ptr = hash_ptr + os_output_serialize( - os_output=os_output, + block_context=block_context, + commitment_tree_update_output=commitment_tree_update_output, + initial_carried_outputs=initial_carried_outputs, + final_carried_outputs=final_carried_outputs, storage_updates_ptr_start=initial_storage_updates_ptr, - storage_updates_ptr_end=storage_updates_ptr) + storage_updates_ptr_end=storage_updates_ptr, + starknet_os_config_hash=starknet_os_config_hash) # Make sure that we report using at least 1 range check to guarantee that # initial_range_check_ptr points to a valid range check instance. diff --git a/src/starkware/starknet/core/os/os_config/CMakeLists.txt b/src/starkware/starknet/core/os/os_config/CMakeLists.txt new file mode 100644 index 00000000..97af6bd7 --- /dev/null +++ b/src/starkware/starknet/core/os/os_config/CMakeLists.txt @@ -0,0 +1,35 @@ +python_lib(starknet_os_config_hash_lib + PREFIX starkware/starknet/core/os/os_config + + FILES + os_config_hash.py + + LIBS + cairo_common_lib + starknet_general_config_lib +) + +full_python_test(starknet_os_config_hash_test + PREFIX starkware/starknet/core/os/os_config + PYTHON python3.7 + TESTED_MODULES starkware/starknet/core/os + + FILES + os_config_hash_test.py + + LIBS + cairo_common_lib + cairo_function_runner_lib + starknet_os_config_hash_lib + starknet_general_config_lib + starknet_os_utils_lib + starkware_python_test_utils_lib + pip_pytest + pip_pytest_asyncio +) + +python_exe(starknet_os_config_hash_fix + VENV starknet_os_config_hash_test_venv + MODULE starkware.starknet.core.os.os_config.os_config_hash_test + ARGS "--fix" +) diff --git a/src/starkware/starknet/core/os/os_config/os_config.cairo b/src/starkware/starknet/core/os/os_config/os_config.cairo new file mode 100644 index 00000000..b6bf9931 --- /dev/null +++ b/src/starkware/starknet/core/os/os_config/os_config.cairo @@ -0,0 +1,36 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash_state import hash_finalize, hash_init, hash_update_single +from starkware.cairo.common.registers import get_fp_and_pc + +const STARKNET_OS_CONFIG_VERSION = 'StarknetOsConfig1' + +struct StarknetOsConfig: + # The identifier of the chain. + # This field can be used to prevent replay of testnet transactions on mainnet. + member chain_id : felt + # The (L2) address of the fee token contract. + member fee_token_address : felt +end + +# Calculates the hash of StarkNet OS config. +func get_starknet_os_config_hash{hash_ptr : HashBuiltin*}( + starknet_os_config : StarknetOsConfig*) -> (starknet_os_config_hash : felt): + let (hash_state_ptr) = hash_init() + let (hash_state_ptr) = hash_update_single( + hash_state_ptr=hash_state_ptr, item=STARKNET_OS_CONFIG_VERSION) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr=hash_state_ptr, item=starknet_os_config.chain_id) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr=hash_state_ptr, item=starknet_os_config.fee_token_address) + + let (starknet_os_config_hash) = hash_finalize(hash_state_ptr=hash_state_ptr) + + return (starknet_os_config_hash=starknet_os_config_hash) +end + +func starknet_os_config_new(chain_id : felt, fee_token_address : felt) -> ( + starknet_os_config : StarknetOsConfig*): + let (fp_val, pc_val) = get_fp_and_pc() + static_assert StarknetOsConfig.SIZE == Args.SIZE + return (starknet_os_config=cast(fp_val - 2 - StarknetOsConfig.SIZE, StarknetOsConfig*)) +end diff --git a/src/starkware/starknet/core/os/os_config/os_config_hash.json b/src/starkware/starknet/core/os/os_config/os_config_hash.json new file mode 100644 index 00000000..2bb8f2eb --- /dev/null +++ b/src/starkware/starknet/core/os/os_config/os_config_hash.json @@ -0,0 +1,4 @@ +{ + "mainnet": "0x17c0bc29d31e9a7d14671610a7626264ce9ce8e3ed066a4775adf9b123de9dd", + "testnet": "0x36f5e4ea4dd042801c8841e3db8e654124305da0f11824fc1db60c405dbb39f" +} diff --git a/src/starkware/starknet/core/os/os_config/os_config_hash.py b/src/starkware/starknet/core/os/os_config/os_config_hash.py new file mode 100644 index 00000000..7817aa2e --- /dev/null +++ b/src/starkware/starknet/core/os/os_config/os_config_hash.py @@ -0,0 +1,18 @@ +from starkware.cairo.common.hash_state import compute_hash_on_elements +from starkware.starknet.definitions.general_config import StarknetOsConfig + +# A constant representing the StarkNet OS config version. +STARKNET_OS_CONFIG_HASH_VERSION = int.from_bytes(b"StarknetOsConfig1", "big") + + +def calculate_starknet_config_hash(starknet_os_config: StarknetOsConfig) -> int: + """ + Calculates the hash of StarkNet config. + """ + return compute_hash_on_elements( + data=[ + STARKNET_OS_CONFIG_HASH_VERSION, + starknet_os_config.chain_id.value, + starknet_os_config.fee_token_address, + ] + ) diff --git a/src/starkware/starknet/core/os/os_config/os_config_hash_test.py b/src/starkware/starknet/core/os/os_config/os_config_hash_test.py new file mode 100644 index 00000000..78e4c5ce --- /dev/null +++ b/src/starkware/starknet/core/os/os_config/os_config_hash_test.py @@ -0,0 +1,109 @@ +import json + +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.common.structs import CairoStructFactory +from starkware.python.random_test import random_test +from starkware.python.utils import get_source_dir_path +from starkware.starknet.core.os.os_config.os_config_hash import ( + STARKNET_OS_CONFIG_HASH_VERSION, + calculate_starknet_config_hash, +) +from starkware.starknet.core.os.os_program import get_os_program +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.general_config import StarknetChainId, StarknetOsConfig + +HASH_PATH = get_source_dir_path("src/starkware/starknet/core/os/os_config/os_config_hash.json") +FEE_TOKEN_ADDRESS = 0x49D36570D4E46F48E99674BD3FCC84644DDD6B96F7C741B1562B82F9E004DC7 +FIX_COMMAND = "fix_starknet_os_config_hash" + + +@random_test() +def test_get_starknet_config_hash(seed: int): + """ + Tests the consistency between the Cairo implementation and the python one. + """ + os_program = get_os_program() + + config_version = os_program.get_const( + name="starkware.starknet.core.os.os_config.os_config.STARKNET_OS_CONFIG_VERSION", + full_name_lookup=True, + ) + assert config_version == STARKNET_OS_CONFIG_HASH_VERSION + + runner = CairoFunctionRunner(os_program, layout="all") + starknet_os_config = StarknetOsConfig( + fee_token_address=fields.AddressField.get_random_value(), + ) + structs = CairoStructFactory( + identifiers=os_program.identifiers, + additional_imports=[ + "starkware.starknet.core.os.os_config.os_config.StarknetOsConfig", + ], + ).structs + runner.run( + "starkware.starknet.core.os.os_config.os_config.get_starknet_os_config_hash", + hash_ptr=runner.pedersen_builtin.base, + starknet_os_config=structs.StarknetOsConfig( + chain_id=starknet_os_config.chain_id.value, + fee_token_address=starknet_os_config.fee_token_address, + ), + use_full_name=True, + verify_secure=True, + ) + pedersen_ptr, starknet_config_hash = runner.get_return_values(2) + assert pedersen_ptr == runner.pedersen_builtin.base + ( + (2 + structs.StarknetOsConfig.size) * runner.pedersen_builtin.cells_per_instance + ) + assert starknet_config_hash == calculate_starknet_config_hash( + starknet_os_config=starknet_os_config + ) + + +def run_starknet_os_config_hash_test(fix: bool): + configs = { + "mainnet": StarknetOsConfig( + chain_id=StarknetChainId.MAINNET, fee_token_address=FEE_TOKEN_ADDRESS + ), + "testnet": StarknetOsConfig( + chain_id=StarknetChainId.TESTNET, fee_token_address=FEE_TOKEN_ADDRESS + ), + } + + config_hashes = { + config_name: hex(calculate_starknet_config_hash(starknet_os_config=config)) + for config_name, config in configs.items() + } + + if fix: + with open(HASH_PATH, "w") as fp: + fp.write(json.dumps(config_hashes, indent=4) + "\n") + return + + expected_hashes = json.load(open(HASH_PATH)) + for config_name, computed_hash in config_hashes.items(): + expected_hash = expected_hashes[config_name] + assert expected_hash == computed_hash, ( + f"Wrong StarkNet OS config hash in os_config_hash.json.\n" + f"Computed hash: {computed_hash}. Expected: {expected_hash}.\n" + f"Please run {FIX_COMMAND}." + ) + + +def test_reference_config_hash(): + run_starknet_os_config_hash_test(fix=False) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Create or test the StarkNet OS config hash.") + parser.add_argument( + "--fix", action="store_true", help="Fix the value of the StarkNet OS config hash." + ) + + args = parser.parse_args() + run_starknet_os_config_hash_test(fix=args.fix) + + +if __name__ == "__main__": + main() diff --git a/src/starkware/starknet/core/os/os_program.py b/src/starkware/starknet/core/os/os_program.py index 99230f65..9bad565d 100644 --- a/src/starkware/starknet/core/os/os_program.py +++ b/src/starkware/starknet/core/os/os_program.py @@ -2,7 +2,7 @@ import cachetools -from starkware.cairo.bootloader.hash_program import compute_program_hash_chain +from starkware.cairo.bootloaders.hash_program import compute_program_hash_chain from starkware.cairo.lang.compiler.program import Program STARKNET_OS_COMPILED_PATH = os.path.join(os.path.dirname(__file__), "starknet_os_compiled.json") diff --git a/src/starkware/starknet/core/os/output.cairo b/src/starkware/starknet/core/os/output.cairo index bd5ec4c7..5ced9ad7 100644 --- a/src/starkware/starknet/core/os/output.cairo +++ b/src/starkware/starknet/core/os/output.cairo @@ -1,6 +1,7 @@ from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.segments import relocate_segment from starkware.cairo.common.serialize import serialize_word +from starkware.starknet.core.os.block_context import BlockContext, BlockInfo from starkware.starknet.core.os.state import CommitmentTreeUpdateOutput # An L2 to L1 message header, the message payload is concatenated to the end of the header. @@ -47,59 +48,49 @@ func os_carried_outputs_new( return (os_carried_outputs=cast(fp_val - 2 - OsCarriedOutputs.SIZE, OsCarriedOutputs*)) end -struct BlockInfo: - # Currently, the block timestamp is not validated. - member block_timestamp : felt - member block_number : felt -end - -struct OsOutput: - # The previous and new root of the contract's storage. - member commitment_tree_update_output : CommitmentTreeUpdateOutput* - member initial_outputs : OsCarriedOutputs - member final_outputs : OsCarriedOutputs - member block_info : BlockInfo -end - func os_output_serialize{output_ptr : felt*}( - os_output : OsOutput*, storage_updates_ptr_start : felt*, storage_updates_ptr_end : felt*): + block_context : BlockContext*, commitment_tree_update_output : CommitmentTreeUpdateOutput*, + initial_carried_outputs : OsCarriedOutputs*, final_carried_outputs : OsCarriedOutputs*, + storage_updates_ptr_start : felt*, storage_updates_ptr_end : felt*, + starknet_os_config_hash : felt): # Serialize program output. # Serialize roots. - serialize_word(os_output.commitment_tree_update_output.initial_storage_root) - serialize_word(os_output.commitment_tree_update_output.final_storage_root) + serialize_word(commitment_tree_update_output.initial_storage_root) + serialize_word(commitment_tree_update_output.final_storage_root) - serialize_word(os_output.block_info.block_number) + serialize_word(block_context.block_info.block_number) + serialize_word(starknet_os_config_hash) let messages_to_l1_segment_size = ( - os_output.final_outputs.messages_to_l1 - - os_output.initial_outputs.messages_to_l1) + final_carried_outputs.messages_to_l1 - + initial_carried_outputs.messages_to_l1) serialize_word(messages_to_l1_segment_size) # Relocate 'messages_to_l1_segment' to the correct place in the output segment. - relocate_segment(src_ptr=os_output.initial_outputs.messages_to_l1, dest_ptr=output_ptr) - let output_ptr = cast(os_output.final_outputs.messages_to_l1, felt*) + relocate_segment(src_ptr=initial_carried_outputs.messages_to_l1, dest_ptr=output_ptr) + let output_ptr = cast(final_carried_outputs.messages_to_l1, felt*) let messages_to_l2_segment_size = ( - os_output.final_outputs.messages_to_l2 - - os_output.initial_outputs.messages_to_l2) + final_carried_outputs.messages_to_l2 - + initial_carried_outputs.messages_to_l2) serialize_word(messages_to_l2_segment_size) # Relocate 'messages_to_l2_segment' to the correct place in the output segment. - relocate_segment(src_ptr=os_output.initial_outputs.messages_to_l2, dest_ptr=output_ptr) - let output_ptr = cast(os_output.final_outputs.messages_to_l2, felt*) + relocate_segment(src_ptr=initial_carried_outputs.messages_to_l2, dest_ptr=output_ptr) + let output_ptr = cast(final_carried_outputs.messages_to_l2, felt*) # Serialize data availability. let da_start = output_ptr let deployment_info_segment_size = ( - os_output.final_outputs.deployment_info - - os_output.initial_outputs.deployment_info) + final_carried_outputs.deployment_info - + initial_carried_outputs.deployment_info) serialize_word(deployment_info_segment_size) # Relocate 'deployment_info_segment' to the correct place in the output segment. - relocate_segment(src_ptr=os_output.initial_outputs.deployment_info, dest_ptr=output_ptr) - let output_ptr = cast(os_output.final_outputs.deployment_info, felt*) + relocate_segment(src_ptr=initial_carried_outputs.deployment_info, dest_ptr=output_ptr) + let output_ptr = cast(final_carried_outputs.deployment_info, felt*) # Relocate 'storage_updates_segment' to the correct place in the output segment. relocate_segment(src_ptr=storage_updates_ptr_start, dest_ptr=output_ptr) diff --git a/src/starkware/starknet/core/os/program_hash.json b/src/starkware/starknet/core/os/program_hash.json index 9bd7437f..d18cef90 100644 --- a/src/starkware/starknet/core/os/program_hash.json +++ b/src/starkware/starknet/core/os/program_hash.json @@ -1,3 +1,3 @@ { - "program_hash": "0x26b17b932ce47266d0e6ae3d6bb17c9189a755b41e9e48b3899abc2aae1a298" + "program_hash": "0x2e5e61af182148e7cdd99ab288435125f730a8f0811f446f3a085153d806e" } diff --git a/src/starkware/starknet/core/os/program_hash_test.py b/src/starkware/starknet/core/os/program_hash_test.py index c8e7a9f9..fd70ac54 100644 --- a/src/starkware/starknet/core/os/program_hash_test.py +++ b/src/starkware/starknet/core/os/program_hash_test.py @@ -1,6 +1,6 @@ import os -from starkware.cairo.bootloader.program_hash_test_utils import ( +from starkware.cairo.bootloaders.program_hash_test_utils import ( program_hash_test_main, run_generate_hash_test, ) diff --git a/src/starkware/starknet/core/os/state.cairo b/src/starkware/starknet/core/os/state.cairo index bdea801a..f0a64a01 100644 --- a/src/starkware/starknet/core/os/state.cairo +++ b/src/starkware/starknet/core/os/state.cairo @@ -2,7 +2,9 @@ from starkware.cairo.common.alloc import alloc from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.dict import DictAccess, squash_dict from starkware.cairo.common.hash import hash2 -from starkware.cairo.common.patricia import patricia_update +from starkware.cairo.common.patricia import ( + ParticiaGlobals, PatriciaUpdateConstants, patricia_update_constants_new, + patricia_update_using_update_constants) from starkware.cairo.common.segments import relocate_segment const MERKLE_HEIGHT = 251 # PRIME.bit_length() - 1. @@ -60,11 +62,17 @@ func state_update{hash_ptr : HashBuiltin*, range_check_ptr, storage_updates_ptr let output_n_updates = [storage_updates_ptr] let storage_updates_ptr = storage_updates_ptr + 1 let n_actual_state_changes = 0 + # Creates PatriciaUpdateConstants struct for patricia update. + let ( + local patricia_update_constants : PatriciaUpdateConstants*) = patricia_update_constants_new( + ) + with n_actual_state_changes: hash_state_changes( n_state_changes=n_state_changes, state_changes=squashed_dict, - hashed_state_changes=hashed_state_changes) + hashed_state_changes=hashed_state_changes, + patricia_update_constants=patricia_update_constants) end # Write number of state updates. assert output_n_updates = n_actual_state_changes @@ -86,7 +94,10 @@ func state_update{hash_ptr : HashBuiltin*, range_check_ptr, storage_updates_ptr assert global_state_storage.commitment_tree.height == ids.MERKLE_HEIGHT %} - patricia_update( + # Call patricia_update_using_update_constants() instead of patricia_update() + # in order not to repeat globals_pow2 calculation. + patricia_update_using_update_constants( + patricia_update_constants=patricia_update_constants, update_ptr=hashed_state_changes, n_updates=n_state_changes, height=MERKLE_HEIGHT, @@ -125,7 +136,8 @@ end func hash_state_changes{ hash_ptr : HashBuiltin*, range_check_ptr, storage_updates_ptr : felt*, n_actual_state_changes}( - n_state_changes, state_changes : DictAccess*, hashed_state_changes : DictAccess*): + n_state_changes, state_changes : DictAccess*, hashed_state_changes : DictAccess*, + patricia_update_constants : PatriciaUpdateConstants*): if n_state_changes == 0: return () end @@ -156,7 +168,10 @@ func hash_state_changes{ squashed_dict=squashed_storage_dict) local n_updates = (squashed_storage_dict_end - squashed_storage_dict) / DictAccess.SIZE - let vault_merkle_multi_update_ret = patricia_update( + # Call patricia_update_using_update_constants() instead of patricia_update() + # in order not to repeat globals_pow2 calculation. + patricia_update_using_update_constants( + patricia_update_constants=patricia_update_constants, update_ptr=squashed_storage_dict, n_updates=n_updates, height=MERKLE_HEIGHT, @@ -197,7 +212,8 @@ func hash_state_changes{ return hash_state_changes( n_state_changes=n_state_changes - 1, state_changes=state_changes + DictAccess.SIZE, - hashed_state_changes=hashed_state_changes) + hashed_state_changes=hashed_state_changes, + patricia_update_constants=patricia_update_constants) end # Write contract address and number of updates. @@ -214,5 +230,6 @@ func hash_state_changes{ return hash_state_changes( n_state_changes=n_state_changes - 1, state_changes=state_changes + DictAccess.SIZE, - hashed_state_changes=hashed_state_changes) + hashed_state_changes=hashed_state_changes, + patricia_update_constants=patricia_update_constants) end diff --git a/src/starkware/starknet/core/os/syscall_utils.py b/src/starkware/starknet/core/os/syscall_utils.py index 9408de67..8a5e3ca3 100644 --- a/src/starkware/starknet/core/os/syscall_utils.py +++ b/src/starkware/starknet/core/os/syscall_utils.py @@ -10,19 +10,18 @@ from starkware.cairo.lang.compiler.identifier_definition import StructDefinition from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue -from starkware.python.utils import camel_to_snake_case, safe_zip -from starkware.starknet.business_logic.internal_transaction_interface import ( - InternalStateTransaction, -) +from starkware.python.utils import assert_exhausted, camel_to_snake_case, safe_zip +from starkware.starknet.business_logic.execute_entry_point_base import ExecuteEntryPointBase from starkware.starknet.business_logic.state import BlockInfo, CarriedState from starkware.starknet.business_logic.transaction_execution_objects import ( - ContractCall, - ContractCallResponse, - L2ToL1MessageInfo, - OrderedEventContent, + CallInfo, + OrderedEvent, + OrderedL2ToL1Message, TransactionExecutionContext, + TransactionExecutionInfo, ) from starkware.starknet.core.os.os_program import get_os_program +from starkware.starknet.definitions import constants from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.services.api.contract_definition import EntryPointType @@ -499,33 +498,31 @@ class BusinessLogicSysCallHandler(SysCallHandlerBase): def __init__( self, + execute_entry_point_cls: Type[ExecuteEntryPointBase], tx_execution_context: TransactionExecutionContext, state: CarriedState, caller_address: int, contract_address: int, - signature: List[int], starknet_storage: BusinessLogicStarknetStorage, general_config: StarknetGeneralConfig, initial_syscall_ptr: RelocatableValue, ): super().__init__(general_config=general_config) + self.execute_entry_point_cls = execute_entry_point_cls self.tx_execution_context = tx_execution_context self.state = state self.caller_address = caller_address self.contract_address = contract_address - self.signature = signature self.starknet_storage = starknet_storage self.loop = starknet_storage.loop - # Accumulated execution info. - self.internal_call_responses: List[ContractCallResponse] = [] - self.internal_calls: List[ContractCall] = [] + # Internal calls executed by the current contract call. + self.internal_calls: List[CallInfo] = [] # Events emitted by the current contract call. - self.events: List[OrderedEventContent] = [] - - # Messages from L2 to L1 including ones sent from internal calls. - self.l2_to_l1_messages: List[L2ToL1MessageInfo] = [] + self.events: List[OrderedEvent] = [] + # Messages sent by the current contract call to L1. + self.l2_to_l1_messages: List[OrderedL2ToL1Message] = [] # Kept for validations during the run. self.expected_syscall_ptr = initial_syscall_ptr @@ -607,25 +604,20 @@ def _call_contract( else: raise NotImplementedError(f"Unsupported call type {syscall_name}.") - from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction - - tx = InternalInvokeFunction( + call = self.execute_entry_point_cls( contract_address=contract_address, code_address=code_address, entry_point_selector=cast(int, request.function_selector), entry_point_type=entry_point_type, calldata=calldata, - signature=self.signature, - hash_value=0, caller_address=caller_address, - nonce=None, ) with self.contract_call_execution_context( - tx=tx, called_contract_address=tx.contract_address + call=call, called_contract_address=contract_address ): # Execute contract call. - execution_info = tx.execute_contract_function( + call_info = call.sync_execute( state=self.state, general_config=self.general_config, loop=self.loop, @@ -633,21 +625,16 @@ def _call_contract( ) # Update execution info. - self.l2_to_l1_messages.extend(execution_info.l2_to_l1_messages) - call_response = ContractCallResponse( - retdata=execution_info.retdata, - ) - self.internal_call_responses.append(call_response) - self.internal_calls.extend(execution_info.contract_calls) + self.internal_calls.append(call_info) - return call_response.retdata + return call_info.retdata @contextlib.contextmanager def contract_call_execution_context( - self, tx: InternalStateTransaction, called_contract_address: int + self, call: ExecuteEntryPointBase, called_contract_address: int ): # Pre-execution preperation and validations. - self._enrich_state(tx=tx) + self._enrich_state(call=call) try: yield @@ -669,19 +656,19 @@ def contract_call_execution_context( # Post-execution updates. self._update_starknet_storage() - def _enrich_state(self, tx: InternalStateTransaction): + def _enrich_state(self, call: ExecuteEntryPointBase): """ - Prepares the state for the execution of the given transaction. + Prepares the state for the execution of the given call. """ # Apply current modifications to the origin contract storage, in case there will be - # future nested call to this contract. + # future nested calls to this contract. self.state.update_contract_storage( contract_address=self.contract_address, modifications=self.starknet_storage.get_modifications(), ) - # Fetch required information for the transaction (that is not already cached in the state). - state_selector = tx.get_state_selector(general_config=self.general_config) + # Fetch required information for the call (that is not already cached in the state). + state_selector = call.get_call_state_selector() state_selector -= self.state.state_selector future_extra_state = asyncio.run_coroutine_threadsafe( coro=self.state.shared_state.get_filled_carried_state( @@ -705,12 +692,10 @@ def emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableVal request = self._read_and_validate_syscall_request( syscall_name="emit_event", segments=segments, syscall_ptr=syscall_ptr ) - # Update events count. - self.tx_execution_context.n_emitted_events += 1 self.events.append( - OrderedEventContent( - order=self.tx_execution_context.n_emitted_events - 1, + OrderedEvent( + order=self.tx_execution_context.n_emitted_events, keys=segments.memory.get_range_as_ints( addr=cast(RelocatableValue, request.keys), size=cast(int, request.keys_len) ), @@ -720,14 +705,19 @@ def emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableVal ) ) + # Update events count. + self.tx_execution_context.n_emitted_events += 1 + def _get_tx_info_ptr(self, segments: MemorySegmentManager) -> RelocatableValue: if self.tx_info_ptr is None: tx_info = self.structs.TxInfo( - version=0, + version=constants.TRANSACTION_VERSION, account_contract_address=self.tx_execution_context.account_contract_address, - max_fee=0, - signature_len=len(self.signature), - signature=segments.gen_arg(self.signature), + max_fee=self.tx_execution_context.max_fee, + transaction_hash=self.tx_execution_context.transaction_hash, + signature_len=len(self.tx_execution_context.signature), + signature=segments.gen_arg(self.tx_execution_context.signature), + chain_id=self.general_config.chain_id.value, ) self.tx_info_ptr = cast(RelocatableValue, segments.gen_arg(arg=tx_info)) @@ -740,16 +730,20 @@ def send_message_to_l1(self, segments: MemorySegmentManager, syscall_ptr: Reloca payload = segments.memory.get_range_as_ints( addr=cast(RelocatableValue, request.payload_ptr), size=cast(int, request.payload_size) ) + self.l2_to_l1_messages.append( - # Note that the constructor of L2ToL1MessageInfo might fail as it is + # Note that the constructor of OrderedL2ToL1Message might fail as it is # more restrictive than the Cairo code. - L2ToL1MessageInfo( - from_address=self.contract_address, + OrderedL2ToL1Message( + order=self.tx_execution_context.n_sent_messages, to_address=cast(int, request.to_address), payload=payload, ) ) + # Update messages count. + self.tx_execution_context.n_sent_messages += 1 + def _get_block_number(self) -> int: return self.state.block_info.block_number @@ -824,12 +818,14 @@ def post_run_tx_info_related_logic(self, runner: CairoFunctionRunner): signature_ptr = tx_info.signature stark_assert( segments.get_segment_used_size(segment_index=signature_ptr.segment_index) - == len(self.signature), + == len(self.tx_execution_context.signature), code=StarknetErrorCode.SECURITY_ERROR, message=f"Out of bounds write to signature segment.", ) - runner.mark_as_accessed(address=signature_ptr, size=len(self.signature)) + runner.mark_as_accessed( + address=signature_ptr, size=len(self.tx_execution_context.signature) + ) def post_run(self, runner: CairoFunctionRunner, syscall_stop_ptr: MaybeRelocatable): """ @@ -852,19 +848,23 @@ class OsSysCallHandler(SysCallHandlerBase): def __init__( self, - contract_calls: List[ContractCall], + tx_execution_infos: List[TransactionExecutionInfo], general_config: StarknetGeneralConfig, starknet_storage_by_address: Mapping[int, StarknetStorageInterface], block_info: BlockInfo, ): super().__init__(general_config=general_config) - self._call_response_iterator: Iterator[ContractCallResponse] = iter([]) - self._contract_calls_iterator = iter(contract_calls) + self.tx_execution_info_iterator: Iterator[TransactionExecutionInfo] = iter( + tx_execution_infos + ) + self.call_iterator: Iterator[CallInfo] = iter([]) # The following members are stacks that represent the calls being executed now (the last # item is the current execution; the one before it, is the caller function; and so on). - self.call_stack: List[ContractCall] = [] + self.call_stack: List[CallInfo] = [] + # For each call an iterator to the retdata of its internal calls. + self.retdata_iterators: List[Iterator[List[int]]] = [] # For each call an iterator to the read_values array which is consumed when the transaction # code is executed. self.execute_code_read_iterators: List[Iterator[int]] = [] @@ -885,6 +885,9 @@ def __init__( # Set during enter_tx. self.tx_info_ptr: Optional[RelocatableValue] = None + # The TransactionExecutionInfo for the transaction currently being executed. + self.tx_execution_info: Optional[TransactionExecutionInfo] = None + def _read_and_validate_syscall_request( self, syscall_name: str, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> CairoStructProxy: @@ -911,8 +914,7 @@ def _call_contract( request = self.structs.CallContractRequest.from_ptr( memory=segments.memory, addr=syscall_ptr ) - call_response = next(self._call_response_iterator) - return call_response.retdata + return next(self.retdata_iterators[-1]) def _get_block_number(self) -> int: return self.block_info.block_number @@ -923,12 +925,12 @@ def _get_block_timestamp(self) -> int: def _get_caller_address( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> int: - return self.call_stack[-1].from_address + return self.call_stack[-1].caller_address def _get_contract_address( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> int: - return self.call_stack[-1].to_address + return self.call_stack[-1].contract_address def _get_tx_info_ptr(self, segments: MemorySegmentManager) -> RelocatableValue: assert self.tx_info_ptr is not None @@ -963,27 +965,31 @@ def start_tx(self, tx_info_ptr: RelocatableValue): assert self.tx_info_ptr is None self.tx_info_ptr = tx_info_ptr + assert self.tx_execution_info is None + self.tx_execution_info = next(self.tx_execution_info_iterator) + self.call_iterator = self.tx_execution_info.gen_call_iterator() + def end_tx(self): """ Called after the execution of the current transaction complete. """ + assert_exhausted(iterator=self.call_iterator) assert self.tx_info_ptr is not None self.tx_info_ptr = None + assert self.tx_execution_info is not None + self.tx_execution_info = None def enter_call(self): - call_info = next(self._contract_calls_iterator) - self._call_response_iterator = iter(call_info.internal_call_responses) + call_info = next(self.call_iterator) self.call_stack.append(call_info) + self.retdata_iterators.append(call.retdata for call in call_info.internal_calls) # Create two iterators for call_info.storage_read_values. self.execute_code_read_iterators.append(iter(call_info.storage_read_values)) self.execute_syscall_read_iterators.append(iter(call_info.storage_read_values)) def exit_call(self): - assert ( - next(self._call_response_iterator, None) is None - ), "internal_call_responses should be consumed before calling exit_call." self.call_stack.pop() - # Remove the top iterators in execute_code_read_iterators and execute_syscall_read_iterators - # and make sure it is empty. - assert all(False for x in self.execute_code_read_iterators.pop()) - assert all(False for x in self.execute_syscall_read_iterators.pop()) + # Remove the top iterators and make sure they are empty. + assert_exhausted(iterator=self.retdata_iterators.pop()) + assert_exhausted(iterator=self.execute_code_read_iterators.pop()) + assert_exhausted(iterator=self.execute_syscall_read_iterators.pop()) diff --git a/src/starkware/starknet/core/os/transaction_hash.cairo b/src/starkware/starknet/core/os/transaction_hash.cairo new file mode 100644 index 00000000..d40a91d1 --- /dev/null +++ b/src/starkware/starknet/core/os/transaction_hash.cairo @@ -0,0 +1,38 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.hash_state import ( + hash_finalize, hash_init, hash_update, hash_update_single) + +func get_transaction_hash{hash_ptr : HashBuiltin*}( + tx_hash_prefix : felt, version : felt, contract_address : felt, entry_point_selector : felt, + calldata_size : felt, calldata : felt*, max_fee : felt, chain_id : felt, + additional_data_size : felt, additional_data : felt*) -> (tx_hash : felt): + let (calldata_hash) = get_calldata_hash(calldata_size=calldata_size, calldata=calldata) + + let (hash_state_ptr) = hash_init() + let (hash_state_ptr) = hash_update_single(hash_state_ptr=hash_state_ptr, item=tx_hash_prefix) + let (hash_state_ptr) = hash_update_single(hash_state_ptr=hash_state_ptr, item=version) + let (hash_state_ptr) = hash_update_single(hash_state_ptr=hash_state_ptr, item=contract_address) + let (hash_state_ptr) = hash_update_single( + hash_state_ptr=hash_state_ptr, item=entry_point_selector) + let (hash_state_ptr) = hash_update_single(hash_state_ptr=hash_state_ptr, item=calldata_hash) + let (hash_state_ptr) = hash_update_single(hash_state_ptr=hash_state_ptr, item=max_fee) + let (hash_state_ptr) = hash_update_single(hash_state_ptr=hash_state_ptr, item=chain_id) + + let (hash_state_ptr) = hash_update( + hash_state_ptr=hash_state_ptr, data_ptr=additional_data, data_length=additional_data_size) + + let (tx_hash) = hash_finalize(hash_state_ptr=hash_state_ptr) + + return (tx_hash=tx_hash) +end + +func get_calldata_hash{hash_ptr : HashBuiltin*}(calldata_size : felt, calldata : felt*) -> ( + calldata_hash : felt): + let (hash_state_ptr) = hash_init() + let (hash_state_ptr) = hash_update( + hash_state_ptr=hash_state_ptr, data_ptr=calldata, data_length=calldata_size) + + let (calldata_hash) = hash_finalize(hash_state_ptr=hash_state_ptr) + + return (calldata_hash=calldata_hash) +end diff --git a/src/starkware/starknet/core/os/transaction_hash.py b/src/starkware/starknet/core/os/transaction_hash.py index 563dc6c1..f2d754e6 100644 --- a/src/starkware/starknet/core/os/transaction_hash.py +++ b/src/starkware/starknet/core/os/transaction_hash.py @@ -4,6 +4,7 @@ from starkware.cairo.common.hash_state import compute_hash_on_elements from starkware.cairo.lang.vm.crypto import pedersen_hash from starkware.python.utils import from_bytes +from starkware.starknet.definitions import constants from starkware.starknet.services.api.contract_definition import CONSTRUCTOR_SELECTOR @@ -15,9 +16,11 @@ class TransactionHashPrefix(Enum): def calculate_transaction_hash_common( tx_hash_prefix: TransactionHashPrefix, + version: int, contract_address: int, entry_point_selector: int, calldata: Sequence[int], + max_fee: int, chain_id: int, additional_data: Sequence[int], hash_function: Callable[[int, int], int] = pedersen_hash, @@ -27,10 +30,12 @@ def calculate_transaction_hash_common( transaction. The transaction hash is a hash chain of the following information: 1. A prefix that depends on the transaction type. - 2. Contract address. - 3. Entry point selector. - 4. A hash chain of the calldata. - 5. The network's chain ID. + 2. The transaction's version. + 3. Contract address. + 4. Entry point selector. + 5. A hash chain of the calldata. + 6. The transaction's maximum fee. + 7. The network's chain ID. Each hash chain computation begins with 0 as initialization and ends with its length appended. The length is appended in order to avoid collisions of the following kind: H([x,y,z]) = h(h(x,y),z) = H([w, z]) where w = h(x,y). @@ -38,9 +43,11 @@ def calculate_transaction_hash_common( calldata_hash = compute_hash_on_elements(data=calldata, hash_func=hash_function) data_to_hash = [ tx_hash_prefix.value, + version, contract_address, entry_point_selector, calldata_hash, + max_fee, chain_id, *additional_data, ] @@ -59,9 +66,12 @@ def calculate_deploy_transaction_hash( ) -> int: return calculate_transaction_hash_common( tx_hash_prefix=TransactionHashPrefix.DEPLOY, + version=constants.TRANSACTION_VERSION, contract_address=contract_address, entry_point_selector=CONSTRUCTOR_SELECTOR, calldata=constructor_calldata, + # Field max_fee is considered 0 for Deploy transaction hash calculation purposes. + max_fee=0, chain_id=chain_id, additional_data=[], hash_function=hash_function, diff --git a/src/starkware/starknet/core/os/transaction_hash_test.py b/src/starkware/starknet/core/os/transaction_hash_test.py index 0030b430..ec84edfa 100644 --- a/src/starkware/starknet/core/os/transaction_hash_test.py +++ b/src/starkware/starknet/core/os/transaction_hash_test.py @@ -2,54 +2,123 @@ import pytest -from starkware.cairo.common.hash_state import compute_hash_on_elements +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash +from starkware.starknet.core.os.os_program import get_os_program from starkware.starknet.core.os.transaction_hash import ( TransactionHashPrefix, calculate_deploy_transaction_hash, calculate_transaction_hash_common, + compute_hash_on_elements, ) +from starkware.starknet.definitions import constants +from starkware.starknet.services.api.contract_definition import CONSTRUCTOR_SELECTOR + + +def run_cairo_transaction_hash( + tx_hash_prefix: TransactionHashPrefix, + version: int, + contract_address: int, + entry_point_selector: int, + calldata: List[int], + max_fee: int, + chain_id: int, + additional_data: List[int], +) -> int: + program = get_os_program() + runner = CairoFunctionRunner(program, layout="all") + + runner.run( + "starkware.starknet.core.os.transaction_hash.get_transaction_hash", + hash_ptr=runner.pedersen_builtin.base, + tx_hash_prefix=tx_hash_prefix.value, + version=version, + contract_address=contract_address, + entry_point_selector=entry_point_selector, + calldata_size=len(calldata), + calldata=calldata, + max_fee=max_fee, + chain_id=chain_id, + additional_data_size=len(additional_data), + additional_data=additional_data, + use_full_name=True, + verify_secure=False, + ) + pedersen_ptr, contract_hash = runner.get_return_values(2) + + assert pedersen_ptr == runner.pedersen_builtin.base + ( + runner.pedersen_builtin.cells_per_instance * (9 + len(calldata) + len(additional_data)) + ) + return contract_hash @pytest.mark.parametrize("tx_hash_prefix", set(TransactionHashPrefix)) @pytest.mark.parametrize("calldata", [[], [659], [540, 338], [73, 443, 234, 350, 841]]) +@pytest.mark.parametrize("max_fee", [0, 10, 299]) +@pytest.mark.parametrize("version", [0]) @pytest.mark.parametrize("additional_data", [[], [17]]) def test_transaction_hash_common_flow( - tx_hash_prefix: TransactionHashPrefix, calldata: List[int], additional_data: List[int] + tx_hash_prefix: TransactionHashPrefix, + version: int, + calldata: List[int], + max_fee: int, + additional_data: List[int], ): + """ + Tests that the Python and Cairo tx_hash implementations return the same value. + """ contract_address = 42 entry_point_selector = 100 chain_id = 1 - expected_tx_hash = compute_hash_on_elements( - data=[ - tx_hash_prefix.value, - contract_address, - entry_point_selector, - compute_hash_on_elements(data=calldata, hash_func=pedersen_hash), - chain_id, - *additional_data, - ], - hash_func=pedersen_hash, - ) - assert expected_tx_hash == calculate_transaction_hash_common( + tx_hash = calculate_transaction_hash_common( tx_hash_prefix=tx_hash_prefix, + version=version, contract_address=contract_address, entry_point_selector=entry_point_selector, calldata=calldata, + max_fee=max_fee, chain_id=chain_id, hash_function=pedersen_hash, additional_data=additional_data, ) + assert tx_hash == run_cairo_transaction_hash( + tx_hash_prefix=tx_hash_prefix, + contract_address=contract_address, + entry_point_selector=entry_point_selector, + calldata=calldata, + max_fee=max_fee, + chain_id=chain_id, + version=version, + additional_data=additional_data, + ) + -def test_deploy_transaction_hash(): - expected_hash = 0x334E744938EE65F038037AD1CC85D949A3554D5CF6508471BB00B0AD91B483 +@pytest.mark.parametrize("constructor_calldata", [[], [658], [539, 337], [72, 442, 233, 349, 840]]) +def test_deploy_transaction_hash(constructor_calldata: List[int]): + version = constants.TRANSACTION_VERSION + contract_address = 1 + chain_id = 1 + max_fee = 0 + + expected_hash = compute_hash_on_elements( + data=[ + TransactionHashPrefix.DEPLOY.value, + version, + contract_address, + CONSTRUCTOR_SELECTOR, + compute_hash_on_elements(data=constructor_calldata, hash_func=pedersen_hash), + max_fee, + chain_id, + ], + hash_func=pedersen_hash, + ) assert ( calculate_deploy_transaction_hash( - contract_address=1, - constructor_calldata=[1, 2], - chain_id=1, + contract_address=contract_address, + constructor_calldata=constructor_calldata, + chain_id=chain_id, ) == expected_hash ) diff --git a/src/starkware/starknet/core/os/transactions.cairo b/src/starkware/starknet/core/os/transactions.cairo index 479c9cca..2c591054 100644 --- a/src/starkware/starknet/core/os/transactions.cairo +++ b/src/starkware/starknet/core/os/transactions.cairo @@ -4,10 +4,11 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.dict import dict_new, dict_read, dict_update, dict_write from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.find_element import find_element, search_sorted -from starkware.cairo.common.math import assert_nn, assert_not_zero +from starkware.cairo.common.math import assert_nn, assert_nn_le, assert_not_zero from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.registers import get_ap, get_fp_and_pc from starkware.cairo.common.segments import relocate_segment +from starkware.cairo.common.uint256 import Uint256 from starkware.starknet.common.syscalls import ( CALL_CONTRACT_SELECTOR, DELEGATE_CALL_SELECTOR, DELEGATE_L1_HANDLER_SELECTOR, EMIT_EVENT_SELECTOR, GET_BLOCK_NUMBER_SELECTOR, GET_BLOCK_TIMESTAMP_SELECTOR, @@ -18,14 +19,16 @@ from starkware.starknet.common.syscalls import ( GetCallerAddress, GetCallerAddressResponse, GetContractAddress, GetContractAddressResponse, GetSequencerAddress, GetSequencerAddressResponse, GetTxInfo, GetTxInfoResponse, GetTxSignature, GetTxSignatureResponse, SendMessageToL1SysCall, StorageRead, StorageWrite, TxInfo) -from starkware.starknet.core.os.builtins import ( - BuiltinEncodings, BuiltinParams, BuiltinPointers, get_builtin_params) +from starkware.starknet.core.os.block_context import BlockContext +from starkware.starknet.core.os.builtins import BuiltinEncodings, BuiltinParams, BuiltinPointers from starkware.starknet.core.os.contracts import ( - ContractDefinition, ContractDefinitionFact, ContractEntryPoint, load_contract_definition_facts) + ContractDefinition, ContractDefinitionFact, ContractEntryPoint) +from starkware.starknet.core.os.os_config.os_config import StarknetOsConfig from starkware.starknet.core.os.output import ( BlockInfo, DeploymentInfoHeader, MessageToL1Header, MessageToL2Header, OsCarriedOutputs, os_carried_outputs_new) from starkware.starknet.core.os.state import StateEntry +from starkware.starknet.core.os.transaction_hash import get_transaction_hash const UNINITIALIZED_CONTRACT_HASH = 0 @@ -40,10 +43,20 @@ const ENTRY_POINT_TYPE_EXTERNAL = 0 const ENTRY_POINT_TYPE_L1_HANDLER = 1 const ENTRY_POINT_TYPE_CONSTRUCTOR = 2 +const TRANSACTION_VERSION = 0 + # get_selector_from_name('constructor'). const CONSTRUCTOR_SELECTOR = ( 0x28ffe4ff0f226a9107253e17a904099aa4f63a02a5621de0576e5aa71bc5194) +# get_selector_from_name('__execute__'). +const EXECUTE_ENTRY_POINT_SELECTOR = ( + 0x15d40a3d6ca2ac30f4031e42be28da9b056fef9bb7357ac5e85627ee876e5ad) + +# get_selector_from_name('transfer'). +const TRANSFER_SELECTOR = ( + 0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e) + const DEFAULT_ENTRY_POINT_SELECTOR = 0 # Represents the execution context during the execution of contract code. @@ -69,15 +82,6 @@ struct StateChanges: member changes_end : DictAccess* end -# Context that remains fixed throughout the block. -struct BlockContext: - member builtin_params : BuiltinParams* - member n_contract_definition_facts : felt - member contract_definition_facts : ContractDefinitionFact* - member sequencer_address : felt - member block_info : BlockInfo* -end - # Executes the transactions in the hint variable os_input.transactions. # # Returns: @@ -92,7 +96,7 @@ end # the returned range_check_ptr is smaller then reserved_range_checks_end. func execute_transactions{ pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr, bitwise_ptr, - outputs : OsCarriedOutputs*}(block_info : BlockInfo*) -> ( + outputs : OsCarriedOutputs*}(block_context : BlockContext*) -> ( reserved_range_checks_end, state_changes : StateChanges): alloc_locals local n_txs @@ -110,24 +114,13 @@ func execute_transactions{ # A dict from contract address to a dict of storage changes. let (local global_state_changes : DictAccess*) = dict_new() - let (n_contract_definition_facts, contract_definition_facts) = load_contract_definition_facts() - - let (local __fp__, _) = get_fp_and_pc() + let (__fp__, _) = get_fp_and_pc() local local_builtin_ptrs : BuiltinPointers = BuiltinPointers( pedersen=pedersen_ptr, range_check=nondet %{ segments.add_temp_segment() %}, ecdsa=ecdsa_ptr, bitwise=bitwise_ptr) - let (builtin_params) = get_builtin_params() - - local block_context : BlockContext = BlockContext( - builtin_params=builtin_params, - n_contract_definition_facts=n_contract_definition_facts, - contract_definition_facts=contract_definition_facts, - sequencer_address=nondet %{ os_input.sequencer_address %}, - block_info=block_info) - let builtin_ptrs = &local_builtin_ptrs %{ vm_enter_scope({ @@ -141,7 +134,7 @@ func execute_transactions{ let global_state_changes_start = global_state_changes execute_transactions_inner{ builtin_ptrs=builtin_ptrs, global_state_changes=global_state_changes}( - block_context=&block_context, n_txs=n_txs) + block_context=block_context, n_txs=n_txs) %{ vm_exit_scope() %} let reserved_range_checks_end = range_check_ptr @@ -199,7 +192,59 @@ func execute_transactions_inner{ return execute_transactions_inner(block_context=block_context, n_txs=n_txs - 1) end -# Executes an externally called invoke transaction. +# Represents the calldata of an ERC20 transfer. +struct TransferCallData: + member recipient : felt + member amount : Uint256 +end + +# Charges a fee from the user. +# If max_fee is not 0, validates that the selector matches the entry point of an account contract +# and executes an ERC20 transfer on the behalf of that account contract. +# +# Arguments: +# block_context - a global context that is fixed throughout the block. +# tx_execution_context - The execution context of the transaction that pays the fee. +func charge_fee{ + range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, + outputs : OsCarriedOutputs*}( + block_context : BlockContext*, tx_execution_context : ExecutionContext*): + alloc_locals + if tx_execution_context.original_tx_info.max_fee == 0: + return () + end + + # Transactions with fee should go through the EXECUTE_ENTRY_POINT_SELECTOR. + assert tx_execution_context.selector = EXECUTE_ENTRY_POINT_SELECTOR + + local calldata : TransferCallData = TransferCallData( + recipient=block_context.sequencer_address, + amount=Uint256(low=nondet %{ syscall_handler.tx_execution_info.actual_fee %}, high=0)) + + tempvar original_tx_info = tx_execution_context.original_tx_info + + # Verify that the charged amount is not larger than the transaction's max_fee field. + assert_nn_le(calldata.amount.low, original_tx_info.max_fee) + + let (__fp__, _) = get_fp_and_pc() + tempvar fee_token_address = block_context.starknet_os_config.fee_token_address + local execution_context : ExecutionContext = ExecutionContext( + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + caller_address=original_tx_info.account_contract_address, + contract_address=fee_token_address, + code_address=fee_token_address, + selector=TRANSFER_SELECTOR, + calldata_size=TransferCallData.SIZE, + calldata=&calldata, + original_tx_info=original_tx_info, + ) + + execute_entry_point(block_context=block_context, execution_context=&execution_context) + + return () +end + +# Executes an externally called transaction (external invoke or l1 handler). # # The transaction should be passed in the hint variable 'tx'. # If the transaction is an L1 handler, it is appended to the list of consumed L1->L2 messages. @@ -244,34 +289,65 @@ func execute_externally_called_invoke_transaction{ ) %} + local nonce + local max_fee = nondet %{ tx.max_fee %} + %{ assert tx.version == ids.TRANSACTION_VERSION, 'Wrong transaction version.' %} + if execution_context.entry_point_type == ENTRY_POINT_TYPE_L1_HANDLER: + %{ ids.nonce = tx.nonce %} + let (__fp__, _) = get_fp_and_pc() + tempvar tx_hash_prefix = 'l1_handler' + tempvar additional_data_size = 1 + tempvar additional_data = &nonce + with_attr error_message("An L1 handler transaction must have max_fee==0."): + assert max_fee = 0 + end + else: + # If execution_context.entry_point_type is not ENTRY_POINT_TYPE_L1_HANDLER, + # it must be ENTRY_POINT_TYPE_EXTERNAL. + assert execution_context.entry_point_type = ENTRY_POINT_TYPE_EXTERNAL + tempvar tx_hash_prefix = 'invoke' + tempvar additional_data_size = 0 + tempvar additional_data = cast(0, felt*) + end + + local chain_id = block_context.starknet_os_config.chain_id + let (transaction_hash) = compute_transaction_hash( + tx_hash_prefix=tx_hash_prefix, + execution_context=execution_context, + max_fee=max_fee, + chain_id=chain_id, + additional_data_size=additional_data_size, + additional_data=additional_data) + assert [execution_context.original_tx_info] = TxInfo( - version=0, + version=TRANSACTION_VERSION, account_contract_address=execution_context.contract_address, - max_fee=0, + max_fee=max_fee, signature_len=nondet %{ len(tx.signature) %}, signature=cast(nondet %{ segments.gen_arg(arg=tx.signature) %}, felt*), + transaction_hash=transaction_hash, + chain_id=chain_id, ) # External calls originate from ORIGIN_ADDRESS. assert execution_context.caller_address = ORIGIN_ADDRESS + # In external calls and l1 handlers, the code_address must match the contract_address. + assert execution_context.code_address = execution_context.contract_address + if execution_context.entry_point_type == ENTRY_POINT_TYPE_L1_HANDLER: # Consume L1-to-L2 message. - consume_l1_to_l2_message(execution_context=execution_context, nonce=nondet %{ tx.nonce %}) + consume_l1_to_l2_message(execution_context=execution_context, nonce=nonce) else: - # If execution_context.entry_point_type is not ENTRY_POINT_TYPE_L1_HANDLER, - # it must be ENTRY_POINT_TYPE_EXTERNAL. - assert execution_context.entry_point_type = ENTRY_POINT_TYPE_EXTERNAL tempvar outputs = outputs end - # In external calls and l1 handlers, the code_address must match the contract_address. - assert execution_context.code_address = execution_context.contract_address - %{ syscall_handler.start_tx(tx_info_ptr=ids.execution_context.original_tx_info.address_) %} execute_entry_point(block_context=block_context, execution_context=execution_context) + charge_fee(block_context=block_context, tx_execution_context=execution_context) %{ syscall_handler.end_tx() %} + return () end @@ -841,12 +917,24 @@ func execute_deploy_transaction{ original_tx_info=cast(nondet %{ segments.add() %}, TxInfo*), ) + let nullptr = cast(0, felt*) + local chain_id = block_context.starknet_os_config.chain_id + let (transaction_hash) = compute_transaction_hash( + tx_hash_prefix='deploy', + execution_context=execution_context, + max_fee=0, + chain_id=chain_id, + additional_data_size=0, + additional_data=nullptr) + assert [execution_context.original_tx_info] = TxInfo( - version=0, + version=TRANSACTION_VERSION, account_contract_address=ORIGIN_ADDRESS, max_fee=0, signature_len=0, - signature=cast(0, felt*), + signature=nullptr, + transaction_hash=transaction_hash, + chain_id=chain_id, ) %{ syscall_handler.start_tx(tx_info_ptr=ids.execution_context.original_tx_info.address_) %} @@ -855,3 +943,41 @@ func execute_deploy_transaction{ return () end + +# Computes the hash of the transaction. +# +# Note that execution_context.original_tx_info is uninitialized when this function is called. +# In particular, this field is not used in this function. +func compute_transaction_hash{builtin_ptrs : BuiltinPointers*}( + tx_hash_prefix : felt, execution_context : ExecutionContext*, max_fee : felt, + chain_id : felt, additional_data_size : felt, additional_data : felt*) -> ( + transaction_hash : felt): + let hash_ptr = builtin_ptrs.pedersen + with hash_ptr: + let (transaction_hash) = get_transaction_hash( + tx_hash_prefix=tx_hash_prefix, + version=TRANSACTION_VERSION, + contract_address=execution_context.contract_address, + entry_point_selector=execution_context.selector, + calldata_size=execution_context.calldata_size, + calldata=execution_context.calldata, + max_fee=max_fee, + chain_id=chain_id, + additional_data_size=additional_data_size, + additional_data=additional_data) + end + + %{ + assert ids.transaction_hash == tx.hash_value, ( + "Computed transaction_hash is inconsistent with the hash in the transaction. " + f"Computed hash = {ids.transaction_hash}, Expected hash = {tx.hash_value}.") + %} + + tempvar builtin_ptrs = new BuiltinPointers( + pedersen=hash_ptr, + range_check=builtin_ptrs.range_check, + ecdsa=builtin_ptrs.ecdsa, + bitwise=builtin_ptrs.bitwise) + + return (transaction_hash=transaction_hash) +end diff --git a/src/starkware/starknet/core/test_contract/CMakeLists.txt b/src/starkware/starknet/core/test_contract/CMakeLists.txt new file mode 100644 index 00000000..71a098cf --- /dev/null +++ b/src/starkware/starknet/core/test_contract/CMakeLists.txt @@ -0,0 +1,26 @@ +starknet_compile(compile_delegate_proxy delegate_proxy.json delegate_proxy.cairo "") +starknet_compile(compile_dummy_account_contract + dummy_account.json dummy_account.cairo "--account_contract") + +python_lib(starknet_external_compiled_contracts_lib + PREFIX starkware/starknet/core/test_contract + + ARTIFACTS + "${CMAKE_CURRENT_BINARY_DIR}/delegate_proxy.json delegate_proxy.json" + "${CMAKE_CURRENT_BINARY_DIR}/dummy_account.json dummy_account.json" +) +add_dependencies(starknet_external_compiled_contracts_lib + compile_delegate_proxy + compile_dummy_account_contract +) + +python_lib(starknet_test_external_contract_test_utils_lib + PREFIX starkware/starknet/core/test_contract + + FILES + test_utils.py + + LIBS + starknet_contract_definition_lib + starknet_external_compiled_contracts_lib +) diff --git a/src/starkware/starknet/core/test_contract/delegate_proxy.cairo b/src/starkware/starknet/core/test_contract/delegate_proxy.cairo new file mode 100644 index 00000000..f4bad437 --- /dev/null +++ b/src/starkware/starknet/core/test_contract/delegate_proxy.cairo @@ -0,0 +1,49 @@ +# Note that this is a dummy contract to be used in tests. + +%lang starknet +%builtins pedersen range_check bitwise + +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.starknet.common.syscalls import delegate_call, delegate_l1_handler + +# The address of the implementation contract. +@storage_var +func impl_address() -> (address : felt): +end + +@external +func set_implementation_address{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + impl_address_ : felt): + impl_address.write(value=impl_address_) + return () +end + +@external +@raw_input +@raw_output +func __default__{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + selector : felt, calldata_size : felt, calldata : felt*) -> ( + retdata_size : felt, retdata : felt*): + let (address) = impl_address.read() + + let (retdata_size : felt, retdata : felt*) = delegate_call( + contract_address=address, + function_selector=selector, + calldata_size=calldata_size, + calldata=calldata) + return (retdata_size=retdata_size, retdata=retdata) +end + +@l1_handler +@raw_input +func __l1_default__{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + selector : felt, calldata_size : felt, calldata : felt*): + let (address) = impl_address.read() + + delegate_l1_handler( + contract_address=address, + function_selector=selector, + calldata_size=calldata_size, + calldata=calldata) + return () +end diff --git a/src/starkware/starknet/core/test_contract/dummy_account.cairo b/src/starkware/starknet/core/test_contract/dummy_account.cairo new file mode 100644 index 00000000..ccd94558 --- /dev/null +++ b/src/starkware/starknet/core/test_contract/dummy_account.cairo @@ -0,0 +1,19 @@ +# A dummy account contract without any validations. + +%lang starknet + +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.starknet.common.syscalls import call_contract + +@external +@raw_output +func __execute__{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( + contract_address, selector : felt, calldata_len : felt, calldata : felt*) -> ( + retdata_size : felt, retdata : felt*): + let (retdata_size : felt, retdata : felt*) = call_contract( + contract_address=contract_address, + function_selector=selector, + calldata_size=calldata_len, + calldata=calldata) + return (retdata_size=retdata_size, retdata=retdata) +end diff --git a/src/starkware/starknet/core/test_contract/test_utils.py b/src/starkware/starknet/core/test_contract/test_utils.py new file mode 100644 index 00000000..24a8c719 --- /dev/null +++ b/src/starkware/starknet/core/test_contract/test_utils.py @@ -0,0 +1,15 @@ +import os + +from starkware.starknet.services.api.contract_definition import ContractDefinition + + +def get_contract_definition(contract_name: str) -> ContractDefinition: + main_dir_path = os.path.dirname(__file__) + file_path = os.path.join(main_dir_path, contract_name + ".json") + + with open(file_path, "r") as fp: + return ContractDefinition.loads(fp.read()) + + +def get_test_contract_definition() -> ContractDefinition: + return get_contract_definition("test_contract") diff --git a/src/starkware/starknet/definitions/CMakeLists.txt b/src/starkware/starknet/definitions/CMakeLists.txt index fc40ccd6..32f408aa 100644 --- a/src/starkware/starknet/definitions/CMakeLists.txt +++ b/src/starkware/starknet/definitions/CMakeLists.txt @@ -26,6 +26,8 @@ python_lib(starknet_general_config_lib general_config.yml LIBS + cairo_instances_lib + everest_general_config_lib starknet_definitions_lib starkware_config_utils_lib starkware_dataclasses_utils_lib diff --git a/src/starkware/starknet/definitions/constants.py b/src/starkware/starknet/definitions/constants.py index da97b5a5..cce604d2 100644 --- a/src/starkware/starknet/definitions/constants.py +++ b/src/starkware/starknet/definitions/constants.py @@ -7,6 +7,7 @@ FIELD_SIZE_BITS = 251 ADDRESS_BITS = FIELD_SIZE_BITS CONTRACT_ADDRESS_BITS = ADDRESS_BITS +NONCE_BITS = FIELD_SIZE_BITS FELT_LOWER_BOUND = 0 FELT_UPPER_BOUND = FIELD_SIZE @@ -15,8 +16,6 @@ # Address 0 is reserved to distinguish an external transaction from an inner (L2<>L2) one. CONTRACT_ADDRESS_LOWER_BOUND = 1 CONTRACT_ADDRESS_UPPER_BOUND = 2 ** CONTRACT_ADDRESS_BITS -CONTRACT_ADDRESS_SALT_LOWER_BOUND = FELT_LOWER_BOUND -CONTRACT_ADDRESS_SALT_UPPER_BOUND = FELT_UPPER_BOUND CONTRACT_HASH_BYTES = HASH_BYTES CONTRACT_HASH_UPPER_BOUND = FIELD_SIZE CONTRACT_STATES_COMMITMENT_TREE_HEIGHT = FIELD_SIZE_BITS @@ -26,18 +25,26 @@ ENTRY_POINT_SELECTOR_UPPER_BOUND = FIELD_SIZE EVENT_COMMITMENT_TREE_HEIGHT = 64 FEE_LOWER_BOUND = 0 -FEE_UPPER_BOUND = 2 ** 256 # Fee is a uint-256. +FEE_UPPER_BOUND = 2 ** 128 # Default hash to fill the parent_hash field of the first block in the sequence. GENESIS_PARENT_BLOCK_HASH = 0 MAX_MESSAGE_TO_L1_LENGTH = 100 MAX_CALLDATA_LENGTH = 2 ** 30 +NONCE_LOWER_BOUND = 0 +NONCE_UPPER_BOUND = 2 ** NONCE_BITS SYSCALL_SELECTOR_UPPER_BOUND = FIELD_SIZE TRANSACTION_COMMITMENT_TREE_HEIGHT = 64 TRANSACTION_HASH_LOWER_BOUND = 0 TRANSACTION_HASH_UPPER_BOUND = FIELD_SIZE +TRANSACTION_VERSION_LOWER_BOUND = 0 +TRANSACTION_VERSION_UPPER_BOUND = 2 ** 32 ADDRESS_LOWER_BOUND = 0 ADDRESS_UPPER_BOUND = 2 ** ADDRESS_BITS + +# In order to identify transactions from unsupported versions. +TRANSACTION_VERSION = 0 + # OS-related constants. L1_TO_L2_MSG_HEADER_SIZE = 5 L2_TO_L1_MSG_HEADER_SIZE = 3 @@ -51,3 +58,6 @@ # additional parameters (offset and length) in solidity. LOG_MSG_TO_L1_ENCODED_DATA_SIZE = (L2_TO_L1_MSG_HEADER_SIZE + 1) - LOG_MSG_TO_L1_N_TOPICS CONSUMED_MSG_TO_L2_ENCODED_DATA_SIZE = (L1_TO_L2_MSG_HEADER_SIZE + 1) - CONSUMED_MSG_TO_L2_N_TOPICS + +# The (empirical) L1 gas cost of each Cairo step. +N_STEPS_FEE_WEIGHT = 0.05 diff --git a/src/starkware/starknet/definitions/error_codes.py b/src/starkware/starknet/definitions/error_codes.py index 56d53e70..25e8ceb3 100644 --- a/src/starkware/starknet/definitions/error_codes.py +++ b/src/starkware/starknet/definitions/error_codes.py @@ -10,6 +10,7 @@ class StarknetErrorCode(ErrorCode): CONTRACT_BYTECODE_SIZE_TOO_LARGE = auto() CONTRACT_DEFINITION_OBJECT_SIZE_TOO_LARGE = auto() ENTRY_POINT_NOT_FOUND_IN_CONTRACT = auto() + FEE_TRANSFER_FAILURE = auto() INVALID_BLOCK_NUMBER = auto() INVALID_BLOCK_TIMESTAMP = auto() INVALID_CONTRACT_DEFINITION = auto() @@ -17,23 +18,28 @@ class StarknetErrorCode(ErrorCode): INVALID_RETURN_DATA = auto() INVALID_STATUS_MODE = auto() INVALID_TRANSACTION_ID = auto() + INVALID_TRANSACTION_HASH = auto() + INVALID_TRANSACTION_VERSION = auto() + L1_TO_L2_MESSAGE_CANCELLED = auto() L1_TO_L2_MESSAGE_ZEROED_COUNTER = auto() MULTIPLE_ENTRY_POINTS_MATCH_SELECTOR = auto() NON_PERMITTED_CONTRACT = auto() + NO_TRACE = auto() OUT_OF_RANGE_ADDRESS = auto() OUT_OF_RANGE_BLOCK_HASH = auto() OUT_OF_RANGE_BLOCK_ID = auto() OUT_OF_RANGE_CALLER_ADDRESS = auto() OUT_OF_RANGE_CONTRACT_ADDRESS = auto() OUT_OF_RANGE_CONTRACT_HASH = auto() - OUT_OF_RANGE_CONTRACT_ADDRESS_SALT = auto() OUT_OF_RANGE_CONTRACT_STORAGE_KEY = auto() OUT_OF_RANGE_ENTRY_POINT_OFFSET = auto() OUT_OF_RANGE_ENTRY_POINT_SELECTOR = auto() OUT_OF_RANGE_FEE = auto() + OUT_OF_RANGE_NONCE = auto() OUT_OF_RANGE_SEQUENCER_ADDRESS = auto() OUT_OF_RANGE_TRANSACTION_HASH = auto() OUT_OF_RANGE_TRANSACTION_ID = auto() + OUT_OF_RANGE_TRANSACTION_VERSION = auto() OUT_OF_RESOURCES = auto() SECURITY_ERROR = auto() TRANSACTION_FAILED = auto() @@ -41,6 +47,7 @@ class StarknetErrorCode(ErrorCode): TRANSACTION_NOT_FOUND = auto() UNEXPECTED_FAILURE = auto() UNINITIALIZED_CONTRACT = auto() + UNSUPPORTED_SELECTOR_FOR_FEE = auto() # Errors that are raised by the gateways and caused by wrong usage of the user. @@ -52,6 +59,9 @@ class StarknetErrorCode(ErrorCode): StarkErrorCode.SCHEMA_VALIDATION_ERROR, StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_OFFSET, StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_SELECTOR, + # External-to-internal conversion errors. + StarknetErrorCode.INVALID_TRANSACTION_VERSION, + StarknetErrorCode.UNSUPPORTED_SELECTOR_FOR_FEE, ] main_gateway_error_code_whitelist: FrozenSet[ErrorCode] = frozenset( @@ -66,7 +76,6 @@ class StarknetErrorCode(ErrorCode): StarknetErrorCode.INVALID_PROGRAM, StarknetErrorCode.MULTIPLE_ENTRY_POINTS_MATCH_SELECTOR, StarknetErrorCode.NON_PERMITTED_CONTRACT, - StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS_SALT, # Reaching traffic limits. StarknetErrorCode.TRANSACTION_LIMIT_EXCEEDED, ] @@ -77,13 +86,18 @@ class StarknetErrorCode(ErrorCode): *external_txs_loading_common_error_codes, # Requests that fail after quering the DB. StarknetErrorCode.BLOCK_NOT_FOUND, + StarknetErrorCode.INVALID_TRANSACTION_HASH, + StarknetErrorCode.NO_TRACE, StarknetErrorCode.TRANSACTION_NOT_FOUND, StarknetErrorCode.UNINITIALIZED_CONTRACT, # Function call errors. StarknetErrorCode.ENTRY_POINT_NOT_FOUND_IN_CONTRACT, + StarknetErrorCode.FEE_TRANSFER_FAILURE, StarknetErrorCode.INVALID_RETURN_DATA, + StarknetErrorCode.OUT_OF_RESOURCES, StarknetErrorCode.SECURITY_ERROR, StarknetErrorCode.TRANSACTION_FAILED, + StarknetErrorCode.UNEXPECTED_FAILURE, # Request parsing errors. StarkErrorCode.MALFORMED_REQUEST, StarknetErrorCode.INVALID_STATUS_MODE, diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index 785554ea..5db60259 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -27,20 +27,117 @@ # Fields data: validation data, dataclass metadata. + +# Common. + +felt_as_hex_list_metadata = dict( + marshmallow_field=mfields.List( + everest_fields.FeltField.get_marshmallow_field( + required=True, load_default=marshmallow.utils.missing + ) + ) +) + +felt_list_metadata = dict( + marshmallow_field=mfields.List(IntAsStr(validate=everest_fields.FeltField.validate)) +) + + +def bytes_as_hex_dict_keys_metadata( + values_schema: Type[marshmallow.Schema], +) -> Dict[str, mfields.Dict]: + width_validator = validate_length( + field_name="contract hash", length=constants.CONTRACT_HASH_BYTES + ) + return dict( + marshmallow_field=mfields.Dict( + keys=BytesAsHex(required=True, validate=width_validator), + values=mfields.Nested(values_schema), + ) + ) + + +timestamp_metadata = dict( + marshmallow_field=StrictRequiredInteger(validate=validate_non_negative("timestamp")) +) + + +# Address. + +AddressField = RangeValidatedField( + lower_bound=constants.ADDRESS_LOWER_BOUND, + upper_bound=constants.ADDRESS_UPPER_BOUND, + name="Address", + error_code=StarknetErrorCode.OUT_OF_RANGE_ADDRESS, + formatter=hex, +) + + +def address_metadata(name: str, error_code: StarknetErrorCode) -> Dict[str, Any]: + return dataclasses.replace(AddressField, name=name, error_code=error_code).metadata() + + +sequencer_address_metadata = address_metadata( + name="Sequencer address", error_code=StarknetErrorCode.OUT_OF_RANGE_SEQUENCER_ADDRESS +) + +caller_address_metadata = address_metadata( + name="Caller address", error_code=StarknetErrorCode.OUT_OF_RANGE_CALLER_ADDRESS +) + +fee_token_address_metadata = address_metadata( + name="Fee token address", error_code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS +) + + +# Nonce. + +NonceField = RangeValidatedField( + lower_bound=constants.NONCE_LOWER_BOUND, + upper_bound=constants.NONCE_UPPER_BOUND, + name="Nonce", + error_code=StarknetErrorCode.OUT_OF_RANGE_NONCE, + formatter=hex, +) +nonce_metadata = NonceField.metadata() + +OptionalNonceField = OptionalField(field=NonceField, none_probability=0) +optional_nonce_metadata = OptionalNonceField.metadata() + + +# Block. + block_number_metadata = sequential_id_metadata(field_name="Block number", allow_previous_id=True) default_optional_block_number_metadata = sequential_id_metadata( field_name="Block number", required=False, load_default=None ) + +BlockHashField = RangeValidatedField( + lower_bound=0, + upper_bound=constants.BLOCK_HASH_UPPER_BOUND, + name="Block hash", + error_code=StarknetErrorCode.OUT_OF_RANGE_BLOCK_HASH, + formatter=hex, +) +block_hash_metadata = BlockHashField.metadata() + +OptionalBlockHashField = OptionalField(field=BlockHashField, none_probability=0) +optional_block_hash_metadata = OptionalBlockHashField.metadata() + default_optional_transaction_index_metadata = sequential_id_metadata( field_name="Transaction index", required=False, load_default=None ) -felt_list_metadata = dict( - marshmallow_field=mfields.List(IntAsStr(validate=everest_fields.FeltField.validate)) -) + +# InvokeFunction. call_data_metadata = felt_list_metadata +call_data_as_hex_metadata = felt_as_hex_list_metadata signature_metadata = felt_list_metadata +retdata_as_hex_metadata = felt_as_hex_list_metadata + + +# Contract address. ContractAddressField = RangeValidatedField( lower_bound=constants.CONTRACT_ADDRESS_LOWER_BOUND, @@ -49,38 +146,18 @@ error_code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS, formatter=hex, ) - contract_address_metadata = ContractAddressField.metadata() -OptionalContractAddressField = OptionalField(field=ContractAddressField, none_probability=0) -optional_contract_address_metadata = OptionalContractAddressField.metadata() - -ContractAddressSalt = RangeValidatedField( - lower_bound=constants.CONTRACT_ADDRESS_SALT_LOWER_BOUND, - upper_bound=constants.CONTRACT_ADDRESS_SALT_UPPER_BOUND, - name="Contract salt", - error_code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS_SALT, - formatter=hex, +OptionalCodeAddressField = OptionalField( + field=dataclasses.replace(ContractAddressField, name="Code address"), none_probability=0 ) +optional_code_address_metadata = OptionalCodeAddressField.metadata() +ContractAddressSalt = everest_fields.felt(name_in_error_message="Contract salt") contract_address_salt_metadata = ContractAddressSalt.metadata() -def bytes_as_hex_dict_keys_metadata( - values_schema: Type[marshmallow.Schema], -) -> Dict[str, mfields.Dict]: - width_validator = validate_length( - field_name="contract hash", length=constants.CONTRACT_HASH_BYTES - ) - return dict( - marshmallow_field=mfields.Dict( - keys=BytesAsHex(required=True, validate=width_validator), - values=mfields.Nested(values_schema), - ) - ) - - -contract_definitions_metadata = dict(marshmallow_field=mfields.Dict(keys=BytesAsHex)) +# Contract hash. def validate_contract_hash(contract_hash: bytes): @@ -98,11 +175,8 @@ def validate_contract_hash(contract_hash: bytes): marshmallow_field=BytesAsHex(required=False, validate=validate_contract_hash), ) -contract_storage_commitment_tree_height_metadata = dict( - marshmallow_field=StrictRequiredInteger( - validate=validate_positive("contract_storage_commitment_tree_height") - ) -) + +# Entry point. EntryPointSelectorField = RangeValidatedField( lower_bound=constants.ENTRY_POINT_SELECTOR_LOWER_BOUND, @@ -111,9 +185,11 @@ def validate_contract_hash(contract_hash: bytes): error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_SELECTOR, formatter=hex, ) - entry_point_selector_metadata = EntryPointSelectorField.metadata() +OptionalEntryPointSelectorField = OptionalField(field=EntryPointSelectorField, none_probability=0) +optional_entry_point_selector_metadata = OptionalEntryPointSelectorField.metadata() + EntryPointOffsetField = RangeValidatedField( lower_bound=constants.ENTRY_POINT_OFFSET_LOWER_BOUND, upper_bound=constants.ENTRY_POINT_OFFSET_UPPER_BOUND, @@ -121,19 +197,42 @@ def validate_contract_hash(contract_hash: bytes): error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_OFFSET, formatter=hex, ) - entry_point_offset_metadata = EntryPointOffsetField.metadata() -global_state_commitment_tree_height_metadata = dict( - marshmallow_field=StrictRequiredInteger( - validate=validate_non_negative("global_state_commitment_tree_height"), - ) + +# Fee. + +FeeField = RangeValidatedField( + lower_bound=constants.FEE_LOWER_BOUND, + upper_bound=constants.FEE_UPPER_BOUND, + name="Fee", + error_code=StarknetErrorCode.OUT_OF_RANGE_FEE, + formatter=hex, ) +fee_metadata = FeeField.metadata(required=False, load_default=0) +# Transaction version. + +TransactionVersionField = RangeValidatedField( + lower_bound=constants.TRANSACTION_VERSION_LOWER_BOUND, + upper_bound=constants.TRANSACTION_VERSION_UPPER_BOUND, + name="Transaction version", + error_code=StarknetErrorCode.OUT_OF_RANGE_TRANSACTION_VERSION, + formatter=hex, +) +tx_version_metadata = TransactionVersionField.metadata( + required=False, load_default=constants.TRANSACTION_VERSION +) + + +# State root. state_root_metadata = dict(marshmallow_field=BytesAsHex(required=True)) optional_state_root_metadata = dict(marshmallow_field=BytesAsHex(required=False, allow_none=True)) + +# Transaction hash. + TransactionHashField = RangeValidatedField( lower_bound=constants.TRANSACTION_HASH_LOWER_BOUND, upper_bound=constants.TRANSACTION_HASH_UPPER_BOUND, @@ -141,56 +240,30 @@ def validate_contract_hash(contract_hash: bytes): error_code=StarknetErrorCode.OUT_OF_RANGE_TRANSACTION_HASH, formatter=hex, ) - transaction_hash_metadata = TransactionHashField.metadata() OptionalTransactionHashField = OptionalField(field=TransactionHashField, none_probability=0) - optional_transaction_hash_metadata = OptionalTransactionHashField.metadata() -BlockHashField = RangeValidatedField( - lower_bound=0, - upper_bound=constants.BLOCK_HASH_UPPER_BOUND, - name="Block hash", - error_code=StarknetErrorCode.OUT_OF_RANGE_BLOCK_HASH, - formatter=hex, -) - -block_hash_metadata = BlockHashField.metadata() -OptionalBlockHashField = OptionalField(field=BlockHashField, none_probability=0) +# General config. -optional_block_hash_metadata = OptionalBlockHashField.metadata() +contract_storage_commitment_tree_height_metadata = dict( + marshmallow_field=StrictRequiredInteger( + validate=validate_positive("contract_storage_commitment_tree_height") + ) +) -timestamp_metadata = dict( - marshmallow_field=StrictRequiredInteger(validate=validate_non_negative("timestamp")) +global_state_commitment_tree_height_metadata = dict( + marshmallow_field=StrictRequiredInteger( + validate=validate_non_negative("global_state_commitment_tree_height"), + ) ) invoke_tx_n_steps_metadata = dict( marshmallow_field=StrictRequiredInteger(validate=validate_non_negative("invoke_tx_n_steps")) ) - -AddressField = RangeValidatedField( - lower_bound=constants.ADDRESS_LOWER_BOUND, - upper_bound=constants.ADDRESS_UPPER_BOUND, - name="Address", - error_code=StarknetErrorCode.OUT_OF_RANGE_ADDRESS, - formatter=hex, -) - - -def address_metadata(name: str, error_code: StarknetErrorCode) -> Dict[str, Any]: - return dataclasses.replace(AddressField, name=name, error_code=error_code).metadata() - - -sequencer_address_metadata = address_metadata( - name="Sequencer address", error_code=StarknetErrorCode.OUT_OF_RANGE_SEQUENCER_ADDRESS -) - -caller_address_metadata = address_metadata( - name="Caller address", error_code=StarknetErrorCode.OUT_OF_RANGE_CALLER_ADDRESS +gas_price = dict( + marshmallow_field=StrictRequiredInteger(validate=validate_non_negative("gas_price")) ) - -OptionalNonceField = OptionalField(field=everest_fields.FeltField, none_probability=0) -optional_nonce_metadata = OptionalNonceField.metadata() diff --git a/src/starkware/starknet/definitions/general_config.py b/src/starkware/starknet/definitions/general_config.py index d9fe8663..545bc6ec 100644 --- a/src/starkware/starknet/definitions/general_config.py +++ b/src/starkware/starknet/definitions/general_config.py @@ -1,17 +1,31 @@ +import copy +import os from dataclasses import field from enum import Enum -from typing import Dict +from typing import Any, Dict import marshmallow.fields as mfields import marshmallow_dataclass +from services.everest.definitions.general_config import EverestGeneralConfig +from starkware.cairo.lang.instances import all_instance from starkware.python.utils import from_bytes from starkware.starknet.definitions import constants, fields -from starkware.starkware_utils.config_base import Config +from starkware.starkware_utils.config_base import Config, load_config from starkware.starkware_utils.field_validators import validate_dict, validate_non_negative -from starkware.starkware_utils.marshmallow_dataclass_fields import StrictRequiredInteger +from starkware.starkware_utils.marshmallow_dataclass_fields import ( + StrictRequiredInteger, + load_int_value, +) -DOCKER_GENERAL_CONFIG_PATH = "/general_config.yml" +GENERAL_CONFIG_FILE_NAME = "general_config.yml" +DOCKER_GENERAL_CONFIG_PATH = os.path.join("/", GENERAL_CONFIG_FILE_NAME) +GENERAL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), GENERAL_CONFIG_FILE_NAME) + +# Reference to the default general config. +default_general_config = load_config( + config_file_path=GENERAL_CONFIG_PATH, load_logging_config=False +) class StarknetChainId(Enum): @@ -19,18 +33,23 @@ class StarknetChainId(Enum): TESTNET = from_bytes(b"SN_GOERLI") -# Default configuration values. +# Fee token account constants. +TOKEN_NAME = from_bytes(b"Wrapped Ether") +TOKEN_SYMBOL = from_bytes(b"WETH") +TOKEN_DECIMALS = 18 -# Note: tokens sent to this default address will be burned. -DEFAULT_SEQUENCER_ADDRESS = 0 + +# Default configuration values. # In order to be able to use Keccak builtin, which uses bitwise, which is sparse. DEFAULT_MAX_STEPS = 10 ** 6 DEFAULT_CHAIN_ID = StarknetChainId.TESTNET +# Given in units of wei. +DEFAULT_GAS_PRICE = 100 * 10 ** 9 + class CairoResource(Enum): N_STEPS = "n_steps" - GAS_WEIGHT = "gas_weight" PEDERSEN_BUILTIN = "pedersen_builtin" RANGE_CHECK_BUILTIN = "range_check_builtin" ECDSA_BUILTIN = "ecdsa_builtin" @@ -39,13 +58,14 @@ class CairoResource(Enum): EC_OP_BUILTIN = "ec_op_builtin" -DEFAULT_CAIRO_USAGE_RESOURCE_FEE_WEIGHTS = { - CairoResource.N_STEPS.value: 0.0, - CairoResource.GAS_WEIGHT.value: 0.0, +DEFAULT_CAIRO_RESOURCE_FEE_WEIGHTS = { + CairoResource.N_STEPS.value: 1.0, CairoResource.PEDERSEN_BUILTIN.value: 0.0, CairoResource.RANGE_CHECK_BUILTIN.value: 0.0, CairoResource.ECDSA_BUILTIN.value: 0.0, CairoResource.BITWISE_BUILTIN.value: 0.0, + CairoResource.OUTPUT_BUILTIN.value: 0.0, + CairoResource.EC_OP_BUILTIN.value: 0.0, } @@ -53,9 +73,24 @@ class CairoResource(Enum): @marshmallow_dataclass.dataclass(frozen=True) -class StarknetGeneralConfig(Config): +class StarknetOsConfig(Config): chain_id: StarknetChainId = field(default=DEFAULT_CHAIN_ID) + fee_token_address: int = field( + metadata=dict( + **fields.fee_token_address_metadata, description="StarkNet fee token L2 address." + ), + default=load_int_value( + field_metadata=fields.fee_token_address_metadata, + value=default_general_config["starknet_os_config"]["fee_token_address"], + ), + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class StarknetGeneralConfig(EverestGeneralConfig): + starknet_os_config: StarknetOsConfig = field(default_factory=StarknetOsConfig) + contract_storage_commitment_tree_height: int = field( metadata=fields.contract_storage_commitment_tree_height_metadata, default=constants.CONTRACT_STATES_COMMITMENT_TREE_HEIGHT, @@ -70,11 +105,16 @@ class StarknetGeneralConfig(Config): metadata=fields.invoke_tx_n_steps_metadata, default=DEFAULT_MAX_STEPS ) + gas_price: int = field(metadata=fields.gas_price, default=DEFAULT_GAS_PRICE) + sequencer_address: int = field( metadata=dict( **fields.sequencer_address_metadata, description="StarkNet sequencer address." ), - default=DEFAULT_SEQUENCER_ADDRESS, + default=load_int_value( + field_metadata=fields.fee_token_address_metadata, + value=default_general_config["sequencer_address"], + ), ) tx_commitment_tree_height: int = field( @@ -87,6 +127,19 @@ class StarknetGeneralConfig(Config): default=constants.TRANSACTION_COMMITMENT_TREE_HEIGHT, ) + tx_version: int = field( + metadata=dict( + marshmallow_field=StrictRequiredInteger( + validate=validate_non_negative("Trasaction version."), + ), + description=( + "Current transaction version - " + "in order to identify transactions from unsupported versions." + ), + ), + default=constants.TRANSACTION_VERSION, + ) + event_commitment_tree_height: int = field( metadata=dict( marshmallow_field=StrictRequiredInteger( @@ -97,19 +150,61 @@ class StarknetGeneralConfig(Config): default=constants.EVENT_COMMITMENT_TREE_HEIGHT, ) - cairo_usage_resource_fee_weights: Dict[str, float] = field( + cairo_resource_fee_weights: Dict[str, float] = field( metadata=dict( marshmallow_field=mfields.Dict( keys=mfields.String, values=mfields.Float, validate=validate_dict( - "Cairo usage resource fee weights", value_validator=validate_non_negative + "Cairo resource fee weights", value_validator=validate_non_negative ), ), description=( - "A mapping from a Cairo usage resource to its coefficient in this transaction " + "A mapping from a Cairo resource to its coefficient in this transaction " "fee calculation." ), ), - default_factory=lambda: DEFAULT_CAIRO_USAGE_RESOURCE_FEE_WEIGHTS.copy(), + default_factory=lambda: DEFAULT_CAIRO_RESOURCE_FEE_WEIGHTS.copy(), ) + + @property + def chain_id(self) -> StarknetChainId: + return self.starknet_os_config.chain_id + + @property + def fee_token_address(self) -> int: + return self.starknet_os_config.fee_token_address + + +def build_general_config(raw_general_config: Dict[str, Any]) -> StarknetGeneralConfig: + """ + Updates the fee weights and builds the general config. + """ + raw_general_config = copy.deepcopy(raw_general_config) + cairo_resource_fee_weights: Dict[str, float] = raw_general_config["cairo_resource_fee_weights"] + n_steps_key = CairoResource.N_STEPS.value + assert cairo_resource_fee_weights.keys() == { + n_steps_key + }, f"Only {n_steps_key} weight should be given." + + n_steps_weight = cairo_resource_fee_weights[n_steps_key] + + # Zero all entries. + cairo_resource_fee_weights.update({resource.value: 0.0 for resource in CairoResource}) + # Update relevant entries. + cairo_resource_fee_weights.update( + { + n_steps_key: n_steps_weight, + # All other weights are deduced from n_steps. + CairoResource.PEDERSEN_BUILTIN.value: n_steps_weight + * all_instance.builtins["pedersen"].ratio, + CairoResource.RANGE_CHECK_BUILTIN.value: n_steps_weight + * all_instance.builtins["range_check"].ratio, + CairoResource.ECDSA_BUILTIN.value: n_steps_weight + * all_instance.builtins["ecdsa"].ratio, + CairoResource.BITWISE_BUILTIN.value: n_steps_weight + * all_instance.builtins["bitwise"].ratio, + } + ) + + return StarknetGeneralConfig.load(data=raw_general_config) diff --git a/src/starkware/starknet/definitions/general_config.yml b/src/starkware/starknet/definitions/general_config.yml index d1a12221..7947ddcc 100644 --- a/src/starkware/starknet/definitions/general_config.yml +++ b/src/starkware/starknet/definitions/general_config.yml @@ -1,14 +1,12 @@ -sequencer_address: '0x0' +cairo_resource_fee_weights: + n_steps: 1.0 contract_storage_commitment_tree_height: 251 +event_commitment_tree_height: 64 +gas_price: 100000000000 global_state_commitment_tree_height: 251 invoke_tx_max_n_steps: 1000000 -cairo_usage_resource_fee_weights: - bitwise_builtin: 0.0 - ecdsa_builtin: 0.0 - gas_weight: 0.0 - n_steps: 0.0 - pedersen_builtin: 0.0 - range_check_builtin: 0.0 +sequencer_address: '0x37b2cd6baaa515f520383bee7b7094f892f4c770695fc329a8973e841a971ae' +starknet_os_config: + fee_token_address: '0x20abcf49dad3e9813d65bf1b8d54c5a0c9e6049a3027bd8c2ab315475c0a5c1' tx_commitment_tree_height: 64 -event_commitment_tree_height: 64 - +tx_version: 0 diff --git a/src/starkware/starknet/eth/CMakeLists.txt b/src/starkware/starknet/eth/CMakeLists.txt index a134deae..08500473 100644 --- a/src/starkware/starknet/eth/CMakeLists.txt +++ b/src/starkware/starknet/eth/CMakeLists.txt @@ -2,6 +2,7 @@ python_lib(starknet_messaging_sol PREFIX starkware/starknet/eth FILES IStarknetMessaging.sol + IStarknetMessagingEvents.sol StarknetMessaging.sol LIBS diff --git a/src/starkware/starknet/eth/IStarknetMessaging.sol b/src/starkware/starknet/eth/IStarknetMessaging.sol index 5ff06e20..2f9de4bd 100644 --- a/src/starkware/starknet/eth/IStarknetMessaging.sol +++ b/src/starkware/starknet/eth/IStarknetMessaging.sol @@ -1,46 +1,16 @@ // SPDX-License-Identifier: Apache-2.0. pragma solidity ^0.6.12; -interface IStarknetMessaging { - // This event needs to be compatible with the one defined in Output.sol. - event LogMessageToL1( - uint256 indexed from_address, - address indexed to_address, - uint256[] payload - ); - - // An event that is raised when a message is sent from L1 to L2. - event LogMessageToL2( - address indexed from_address, - uint256 indexed to_address, - uint256 indexed selector, - uint256[] payload, - uint256 nonce - ); - - // An event that is raised when a message from L2 to L1 is consumed. - event ConsumedMessageToL1( - uint256 indexed from_address, - address indexed to_address, - uint256[] payload - ); - - // An event that is raised when a message from L1 to L2 is consumed. - event ConsumedMessageToL2( - address indexed from_address, - uint256 indexed to_address, - uint256 indexed selector, - uint256[] payload, - uint256 nonce - ); +import "./IStarknetMessagingEvents.sol"; +interface IStarknetMessaging is IStarknetMessagingEvents { /** Sends a message to an L2 contract. Returns the hash of the message. */ function sendMessageToL2( - uint256 to_address, + uint256 toAddress, uint256 selector, uint256[] calldata payload ) external returns (bytes32); @@ -53,4 +23,29 @@ interface IStarknetMessaging { function consumeMessageFromL2(uint256 fromAddress, uint256[] calldata payload) external returns (bytes32); + + /** + Starts the cancellation of an L1 to L2 message. + A message can be canceled messageCancellationDelay() seconds after this function is called. + + Note: This function may only be called for a message that is currently pending and the caller + must be the sender of the that message. + */ + function startL1ToL2MessageCancellation( + uint256 toAddress, + uint256 selector, + uint256[] calldata payload, + uint256 nonce + ) external; + + /** + Cancels an L1 to L2 message, this function should be called messageCancellationDelay() seconds + after the call to startL1ToL2MessageCancellation(). + */ + function cancelL1ToL2Message( + uint256 toAddress, + uint256 selector, + uint256[] calldata payload, + uint256 nonce + ) external; } diff --git a/src/starkware/starknet/eth/IStarknetMessagingEvents.sol b/src/starkware/starknet/eth/IStarknetMessagingEvents.sol new file mode 100644 index 00000000..2648460c --- /dev/null +++ b/src/starkware/starknet/eth/IStarknetMessagingEvents.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0. +pragma solidity ^0.6.12; + +interface IStarknetMessagingEvents { + // This event needs to be compatible with the one defined in Output.sol. + event LogMessageToL1(uint256 indexed fromAddress, address indexed toAddress, uint256[] payload); + + // An event that is raised when a message is sent from L1 to L2. + event LogMessageToL2( + address indexed fromAddress, + uint256 indexed toAddress, + uint256 indexed selector, + uint256[] payload, + uint256 nonce + ); + + // An event that is raised when a message from L2 to L1 is consumed. + event ConsumedMessageToL1( + uint256 indexed fromAddress, + address indexed toAddress, + uint256[] payload + ); + + // An event that is raised when a message from L1 to L2 is consumed. + event ConsumedMessageToL2( + address indexed fromAddress, + uint256 indexed toAddress, + uint256 indexed selector, + uint256[] payload, + uint256 nonce + ); + + // An event that is raised when a message from L1 to L2 Cancellation is started. + event MessageToL2CancellationStarted( + address indexed fromAddress, + uint256 indexed toAddress, + uint256 indexed selector, + uint256[] payload, + uint256 nonce + ); + + // An event that is raised when a message from L1 to L2 is canceled. + event MessageToL2Canceled( + address indexed fromAddress, + uint256 indexed toAddress, + uint256 indexed selector, + uint256[] payload, + uint256 nonce + ); +} diff --git a/src/starkware/starknet/eth/StarknetMessaging.sol b/src/starkware/starknet/eth/StarknetMessaging.sol index 4472e529..35974720 100644 --- a/src/starkware/starknet/eth/StarknetMessaging.sol +++ b/src/starkware/starknet/eth/StarknetMessaging.sol @@ -18,6 +18,14 @@ contract StarknetMessaging is IStarknetMessaging { string constant L1L2_MESSAGE_NONCE_TAG = "STARKNET_1.0_MSGING_L1TOL2_NONCE"; + string constant L1L2_MESSAGE_CANCELLATION_MAP_TAG = ( + "STARKNET_1.0_MSGING_L1TOL2_CANCELLATION_MAPPPING" + ); + + string constant L1L2_MESSAGE_CANCELLATION_DELAY_TAG = ( + "STARKNET_1.0_MSGING_L1TOL2_CANCELLATION_DELAY" + ); + function l1ToL2Messages(bytes32 msgHash) external view returns (uint256) { return l1ToL2Messages()[msgHash]; } @@ -38,27 +46,66 @@ contract StarknetMessaging is IStarknetMessaging { return NamedStorage.getUintValue(L1L2_MESSAGE_NONCE_TAG); } + function messageCancellationDelay() public view returns (uint256) { + return NamedStorage.getUintValue(L1L2_MESSAGE_CANCELLATION_DELAY_TAG); + } + + function messageCancellationDelay(uint256 delayInSeconds) internal { + NamedStorage.setUintValue(L1L2_MESSAGE_CANCELLATION_DELAY_TAG, delayInSeconds); + } + + /** + Returns the timestamp at the time cancelL1ToL2Message was called with a message + matching 'msgHash'. + + The function returns 0 if cancelL1ToL2Message was never called. + */ + function l1ToL2MessageCancellations(bytes32 msgHash) external view returns (uint256) { + return l1ToL2MessageCancellations()[msgHash]; + } + + function l1ToL2MessageCancellations() + internal + pure + returns (mapping(bytes32 => uint256) storage) + { + return NamedStorage.bytes32ToUint256Mapping(L1L2_MESSAGE_CANCELLATION_MAP_TAG); + } + + /** + Returns the hash of an L1 -> L2 message from msg.sender. + */ + function getL1ToL2MsgHash( + uint256 toAddress, + uint256 selector, + uint256[] calldata payload, + uint256 nonce + ) internal returns (bytes32) { + return + keccak256( + abi.encodePacked( + uint256(msg.sender), + toAddress, + nonce, + selector, + payload.length, + payload + ) + ); + } + /** Sends a message to an L2 contract. */ function sendMessageToL2( - uint256 to_address, + uint256 toAddress, uint256 selector, uint256[] calldata payload ) external override returns (bytes32) { uint256 nonce = l1ToL2MessageNonce(); NamedStorage.setUintValue(L1L2_MESSAGE_NONCE_TAG, nonce + 1); - emit LogMessageToL2(msg.sender, to_address, selector, payload, nonce); - bytes32 msgHash = keccak256( - abi.encodePacked( - uint256(msg.sender), - to_address, - nonce, - selector, - payload.length, - payload - ) - ); + emit LogMessageToL2(msg.sender, toAddress, selector, payload, nonce); + bytes32 msgHash = getL1ToL2MsgHash(toAddress, selector, payload, nonce); l1ToL2Messages()[msgHash] += 1; return msgHash; @@ -69,18 +116,52 @@ contract StarknetMessaging is IStarknetMessaging { Returns the hash of the message. */ - function consumeMessageFromL2(uint256 from_address, uint256[] calldata payload) + function consumeMessageFromL2(uint256 fromAddress, uint256[] calldata payload) external override returns (bytes32) { bytes32 msgHash = keccak256( - abi.encodePacked(from_address, uint256(msg.sender), payload.length, payload) + abi.encodePacked(fromAddress, uint256(msg.sender), payload.length, payload) ); require(l2ToL1Messages()[msgHash] > 0, "INVALID_MESSAGE_TO_CONSUME"); - emit ConsumedMessageToL1(from_address, msg.sender, payload); + emit ConsumedMessageToL1(fromAddress, msg.sender, payload); l2ToL1Messages()[msgHash] -= 1; return msgHash; } + + function startL1ToL2MessageCancellation( + uint256 toAddress, + uint256 selector, + uint256[] calldata payload, + uint256 nonce + ) external override { + emit MessageToL2CancellationStarted(msg.sender, toAddress, selector, payload, nonce); + bytes32 msgHash = getL1ToL2MsgHash(toAddress, selector, payload, nonce); + uint256 msgCount = l1ToL2Messages()[msgHash]; + require(msgCount > 0, "NO_MESSAGE_TO_CANCEL"); + l1ToL2MessageCancellations()[msgHash] = block.timestamp; + } + + function cancelL1ToL2Message( + uint256 toAddress, + uint256 selector, + uint256[] calldata payload, + uint256 nonce + ) external override { + emit MessageToL2Canceled(msg.sender, toAddress, selector, payload, nonce); + bytes32 msgHash = getL1ToL2MsgHash(toAddress, selector, payload, nonce); + uint256 msgCount = l1ToL2Messages()[msgHash]; + require(msgCount > 0, "NO_MESSAGE_TO_CANCEL"); + + uint256 requestTime = l1ToL2MessageCancellations()[msgHash]; + require(requestTime != 0, "MESSAGE_CANCELLATION_NOT_REQUESTED"); + + uint256 cancelAllowedTime = requestTime + messageCancellationDelay(); + require(cancelAllowedTime >= requestTime, "CANCEL_ALLOWED_TIME_OVERFLOW"); + require(block.timestamp >= cancelAllowedTime, "MESSAGE_CANCELLATION_NOT_ALLOWED_YET"); + + l1ToL2Messages()[msgHash] = msgCount - 1; + } } diff --git a/src/starkware/starknet/public/abi.py b/src/starkware/starknet/public/abi.py index 5c5f93e0..608a6333 100644 --- a/src/starkware/starknet/public/abi.py +++ b/src/starkware/starknet/public/abi.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List + from eth_hash.auto import keccak from starkware.cairo.lang.vm.crypto import pedersen_hash @@ -16,6 +18,10 @@ DEFAULT_ENTRY_POINT_NAME = "__default__" DEFAULT_L1_ENTRY_POINT_NAME = "__l1_default__" DEFAULT_ENTRY_POINT_SELECTOR = 0 +EXECUTE_ENTRY_POINT_NAME = "__execute__" +TRANSFER_ENTRY_POINT_NAME = "transfer" + +AbiType = List[Dict[str, Any]] def starknet_keccak(data: bytes) -> int: @@ -33,6 +39,10 @@ def get_selector_from_name(func_name: str) -> int: return starknet_keccak(data=func_name.encode("ascii")) +EXECUTE_ENTRY_POINT_SELECTOR = get_selector_from_name(func_name=EXECUTE_ENTRY_POINT_NAME) +TRANSFER_ENTRY_POINT_SELECTOR = get_selector_from_name(func_name=TRANSFER_ENTRY_POINT_NAME) + + def get_storage_var_address(var_name: str, *args) -> int: """ Returns the storage address of a StarkNet storage variable given its name and arguments. diff --git a/src/starkware/starknet/public/abi_structs.py b/src/starkware/starknet/public/abi_structs.py index a186edf7..e0263f8e 100644 --- a/src/starkware/starknet/public/abi_structs.py +++ b/src/starkware/starknet/public/abi_structs.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, List, Set, Tuple +from typing import Set, Tuple from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, @@ -13,6 +13,7 @@ from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_system import mark_type_resolved +from starkware.starknet.public.abi import AbiType @dataclasses.dataclass @@ -31,9 +32,9 @@ def prepare_type_for_abi(cairo_type: CairoType) -> AbiTypeInfo: new_members = [] structs = set() for inner_type in cairo_type.members: - res = prepare_type_for_abi(inner_type) + res = prepare_type_for_abi(inner_type.typ) structs |= res.structs - new_members.append(res.modified_type) + new_members.append(dataclasses.replace(inner_type, typ=res.modified_type)) return AbiTypeInfo( modified_type=dataclasses.replace(cairo_type, members=new_members), @@ -110,7 +111,7 @@ def struct_definition_from_abi_entry(abi_entry: dict) -> StructDefinition: ) -def identifier_manager_from_abi(abi: List[Any]) -> IdentifierManager: +def identifier_manager_from_abi(abi: AbiType) -> IdentifierManager: """ Returns an IdentifierManager object which contains all struct definitions found in the ABI. """ diff --git a/src/starkware/starknet/security/CMakeLists.txt b/src/starkware/starknet/security/CMakeLists.txt index 84e9eb0e..aa6321e7 100644 --- a/src/starkware/starknet/security/CMakeLists.txt +++ b/src/starkware/starknet/security/CMakeLists.txt @@ -6,7 +6,6 @@ python_lib(starknet_security_lib LIBS cairo_compile_lib - cairo_run_lib starkware_dataclasses_utils_lib pip_marshmallow pip_marshmallow_dataclass @@ -47,6 +46,7 @@ python_lib(starknet_hints_whitelist_lib whitelists/cairo_keccak.json whitelists/cairo_secp.json whitelists/cairo_sha256.json + whitelists/ec_recover.json whitelists/latest.json LIBS diff --git a/src/starkware/starknet/security/secure_hints.py b/src/starkware/starknet/security/secure_hints.py index 9a873d07..59353db7 100644 --- a/src/starkware/starknet/security/secure_hints.py +++ b/src/starkware/starknet/security/secure_hints.py @@ -22,6 +22,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None: return None res = super()._serialize(value, attr, obj, **kwargs) + assert res is not None return sorted(res, key=lambda x: (x["name"], x["expr"])) def _deserialize(self, *args, **kwargs): diff --git a/src/starkware/starknet/security/simple_references_test.py b/src/starkware/starknet/security/simple_references_test.py index 75b7cc69..fb19423f 100644 --- a/src/starkware/starknet/security/simple_references_test.py +++ b/src/starkware/starknet/security/simple_references_test.py @@ -21,6 +21,7 @@ ["&ap", None], ["3 ** 2", None], ["3 / 2", None], + ["nondet %{ 5 %}", None], ], ) def test_is_simple_reference(expr_str: str, simplicity: Optional[int]): diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo index 4c5753a9..d09f73b8 100644 --- a/src/starkware/starknet/security/starknet_common.cairo +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -6,7 +6,7 @@ from starkware.cairo.common.find_element import find_element, search_sorted, sea from starkware.cairo.common.keccak import unsafe_keccak from starkware.cairo.common.math import ( abs_value, assert_250_bit, assert_in_range, assert_le, assert_le_felt, assert_lt, - assert_lt_felt, assert_nn, assert_nn_le, assert_not_equal, assert_not_zero, sign, + assert_lt_felt, assert_nn, assert_nn_le, assert_not_equal, assert_not_zero, horner_eval, sign, signed_div_rem, split_felt, split_int, sqrt, unsigned_div_rem) from starkware.cairo.common.math_cmp import ( is_in_range, is_le, is_le_felt, is_nn, is_nn_le, is_not_zero) @@ -19,6 +19,7 @@ from starkware.cairo.common.uint256 import ( uint256_neg, uint256_not, uint256_or, uint256_shl, uint256_shr, uint256_signed_div_rem, uint256_signed_le, uint256_signed_lt, uint256_signed_nn, uint256_signed_nn_le, uint256_sqrt, uint256_sub, uint256_unsigned_div_rem, uint256_xor) +from starkware.cairo.common.usort import usort from starkware.starknet.common.messages import send_message_to_l1 from starkware.starknet.common.storage import normalize_address from starkware.starknet.common.syscalls import ( diff --git a/src/starkware/starknet/security/whitelists/cairo_secp.json b/src/starkware/starknet/security/whitelists/cairo_secp.json index 6fa8dc08..ac779421 100644 --- a/src/starkware/starknet/security/whitelists/cairo_secp.json +++ b/src/starkware/starknet/security/whitelists/cairo_secp.json @@ -128,6 +128,12 @@ "memory[ap] = int(x == 0)" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "memory[ap] = to_felt_or_relocatable(x == 0)" + ] + }, { "allowed_expressions": [ { diff --git a/src/starkware/starknet/security/whitelists/ec_recover.json b/src/starkware/starknet/security/whitelists/ec_recover.json new file mode 100644 index 00000000..def3b579 --- /dev/null +++ b/src/starkware/starknet/security/whitelists/ec_recover.json @@ -0,0 +1,48 @@ +{ + "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import pack", + "from starkware.python.math_utils import div_mod, safe_div", + "", + "N = pack(ids.n, PRIME)", + "x = pack(ids.x, PRIME) % N", + "s = pack(ids.s, PRIME) % N", + "value = res = div_mod(x, s, N)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import pack", + "from starkware.python.math_utils import div_mod, safe_div", + "", + "a = pack(ids.a, PRIME)", + "b = pack(ids.b, PRIME)", + "", + "value = res = a - b" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import pack", + "from starkware.python.math_utils import div_mod, safe_div", + "", + "a = pack(ids.a, PRIME)", + "b = pack(ids.b, PRIME)", + "product = a * b", + "m = pack(ids.m, PRIME)", + "", + "value = res = product % m" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "value = k = product // m" + ] + } + ] +} diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json index f3c5ee3d..a90f2f60 100644 --- a/src/starkware/starknet/security/whitelists/latest.json +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -148,6 +148,12 @@ "ids.next_key = key = keys.pop()" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "assert len(positions) == 0" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -156,6 +162,14 @@ "memory[ids.range_check_ptr] = current_access_index" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "current_pos = positions.pop()", + "ids.next_item_index = current_pos - last_pos", + "last_pos = current_pos + 1" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -203,6 +217,30 @@ "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "from collections import defaultdict", + "", + "input_ptr = ids.input", + "input_len = int(ids.input_len)", + "if __usort_max_size is not None:", + " assert input_len <= __usort_max_size, (", + " f\"usort() can only be used with input_len<={__usort_max_size}. \"", + " f\"Got: input_len={input_len}.\"", + " )", + "", + "positions_dict = defaultdict(list)", + "for i in range(input_len):", + " val = memory[input_ptr + i]", + " positions_dict[val].append(i)", + "", + "output = sorted(positions_dict.keys())", + "ids.output_len = len(output)", + "ids.output = segments.gen_arg(output)", + "ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -417,6 +455,13 @@ "del initial_dict" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "last_pos = 0", + "positions = positions_dict[ids.value][::-1]" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -575,6 +620,12 @@ "vm_enter_scope()" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))" + ] + }, { "allowed_expressions": [], "hint_lines": [ diff --git a/src/starkware/starknet/services/api/contract_definition.py b/src/starkware/starknet/services/api/contract_definition.py index a4be70ad..ff709010 100644 --- a/src/starkware/starknet/services/api/contract_definition.py +++ b/src/starkware/starknet/services/api/contract_definition.py @@ -1,7 +1,7 @@ import dataclasses from dataclasses import field from enum import Enum, auto -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import marshmallow_dataclass @@ -9,7 +9,7 @@ from starkware.cairo.lang.compiler.program import Program from starkware.starknet.definitions import fields from starkware.starknet.definitions.error_codes import StarknetErrorCode -from starkware.starknet.public.abi import get_selector_from_name +from starkware.starknet.public.abi import AbiType, get_selector_from_name from starkware.starkware_utils.error_handling import stark_assert from starkware.starkware_utils.subsequence import is_subsequence from starkware.starkware_utils.validated_dataclass import ( @@ -44,7 +44,7 @@ class ContractDefinition(ValidatedMarshmallowDataclass): program: Program entry_points_by_type: Dict[EntryPointType, List[ContractEntryPoint]] - abi: Optional[List[Any]] = None + abi: Optional[AbiType] = None def __post_init__(self): super().__post_init__() @@ -67,7 +67,7 @@ def __post_init__(self): ) stark_assert( - len(constructor_eps) <= 1, + len(constructor_eps) <= 1, # type: ignore code=StarknetErrorCode.INVALID_CONTRACT_DEFINITION, message="A contract may have at most 1 constructor.", ) diff --git a/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt b/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt index a88aa987..dfa9eeca 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt +++ b/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt @@ -45,9 +45,11 @@ python_lib(starknet_feeder_gateway_response_objects_lib starknet_internal_transaction_lib starknet_transaction_execution_objects_lib starkware_dataclasses_utils_lib + starkware_python_utils_lib pip_marshmallow pip_marshmallow_dataclass pip_marshmallow_enum pip_marshmallow_oneofschema pip_typing_extensions + pip_web3 ) diff --git a/src/starkware/starknet/services/api/feeder_gateway/block_hash.py b/src/starkware/starknet/services/api/feeder_gateway/block_hash.py index 5db536b8..6d96a33c 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/block_hash.py +++ b/src/starkware/starknet/services/api/feeder_gateway/block_hash.py @@ -9,7 +9,7 @@ from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.storage.dict_storage import DictStorage from starkware.storage.storage import FactFetchingContext -from starkware.storage.storage_utils import LeafFact +from starkware.storage.storage_utils import SimpleLeafFact async def calculate_block_hash( @@ -119,8 +119,10 @@ async def calculate_patricia_root( """ Calculates and returns the patricia root whose (leftmost) leaves are given. """ - empty_tree = await PatriciaTree.empty_tree(ffc=ffc, height=height, leaf_fact=LeafFact.empty()) - modifications = [(index, LeafFact(value=value)) for index, value in enumerate(leaves)] + empty_tree = await PatriciaTree.empty_tree( + ffc=ffc, height=height, leaf_fact=SimpleLeafFact.empty() + ) + modifications = [(index, SimpleLeafFact(value=value)) for index, value in enumerate(leaves)] final_tree = await empty_tree.update(ffc=ffc, modifications=modifications) return from_bytes(final_tree.root) diff --git a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py index 1c46effe..64a6b6b2 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py +++ b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py @@ -9,6 +9,7 @@ StarknetBlock, TransactionInfo, TransactionReceipt, + TransactionTrace, ) from starkware.starknet.services.api.gateway.transaction import InvokeFunction from starkware.starkware_utils.validated_fields import RangeValidatedField @@ -43,6 +44,22 @@ async def call_contract( ) return json.loads(raw_response) + async def estimate_fee( + self, + invoke_tx: InvokeFunction, + block_hash: Optional[CastableToHash] = None, + block_number: Optional[BlockIdentifier] = None, + ) -> JsonObject: + formatted_block_identifier = get_formatted_block_identifier( + block_hash=block_hash, block_number=block_number + ) + raw_response = await self._send_request( + send_method="POST", + uri=(f"/estimate_fee?{formatted_block_identifier}"), + data=invoke_tx.dumps(), + ) + return json.loads(raw_response) + async def get_block( self, block_hash: Optional[CastableToHash] = None, @@ -123,32 +140,33 @@ async def get_storage_at( raw_response = await self._send_request(send_method="GET", uri=uri) return json.loads(raw_response) - async def get_transaction_status( - self, tx_hash: Optional[CastableToHash], tx_id: Optional[int] = None - ) -> JsonObject: + async def get_transaction_status(self, tx_hash: CastableToHash) -> JsonObject: raw_response = await self._send_request( send_method="GET", - uri=f"/get_transaction_status?{tx_identifier(tx_hash=tx_hash, tx_id=tx_id)}", + uri=f"/get_transaction_status?{tx_identifier(tx_hash=tx_hash)}", ) return json.loads(raw_response) - async def get_transaction( - self, tx_hash: Optional[CastableToHash], tx_id: Optional[int] = None - ) -> TransactionInfo: + async def get_transaction(self, tx_hash: CastableToHash) -> TransactionInfo: raw_response = await self._send_request( - send_method="GET", uri=f"/get_transaction?{tx_identifier(tx_hash=tx_hash, tx_id=tx_id)}" + send_method="GET", uri=f"/get_transaction?{tx_identifier(tx_hash=tx_hash)}" ) return TransactionInfo.loads(raw_response) - async def get_transaction_receipt( - self, tx_hash: Optional[CastableToHash], tx_id: Optional[int] = None - ) -> TransactionReceipt: + async def get_transaction_receipt(self, tx_hash: CastableToHash) -> TransactionReceipt: raw_response = await self._send_request( send_method="GET", - uri=f"/get_transaction_receipt?{tx_identifier(tx_hash=tx_hash, tx_id=tx_id)}", + uri=f"/get_transaction_receipt?{tx_identifier(tx_hash=tx_hash)}", ) return TransactionReceipt.loads(raw_response) + async def get_transaction_trace(self, tx_hash: CastableToHash) -> TransactionTrace: + raw_response = await self._send_request( + send_method="GET", + uri=f"/get_transaction_trace?{tx_identifier(tx_hash=tx_hash)}", + ) + return TransactionTrace.loads(raw_response) + async def get_block_hash_by_id(self, block_id: int) -> str: raw_response = await self._send_request( send_method="GET", @@ -192,12 +210,9 @@ def format_hash(hash_value: CastableToHash, hash_field: RangeValidatedField) -> return hash_value -def tx_identifier(tx_hash: Optional[CastableToHash], tx_id: Optional[int]) -> str: - if tx_hash is None: - return f"transactionId={json.dumps(tx_id)}" - else: - hash_str = format_hash(hash_value=tx_hash, hash_field=fields.TransactionHashField) - return f"transactionHash={hash_str}" +def tx_identifier(tx_hash: CastableToHash) -> str: + hash_str = format_hash(hash_value=tx_hash, hash_field=fields.TransactionHashField) + return f"transactionHash={hash_str}" def get_formatted_block_identifier( diff --git a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py index f22684e0..8e96bbc7 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py +++ b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py @@ -10,22 +10,34 @@ import marshmallow_dataclass from marshmallow_oneofschema import OneOfSchema from typing_extensions import Literal +from web3 import Web3 from services.everest.api.feeder_gateway.response_objects import BaseResponseObject from services.everest.business_logic.transaction_execution_objects import TransactionFailureReason from services.everest.definitions import fields as everest_fields from starkware.cairo.lang.vm.cairo_pie import ExecutionResources +from starkware.python.utils import to_bytes from starkware.starknet.business_logic.internal_transaction import ( InternalDeploy, InternalInvokeFunction, InternalTransaction, ) -from starkware.starknet.business_logic.transaction_execution_objects import Event +from starkware.starknet.business_logic.transaction_execution_objects import ( + CallInfo, + Event, + OrderedEvent, + OrderedL2ToL1Message, +) from starkware.starknet.definitions import fields from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starknet.services.api.contract_definition import EntryPointType from starkware.starkware_utils.marshmallow_dataclass_fields import VariadicLengthTupleField -from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass +from starkware.starkware_utils.serializable_dataclass import SerializableMarshmallowDataclass +from starkware.starkware_utils.validated_dataclass import ( + ValidatedDataclass, + ValidatedMarshmallowDataclass, +) +from starkware.starkware_utils.validated_fields import sequential_id_metadata BlockIdentifier = Union[int, Literal["pending"]] OptionalBlockIdentifier = Optional[BlockIdentifier] @@ -60,10 +72,9 @@ class TransactionStatus(Enum): ACCEPTED_ON_L1 = auto() @property - def has_receipt(self) -> bool: + def was_executed(self) -> bool: """ - Returns whether a transaction with that status has a receipt (i.e., has been executed - successfully). + Returns whether a transaction with that status has been executed successfully. """ return self in ( TransactionStatus.PENDING, @@ -215,7 +226,7 @@ def from_internal(cls, internal_tx: InternalTransaction) -> "TransactionSpecific class DeploySpecificInfo(TransactionSpecificInfo): contract_address: int = field(metadata=fields.contract_address_metadata) contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) - constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + constructor_calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) transaction_hash: int = field(metadata=fields.transaction_hash_metadata) tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY @@ -234,9 +245,10 @@ class InvokeSpecificInfo(TransactionSpecificInfo): contract_address: int = field(metadata=fields.contract_address_metadata) entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) entry_point_type: EntryPointType - calldata: List[int] = field(metadata=fields.call_data_metadata) + calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) signature: List[int] = field(metadata=fields.signature_metadata) transaction_hash: int = field(metadata=fields.transaction_hash_metadata) + max_fee: int = field(metadata=fields.fee_metadata) tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION @classmethod @@ -248,6 +260,7 @@ def from_internal_invoke(cls, internal_tx: InternalInvokeFunction) -> "InvokeSpe calldata=internal_tx.calldata, signature=internal_tx.signature, transaction_hash=internal_tx.hash_value, + max_fee=internal_tx.max_fee, ) @@ -311,7 +324,7 @@ class L1ToL2Message(BaseResponseObject): from_address: str = field(metadata=everest_fields.EthAddressField.metadata("from_address")) to_address: int = field(metadata=fields.contract_address_metadata) selector: int = field(metadata=fields.entry_point_selector_metadata) - payload: List[int] = field(metadata=fields.felt_list_metadata) + payload: List[int] = field(metadata=fields.felt_as_hex_list_metadata) nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) @@ -323,7 +336,7 @@ class L2ToL1Message(BaseResponseObject): from_address: int = field(metadata=fields.contract_address_metadata) to_address: str = field(metadata=everest_fields.EthAddressField.metadata("to_address")) - payload: List[int] = field(metadata=fields.felt_list_metadata) + payload: List[int] = field(metadata=fields.felt_as_hex_list_metadata) @marshmallow_dataclass.dataclass(frozen=True) @@ -452,6 +465,105 @@ class BlockStateUpdate(BaseResponseObject): state_diff: StateDiff +@dataclasses.dataclass(frozen=True) +class OrderedL2ToL1MessageResponse(ValidatedDataclass): + """ + See datails in OrderedL2ToL1Message's documentation. + """ + + order: int = field(metadata=sequential_id_metadata("L2-to-L1 message order")) + to_address: str = field(metadata=everest_fields.EthAddressField.metadata("to_address")) + payload: List[int] = field(metadata=fields.felt_as_hex_list_metadata) + + @classmethod + def from_internal( + cls, messages: List[OrderedL2ToL1Message] + ) -> List["OrderedL2ToL1MessageResponse"]: + return [ + cls( + order=message.order, + to_address=Web3.toChecksumAddress(to_bytes(message.to_address, 20)), + payload=message.payload, + ) + for message in messages + ] + + +@dataclasses.dataclass(frozen=True) +class OrderedEventResponse(ValidatedDataclass): + """ + See datails in OrderedEvent's documentation. + """ + + order: int = field(metadata=sequential_id_metadata("Event order")) + keys: List[int] = field(metadata=fields.felt_as_hex_list_metadata) + data: List[int] = field(metadata=fields.felt_as_hex_list_metadata) + + @classmethod + def from_internal(cls, events: List[OrderedEvent]) -> List["OrderedEventResponse"]: + return [cls(order=event.order, keys=event.keys, data=event.data) for event in events] + + +# NOTE: This dataclass isn't validated due to a forward-declaration issue. +@marshmallow_dataclass.dataclass(frozen=True) +class FunctionInvocation(SerializableMarshmallowDataclass): + """ + A lean version of CallInfo class, containing merely the information relevant for the user. + """ + + # Static info. + caller_address: int = field(metadata=fields.contract_address_metadata) + contract_address: int = field(metadata=fields.contract_address_metadata) + code_address: Optional[int] = field(metadata=fields.optional_code_address_metadata) + selector: Optional[int] = field(metadata=fields.optional_entry_point_selector_metadata) + entry_point_type: Optional[EntryPointType] + calldata: List[int] = field(metadata=fields.call_data_as_hex_metadata) + + # Execution info. + result: List[int] = field(metadata=fields.retdata_as_hex_metadata) + execution_resources: ExecutionResources + internal_calls: List["FunctionInvocation"] = field( + metadata=dict( + marshmallow_field=mfields.List(mfields.Nested(lambda: FunctionInvocation.Schema())) + ) + ) + events: List[OrderedEventResponse] + messages: List[OrderedL2ToL1MessageResponse] + + @classmethod + def from_internal_version(cls, call_info: CallInfo) -> "FunctionInvocation": + return cls( + caller_address=call_info.caller_address, + contract_address=call_info.contract_address, + code_address=call_info.code_address, + selector=call_info.entry_point_selector, + entry_point_type=call_info.entry_point_type, + calldata=call_info.calldata, + result=call_info.retdata, + execution_resources=call_info.execution_resources, + internal_calls=[ + cls.from_internal_version(call_info=internal_call) + for internal_call in call_info.internal_calls + ], + events=OrderedEventResponse.from_internal(events=call_info.events), + messages=OrderedL2ToL1MessageResponse.from_internal( + messages=call_info.l2_to_l1_messages + ), + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionTrace(BaseResponseObject): + """ + Represents the trace of a StarkNet transaction execution, + including internal calls. + """ + + # An object describing the invocation of a specific function. + function_invocation: FunctionInvocation + signature: List[int] = field(metadata=fields.signature_metadata) + + @marshmallow_dataclass.dataclass(frozen=True) class StarknetBlock(BaseResponseObject): """ @@ -471,9 +583,11 @@ class StarknetBlock(BaseResponseObject): ) ) timestamp: int = field(metadata=fields.timestamp_metadata) - transaction_receipts: Tuple[TransactionExecution, ...] = field( + transaction_receipts: Optional[Tuple[TransactionExecution, ...]] = field( metadata=dict( - marshmallow_field=VariadicLengthTupleField(mfields.Nested(TransactionExecution.Schema)) + marshmallow_field=VariadicLengthTupleField( + mfields.Nested(TransactionExecution.Schema), allow_none=True + ) ) ) @@ -486,7 +600,7 @@ def create( state_root: Optional[bytes], transactions: Iterable[InternalTransaction], timestamp: int, - transaction_receipts: Tuple[TransactionExecution, ...], + transaction_receipts: Optional[Tuple[TransactionExecution, ...]], status: Optional[BlockStatus], ) -> "StarknetBlock": return cls( @@ -505,13 +619,10 @@ def create( def __post_init__(self): super().__post_init__() - tx_status_error_message = ( - "Transactions' status in block must match the status of the block." - ) - if self.status is None: - assert all( - tx_receipt.status is None for tx_receipt in self.transaction_receipts - ), tx_status_error_message + if self.status in (BlockStatus.ABORTED, BlockStatus.REVERTED): + assert ( + self.transaction_receipts is None + ), "Aborted and reverted blocks must not have transaction receipts." return diff --git a/src/starkware/starknet/services/api/gateway/transaction.py b/src/starkware/starknet/services/api/gateway/transaction.py index af409de3..80de22c5 100644 --- a/src/starkware/starknet/services/api/gateway/transaction.py +++ b/src/starkware/starknet/services/api/gateway/transaction.py @@ -118,9 +118,14 @@ class InvokeFunction(Transaction): """ contract_address: int = field(metadata=fields.contract_address_metadata) - # A field element that encodes the signature of the called function. + # A field element that encodes the signature of the invoked function. entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) calldata: List[int] = field(metadata=fields.call_data_metadata) + # The maximal fee to be paid in Wei for executing invoked function. + max_fee: int = field(metadata=fields.fee_metadata) + # The transaction is not valid if its version is lower than current version, + # defined by the SN OS. + version: int = field(metadata=fields.tx_version_metadata) # Additional information given by the caller that represents the signature of the transaction. # The exact way this field is handled is defined by the called contract's function, like # calldata. @@ -135,9 +140,11 @@ def calculate_hash(self, general_config: StarknetGeneralConfig) -> int: """ return calculate_transaction_hash_common( tx_hash_prefix=TransactionHashPrefix.INVOKE, + version=self.version, contract_address=self.contract_address, entry_point_selector=self.entry_point_selector, calldata=self.calldata, + max_fee=self.max_fee, chain_id=general_config.chain_id.value, additional_data=[], ) diff --git a/src/starkware/starknet/services/api/messages.py b/src/starkware/starknet/services/api/messages.py index 1215a631..a955c864 100644 --- a/src/starkware/starknet/services/api/messages.py +++ b/src/starkware/starknet/services/api/messages.py @@ -4,7 +4,7 @@ from typing import List from services.everest.definitions import fields as everest_fields -from starkware.cairo.bootloader.compute_fact import keccak_ints +from starkware.cairo.bootloaders.compute_fact import keccak_ints from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction from starkware.starknet.definitions import fields from starkware.starknet.services.api.contract_definition import EntryPointType @@ -56,7 +56,7 @@ class StarknetMessageToL2(StarknetMessage): to_address: int = field(metadata=fields.ContractAddressField.metadata(field_name="to_address")) l1_handler_selector: int payload: List[int] = field(metadata=fields.felt_list_metadata) - nonce: int = field(metadata=everest_fields.felt_metadata(name_in_error_message="nonce")) + nonce: int = field(metadata=fields.nonce_metadata) def encode(self) -> List[int]: return [ diff --git a/src/starkware/starknet/storage/CMakeLists.txt b/src/starkware/starknet/storage/CMakeLists.txt index 94233cc5..4dd05efa 100644 --- a/src/starkware/starknet/storage/CMakeLists.txt +++ b/src/starkware/starknet/storage/CMakeLists.txt @@ -6,6 +6,7 @@ python_lib(starknet_storage_lib LIBS cairo_constants_lib + starkware_commitment_tree_facts_lib starkware_dataclasses_utils_lib starkware_python_utils_lib starkware_storage_lib diff --git a/src/starkware/starknet/storage/starknet_storage.py b/src/starkware/starknet/storage/starknet_storage.py index 6f72efb9..5f2d2eac 100644 --- a/src/starkware/starknet/storage/starknet_storage.py +++ b/src/starkware/starknet/storage/starknet_storage.py @@ -8,16 +8,18 @@ from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.python.utils import from_bytes, to_bytes from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact +from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import EmptyNodeFact from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.starkware_utils.validated_dataclass import ValidatedDataclass -from starkware.storage.storage import HASH_BYTES, Fact, FactFetchingContext, HashFunctionType +from starkware.storage.storage import HASH_BYTES, FactFetchingContext, HashFunctionType TStorageLeaf = TypeVar("TStorageLeaf", bound="StorageLeaf") @dataclasses.dataclass(frozen=True) -class StorageLeaf(Fact, ValidatedDataclass): +class StorageLeaf(LeafFact, ValidatedDataclass): """ A class representing a commitment tree leaf in a Cairo contract storage. The content of the leaf is a single integer. @@ -37,7 +39,13 @@ def serialize(self) -> bytes: return to_bytes(self.value) def _hash(self, hash_func: HashFunctionType) -> bytes: - # Note that the return value size needs to be HASH_BYTES. + """ + Calculates and returns the leaf hash. + Note that the return value size needs to be HASH_BYTES. + """ + if self.is_empty: + return EmptyNodeFact.EMPTY_NODE_HASH + return self.serialize() @classmethod @@ -48,6 +56,10 @@ def deserialize(cls: Type[TStorageLeaf], data: bytes) -> TStorageLeaf: def empty(cls) -> "StorageLeaf": return cls(value=0) + @property + def is_empty(self) -> bool: + return self.value == 0 + class StarknetStorageInterface(ABC): """ diff --git a/src/starkware/starknet/testing/CMakeLists.txt b/src/starkware/starknet/testing/CMakeLists.txt index 080a9ea9..47198564 100644 --- a/src/starkware/starknet/testing/CMakeLists.txt +++ b/src/starkware/starknet/testing/CMakeLists.txt @@ -49,6 +49,7 @@ python_lib(starknet_testing_lib starknet_compile_lib starknet_contract_definition_lib starknet_definitions_lib + starknet_feeder_gateway_response_objects_lib starknet_general_config_lib starknet_internal_transaction_interface_lib starknet_internal_transaction_lib @@ -79,6 +80,7 @@ full_python_test(starknet_testing_test LIBS cairo_common_lib starknet_mock_messaging_contracts_lib + starknet_test_external_contract_test_utils_lib starknet_testing_lib starkware_eth_test_utils_lib pip_pytest diff --git a/src/starkware/starknet/testing/MockStarknetMessaging.sol b/src/starkware/starknet/testing/MockStarknetMessaging.sol index b13d353d..a1435e17 100644 --- a/src/starkware/starknet/testing/MockStarknetMessaging.sol +++ b/src/starkware/starknet/testing/MockStarknetMessaging.sol @@ -4,16 +4,20 @@ pragma solidity ^0.6.12; import "contracts/starkware/starknet/eth/StarknetMessaging.sol"; contract MockStarknetMessaging is StarknetMessaging { + constructor(uint256 MessageCancellationDelay) public { + messageCancellationDelay(MessageCancellationDelay); + } + /** Mocks a message from L2 to L1. */ function mockSendMessageFromL2( - uint256 from_address, - uint256 to_address, + uint256 fromAddress, + uint256 toAddress, uint256[] calldata payload ) external { bytes32 msgHash = keccak256( - abi.encodePacked(from_address, to_address, payload.length, payload) + abi.encodePacked(fromAddress, toAddress, payload.length, payload) ); l2ToL1Messages()[msgHash] += 1; } @@ -22,14 +26,14 @@ contract MockStarknetMessaging is StarknetMessaging { Mocks consumption of a message from L1 to L2. */ function mockConsumeMessageToL2( - uint256 from_address, - uint256 to_address, + uint256 fromAddress, + uint256 toAddress, uint256 selector, uint256[] calldata payload, uint256 nonce ) external { bytes32 msgHash = keccak256( - abi.encodePacked(from_address, to_address, nonce, selector, payload.length, payload) + abi.encodePacked(fromAddress, toAddress, nonce, selector, payload.length, payload) ); require(l1ToL2Messages()[msgHash] > 0, "INVALID_MESSAGE_TO_CONSUME"); diff --git a/src/starkware/starknet/testing/contract.py b/src/starkware/starknet/testing/contract.py index b2251e11..6c528379 100644 --- a/src/starkware/starknet/testing/contract.py +++ b/src/starkware/starknet/testing/contract.py @@ -15,15 +15,17 @@ ) from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.type_system import mark_type_resolved -from starkware.python.utils import safe_zip -from starkware.starknet.business_logic.transaction_execution_objects import OrderedEventContent +from starkware.python.utils import assert_exhausted, safe_zip +from starkware.starknet.business_logic.transaction_execution_objects import OrderedEvent +from starkware.starknet.public.abi import AbiType from starkware.starknet.testing.contract_utils import ( + RAW_OUTPUT_ARG_LIST, EventManager, StructManager, flatten, parse_arguments, ) -from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo +from starkware.starknet.testing.objects import Dataclass, StarknetTransactionExecutionInfo from starkware.starknet.testing.state import CastableToAddress, StarknetState from starkware.starknet.utils.api_utils import cast_to_felts @@ -48,11 +50,12 @@ class StarknetContract: def __init__( self, state: StarknetState, - abi: List[Any], + abi: AbiType, contract_address: CastableToAddress, deploy_execution_info: StarknetTransactionExecutionInfo, ): self.state = state + self.abi = abi self.deploy_execution_info = deploy_execution_info self.struct_manager = StructManager(abi=abi) @@ -71,7 +74,7 @@ def __init__( self.contract_address = contract_address def __dir__(self): - return object.__dir__(self) + list(self._abi_function_mapping.keys()) + return list(object.__dir__(self)) + list(self._abi_function_mapping.keys()) def __getattr__(self, name: str): if name in self._abi_function_mapping: @@ -170,8 +173,8 @@ def _get_annotation(self, arg_type: CairoType, is_nested: bool = False) -> Pytho if isinstance(arg_type, TypeTuple): return Tuple[ tuple( - self._get_annotation(arg_type=member, is_nested=True) - for member in arg_type.members + self._get_annotation(arg_type=cairo_type, is_nested=True) + for cairo_type in arg_type.types ) ] if isinstance(arg_type, TypeStruct): @@ -227,6 +230,23 @@ def _build_function_call( calldata=cast_to_felts(values=calldata), retdata_arg_types=retdata_arg_types, retdata_tuple=namedtuple(f"{function_name}_return_type", retdata_arg_names), + has_raw_output=(retdata_arg_names == RAW_OUTPUT_ARG_LIST), + ) + + def replace_abi( + self, + impl_contract_abi: AbiType, + ) -> "StarknetContract": + """ + Replaces the contract's ABI. + Typically used to replace the ABI of a proxy contract with the ABI of the + implementation contract. + """ + return StarknetContract( + state=self.state, + abi=impl_contract_abi, + contract_address=self.contract_address, + deploy_execution_info=self.deploy_execution_info, ) @@ -248,6 +268,7 @@ class StarknetContractFunctionInvocation: calldata: List[int] retdata_arg_types: List[CairoType] retdata_tuple: type + has_raw_output: bool async def call( self, caller_address: int = 0, signature: List[int] = None @@ -260,17 +281,21 @@ async def call( ) async def invoke( - self, caller_address: int = 0, signature: List[int] = None + self, caller_address: int = 0, max_fee: int = 0, signature: List[int] = None ) -> StarknetTransactionExecutionInfo: """ Executes the function call and apply changes on the state. """ return await self._invoke_on_given_state( - state=self.state, caller_address=caller_address, signature=signature + state=self.state, caller_address=caller_address, max_fee=max_fee, signature=signature ) async def _invoke_on_given_state( - self, state: StarknetState, caller_address: int = 0, signature: List[int] = None + self, + state: StarknetState, + caller_address: int = 0, + max_fee: int = 0, + signature: List[int] = None, ) -> StarknetTransactionExecutionInfo: """ Executes the function call and apply changes on the given state. @@ -280,26 +305,34 @@ async def _invoke_on_given_state( selector=self.name, calldata=self.calldata, caller_address=caller_address, + max_fee=max_fee, signature=None if signature is None else cast_to_felts(values=signature), ) + # Check if function has @raw_output. + if self.has_raw_output: + # Return the result as a raw tuple. + result = tuple(execution_info.call_info.retdata) + else: + args = self._build_arguments( + arg_values=execution_info.call_info.retdata, + arg_types=self.retdata_arg_types, + ) + result = self.retdata_tuple(*args) + main_call_raw_events = execution_info.call_info.events return StarknetTransactionExecutionInfo.from_internal( tx_execution_info=execution_info, - result=self._build_arguments( - arg_values=execution_info.retdata, - arg_types=self.retdata_arg_types, - args_tuple=self.retdata_tuple, - ), + result=result, main_call_events=self._build_events(raw_events=main_call_raw_events), ) - def _build_events(self, raw_events: List[OrderedEventContent]) -> List[tuple]: + def _build_events(self, raw_events: List[OrderedEvent]) -> List[Dataclass]: """ - Given a list of low-level events, builds contract events (i.e., a named tuple) from those - corresponding to high-level ones. + Given a list of low-level events, builds contract events (i.e., a dynamic dataclass) from + those corresponding to high-level ones. """ - events: List[tuple] = [] + events: List[Dataclass] = [] for raw_event in raw_events: if len(raw_event.keys) == 0 or raw_event.keys[0] not in self.event_manager: # It is a low-level event emitted using directly the emit_event syscall. @@ -312,21 +345,18 @@ def _build_events(self, raw_events: List[OrderedEventContent]) -> List[tuple]: # low-level event to contain a valid selector in its keys without being a valid high # level event - i.e., without the exact amount of data). try: - events.append( - self._build_arguments( - arg_values=arg_values, - arg_types=self.event_manager.get_event_argument_types(identifier=selector), - args_tuple=self.event_manager.get_contract_event(identifier=selector), - ) + args = self._build_arguments( + arg_values=arg_values, + arg_types=self.event_manager.get_event_argument_types(identifier=selector), ) + args_dataclass = self.event_manager.get_contract_event(identifier=selector) + events.append(args_dataclass(*args)) except ArgumentParsingFailed: pass return events - def _build_arguments( - self, arg_values: List[int], arg_types: List[CairoType], args_tuple: type - ) -> tuple: + def _build_arguments(self, arg_values: List[int], arg_types: List[CairoType]) -> List[Any]: """ Reconstructs a Pythonic variant of the original Cairo structure of the arguments, deduced by their Cairo types, and fills it with the given (flat list of) values. @@ -342,8 +372,8 @@ def build_arg( return next(arg_value_iterator) if isinstance(arg_type, TypeTuple): return tuple( - build_arg(arg_type=member, arg_value_iterator=arg_value_iterator) - for member in arg_type.members + build_arg(arg_type=cairo_type, arg_value_iterator=arg_value_iterator) + for cairo_type in arg_type.types ) if isinstance(arg_type, TypeStruct): struct_name = arg_type.scope.path[-1] @@ -375,7 +405,9 @@ def build_arg( raise ArgumentParsingFailed("Too few argument values.") # Make sure the iterator is empty. - if next(arg_value_iterator, None) is not None: + try: + assert_exhausted(iterator=arg_value_iterator) + except AssertionError: raise ArgumentParsingFailed("Too many argument values.") - return args_tuple(*res) + return res diff --git a/src/starkware/starknet/testing/contract_test.py b/src/starkware/starknet/testing/contract_test.py index 614fed80..58ee2e98 100644 --- a/src/starkware/starknet/testing/contract_test.py +++ b/src/starkware/starknet/testing/contract_test.py @@ -5,72 +5,87 @@ import pytest from starkware.starknet.business_logic.transaction_execution_objects import Event -from starkware.starknet.compiler.compile import compile_starknet_files +from starkware.starknet.core.test_contract.test_utils import get_contract_definition from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.testing.contract import StarknetContract -from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo -from starkware.starknet.testing.state import StarknetState +from starkware.starknet.testing.starknet import Starknet from starkware.starknet.utils.api_utils import cast_to_felts CONTRACT_FILE = os.path.join(os.path.dirname(__file__), "test.cairo") +# Fixtures. + + @pytest.fixture -async def contract() -> StarknetContract: - contract_definition = compile_starknet_files([CONTRACT_FILE], debug_info=True) - state = await StarknetState.empty() - contract_address, execution_info = await state.deploy( +async def starknet() -> Starknet: + return await Starknet.empty() + + +@pytest.fixture +async def test_contract(starknet: Starknet) -> StarknetContract: + return await starknet.deploy( + source=CONTRACT_FILE, constructor_calldata=[], - contract_definition=contract_definition, ) - deploy_execution_info = StarknetTransactionExecutionInfo.from_internal( - tx_execution_info=execution_info, result=(), main_call_events=[] + + +@pytest.fixture +async def proxy_contract(starknet: Starknet) -> StarknetContract: + contract_definition = get_contract_definition("delegate_proxy") + return await starknet.deploy( + constructor_calldata=[], + contract_def=contract_definition, ) - assert contract_definition.abi is not None - return StarknetContract( - state=state, - abi=contract_definition.abi, - contract_address=contract_address, - deploy_execution_info=deploy_execution_info, + +@pytest.fixture +async def account_contract(starknet: Starknet) -> StarknetContract: + contract_definition = get_contract_definition("dummy_account") + return await starknet.deploy( + constructor_calldata=[], + contract_def=contract_definition, ) +# Tests. + + @pytest.mark.asyncio -async def test_function_call(contract: StarknetContract): - await contract.increase_value(address=132, value=3).invoke() - await contract.increase_value(132, 5).invoke() - await contract.increase_value(132, 10).call() +async def test_function_call(test_contract: StarknetContract): + await test_contract.increase_value(address=132, value=3).invoke() + await test_contract.increase_value(132, 5).invoke() + await test_contract.increase_value(132, 10).call() # Since the return type is a named tuple, the result can be checked in multiple ways. - execution_info = await contract.get_value(address=132).invoke() + execution_info = await test_contract.get_value(address=132).invoke() assert execution_info.result == (8,) - execution_info = await contract.get_value(address=132).call() + execution_info = await test_contract.get_value(address=132).call() assert execution_info.result.res == 8 # Access by the name of the return value, `res`. - execution_info = await contract.takes_array(a=[1, 2, 4]).invoke() + execution_info = await test_contract.takes_array(a=[1, 2, 4]).invoke() assert execution_info.result[0] == 6 # Access by location. # Pass signature values using invoke's signature argument. - execution_info = await contract.get_signature().invoke(signature=[1, 2, 4, 10]) + execution_info = await test_contract.get_signature().invoke(signature=[1, 2, 4, 10]) assert execution_info.result == ([1, 2, 4, 10],) # Check structs. - point_1 = contract.Point(x=1, y=2) - point_2 = contract.Point(x=3, y=4) - execution_info = await contract.sum_points(points=(point_1, point_2)).invoke() + point_1 = test_contract.Point(x=1, y=2) + point_2 = test_contract.Point(x=3, y=4) + execution_info = await test_contract.sum_points(points=(point_1, point_2)).invoke() assert execution_info.result == ((4, 6),) - execution_info = await contract.sum_points(((-1, 2), (-3, 4))).invoke() + execution_info = await test_contract.sum_points(((-1, 2), (-3, 4))).invoke() assert execution_info.result.res == tuple(cast_to_felts(values=[-4, 6])) # Check multiple return values. - execution_info = await contract.sum_and_mult_points(points=(point_1, point_2)).invoke() - assert execution_info.result == (contract.Point(x=4, y=6), 11) + execution_info = await test_contract.sum_and_mult_points(points=(point_1, point_2)).invoke() + assert execution_info.result == (test_contract.Point(x=4, y=6), 11) # Check struct type consistency. - assert isinstance(execution_info.result.sum_res, contract.Point) + assert isinstance(execution_info.result.sum_res, test_contract.Point) # Check type annotatins. - func_annotations = contract.sum_and_mult_points.__annotations__ + func_annotations = test_contract.sum_and_mult_points.__annotations__ expected_annotations = { "points": Tuple[Tuple[int, int], Tuple[int, int]], "return": (Tuple[int, int], int), @@ -81,43 +96,95 @@ async def test_function_call(contract: StarknetContract): with pytest.raises( TypeError, match=re.escape("argument points[1] has wrong number of elements") ): - contract.sum_points(points=((1, 2), (3, 4, 5))) + test_contract.sum_points(points=((1, 2), (3, 4, 5))) with pytest.raises(TypeError, match=re.escape("type of argument points[0][1] must be int")): - contract.sum_points(points=((1, 2.5), (3, 4))) + test_contract.sum_points(points=((1, 2.5), (3, 4))) - point = contract.Point(x="1", y=2) + point = test_contract.Point(x="1", y=2) with pytest.raises(TypeError, match=re.escape("type of argument points[0][0] must be int")): - contract.sum_points(points=(point, (1, 2))) + test_contract.sum_points(points=(point, (1, 2))) with pytest.raises(TypeError, match=re.escape("sum_points() takes 1 positional argument")): - contract.sum_points(1, 2, 3, 4) + test_contract.sum_points(1, 2, 3, 4) @pytest.mark.asyncio -async def test_event(contract: StarknetContract): - p1 = contract.Point(x=1, y=2) - p2 = contract.Point(x=3, y=4) - point_sum = contract.Point(x=p1.x + p2.x, y=p1.y + p2.y) +async def test_proxy_call(test_contract: StarknetContract, proxy_contract: StarknetContract): + wrapped_contract = await wrap_with_proxy( + proxy_contract=proxy_contract, + impl_contract=test_contract, + ) + + await wrapped_contract.increase_value(address=132, value=7).invoke() + + execution_info = await wrapped_contract.get_value(address=132).invoke() + assert execution_info.result == (7,) + + +@pytest.mark.asyncio +async def test_raw_decorators( + test_contract: StarknetContract, + account_contract: StarknetContract, + proxy_contract: StarknetContract, +): + selector = get_selector_from_name("increase_value") + await account_contract.__execute__( + contract_address=test_contract.contract_address, selector=selector, calldata=[132, 41] + ).invoke() + + selector = get_selector_from_name("get_value") + execution_info = await account_contract.__execute__( + contract_address=test_contract.contract_address, selector=selector, calldata=[132] + ).invoke() + assert execution_info.result == (41,) + + with pytest.raises(AssertionError, match="Direct raw_input function calls are not supported."): + proxy_contract.__default__(selector=selector, calldata=[]) + - log_sum_points_tuple = contract.event_manager.get_contract_event(identifier="log_sum_points") +@pytest.mark.asyncio +async def test_event(test_contract: StarknetContract): + p1 = test_contract.Point(x=1, y=2) + p2 = test_contract.Point(x=3, y=4) + point_sum = test_contract.Point(x=p1.x + p2.x, y=p1.y + p2.y) + + log_sum_points_tuple = test_contract.event_manager.get_contract_event( + identifier="log_sum_points" + ) expected_event = log_sum_points_tuple(points=[p1, p2], sum=point_sum) - execution_info = await contract.sum_points(points=(p1, p2)).invoke() + execution_info = await test_contract.sum_points(points=(p1, p2)).invoke() (actual_event,) = execution_info.main_call_events # Check high-level form. assert isinstance(actual_event, log_sum_points_tuple) assert actual_event == expected_event - assert actual_event == ([(p1.x, p1.y), (p2.x, p2.y)], (point_sum.x, point_sum.y)) + assert (actual_event.points, actual_event.sum) == ([p1, p2], point_sum) # Check low-level flat form (which includes the array length). (actual_raw_event,) = execution_info.raw_events assert actual_raw_event == Event( - from_address=contract.contract_address, + from_address=test_contract.contract_address, keys=[get_selector_from_name("log_sum_points")], data=[2, p1.x, p1.y, p2.x, p2.y, point_sum.x, point_sum.y], ) # Check that the state's event list was updated. - assert contract.state.events == [actual_raw_event] + assert test_contract.state.events == [actual_raw_event] + + +# Utilities. + + +async def wrap_with_proxy( + proxy_contract: StarknetContract, + impl_contract: StarknetContract, +) -> StarknetContract: + """ + Wraps an implementation contract's ABI with a proxy contract. + """ + await proxy_contract.set_implementation_address( + impl_address_=impl_contract.contract_address + ).invoke() + return proxy_contract.replace_abi(impl_contract_abi=impl_contract.abi) diff --git a/src/starkware/starknet/testing/contract_utils.py b/src/starkware/starknet/testing/contract_utils.py index f0969789..1f4eac73 100644 --- a/src/starkware/starknet/testing/contract_utils.py +++ b/src/starkware/starknet/testing/contract_utils.py @@ -1,18 +1,22 @@ from collections import namedtuple +from dataclasses import make_dataclass from typing import Any, Dict, Iterable, List, Tuple, Union from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeFelt, TypePointer from starkware.cairo.lang.compiler.identifier_definition import StructDefinition from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.type_system import mark_type_resolved -from starkware.starknet.public.abi import get_selector_from_name +from starkware.starknet.public.abi import AbiType, get_selector_from_name from starkware.starknet.public.abi_structs import struct_definition_from_abi_entry +from starkware.starknet.services.api.contract_definition import ContractDefinition +from starkware.starknet.testing.objects import Dataclass EventIdentifier = Union[str, int] +RAW_OUTPUT_ARG_LIST = ["retdata_size", "retdata"] class StructManager: - def __init__(self, abi: List[Any]): + def __init__(self, abi: AbiType): self._struct_definition_mapping = { abi_entry["name"]: struct_definition_from_abi_entry(abi_entry=abi_entry) for abi_entry in abi @@ -47,7 +51,7 @@ def _build_contract_struct(self, name: str) -> type: class EventManager: - def __init__(self, abi: List[Any]): + def __init__(self, abi: AbiType): self._abi_event_mapping = { abi_entry["name"]: abi_entry for abi_entry in abi if abi_entry["type"] == "event" } @@ -58,7 +62,7 @@ def __init__(self, abi: List[Any]): } # Cached contract events and argument types. - self._contract_events: Dict[str, type] = {} + self._contract_events: Dict[str, Dataclass] = {} self._event_name_to_argument_types: Dict[str, List[CairoType]] = {} def __contains__(self, identifier: EventIdentifier) -> bool: @@ -67,7 +71,7 @@ def __contains__(self, identifier: EventIdentifier) -> bool: return identifier in self._selector_to_name - def get_contract_event(self, identifier: EventIdentifier) -> type: + def get_contract_event(self, identifier: EventIdentifier) -> Dataclass: """ Returns a named tuple representing the event whose name is given. """ @@ -97,13 +101,13 @@ def _process_event(self, name: str): names, types = parse_arguments(arguments_abi=event_abi["keys"] + event_abi["data"]) self._event_name_to_argument_types[name] = types - self._contract_events[name] = namedtuple(typename=name, field_names=names) + self._contract_events[name] = make_dataclass(cls_name=name, fields=names) def _get_event_name(self, identifier: EventIdentifier) -> str: return identifier if isinstance(identifier, str) else self._selector_to_name[identifier] -def parse_arguments(arguments_abi: dict) -> Tuple[List[str], List[CairoType]]: +def parse_arguments(arguments_abi: List) -> Tuple[List[str], List[CairoType]]: """ Given the input or output field of a StarkNet contract function ABI, computes the arguments that the python proxy function should accept. @@ -118,16 +122,26 @@ def parse_arguments(arguments_abi: dict) -> Tuple[List[str], List[CairoType]]: name = arg_entry["name"] arg_type = mark_type_resolved(parse_type(code=arg_entry["type"])) if isinstance(arg_type, TypePointer): + # Remove last argument. size_arg_actual_name = arg_names.pop() actual_type = arg_types.pop() - # Make sure the last argument was {name}_len, and remove it. - size_arg_name = f"{name}_len" - assert ( - size_arg_actual_name == size_arg_name - ), f"Array size argument {size_arg_name} must appear right before {name}." + # Allow _size suffix (instead of _len) for @raw_output functions. + if size_arg_actual_name == f"{name}_size": + assert name != "calldata", "Direct raw_input function calls are not supported." + assert [size_arg_actual_name, name] == RAW_OUTPUT_ARG_LIST + assert len(arguments_abi) == 2 + # In case of @raw_output keep retdata_size argument. + arg_names.append(size_arg_actual_name) + arg_types.append(actual_type) + else: + # Make sure the removed last argument was {name}_len. + size_arg_expected_name = f"{name}_len" + assert ( + size_arg_actual_name == size_arg_expected_name + ), f"Array size argument {size_arg_expected_name} must appear right before {name}." assert isinstance(actual_type, TypeFelt), ( - f"Array size entry {size_arg_name} expected to be type felt. Got: " + f"Array size entry {size_arg_actual_name} expected to be type felt. Got: " f"{actual_type.format()}." ) @@ -148,3 +162,8 @@ def flatten(name: str, value: Union[Any, Iterable], max_depth: int = 30) -> List res.extend(flatten(name=name, value=elm, max_depth=max_depth - 1)) return res + + +def get_abi(contract_definition: ContractDefinition) -> AbiType: + assert contract_definition.abi is not None, "Missing ABI." + return contract_definition.abi diff --git a/src/starkware/starknet/testing/mock_starknet_messaging_test.py b/src/starkware/starknet/testing/mock_starknet_messaging_test.py index 92676376..1abafdbc 100644 --- a/src/starkware/starknet/testing/mock_starknet_messaging_test.py +++ b/src/starkware/starknet/testing/mock_starknet_messaging_test.py @@ -6,7 +6,7 @@ @pytest.fixture def mock_starknet_contract(eth_test_utils): - return eth_test_utils.accounts[0].deploy(MockStarknetMessaging) + return eth_test_utils.accounts[0].deploy(MockStarknetMessaging, 0) def test_mock_send_message_from_l2(eth_test_utils, mock_starknet_contract): diff --git a/src/starkware/starknet/testing/objects.py b/src/starkware/starknet/testing/objects.py index 144edd54..1a5ffb20 100644 --- a/src/starkware/starknet/testing/objects.py +++ b/src/starkware/starknet/testing/objects.py @@ -1,36 +1,15 @@ import dataclasses -from typing import List +from typing import Any, List -from starkware.cairo.lang.vm.cairo_pie import ExecutionResources from starkware.starknet.business_logic.transaction_execution_objects import ( - ContractCall, Event, L2ToL1MessageInfo, TransactionExecutionInfo, ) +from starkware.starknet.services.api.feeder_gateway.response_objects import FunctionInvocation from starkware.starkware_utils.validated_dataclass import ValidatedDataclass - -@dataclasses.dataclass(frozen=True) -class StarknetContractCall(ValidatedDataclass): - """ - A lean version of ContractCall class, containing merely the information relevant for the user. - """ - - from_address: int # The caller contract address. - to_address: int # The called contract address. - calldata: List[int] - signature: List[int] - cairo_usage: ExecutionResources - @classmethod - def from_internal_version(cls, contract_call: ContractCall) -> "StarknetContractCall": - return cls( - from_address=contract_call.from_address, - to_address=contract_call.to_address, - calldata=contract_call.calldata, - signature=contract_call.signature, - cairo_usage=contract_call.cairo_usage, - ) +Dataclass = Any @dataclasses.dataclass(frozen=True) @@ -41,32 +20,27 @@ class StarknetTransactionExecutionInfo(ValidatedDataclass): """ result: tuple + call_info: FunctionInvocation # High-level events emitted by the main call through an @event decorated function. - main_call_events: List[tuple] + main_call_events: List[Dataclass] # All low-level events (emitted through emit_event syscall, including those corresponding to # high-level ones). raw_events: List[Event] l2_to_l1_messages: List[L2ToL1MessageInfo] - call_info: StarknetContractCall - internal_calls: List[StarknetContractCall] @classmethod def from_internal( cls, tx_execution_info: TransactionExecutionInfo, result: tuple, - main_call_events: List[tuple], + main_call_events: List[Dataclass], ) -> "StarknetTransactionExecutionInfo": return cls( result=result, main_call_events=main_call_events, raw_events=tx_execution_info.get_sorted_events(), - l2_to_l1_messages=tx_execution_info.l2_to_l1_messages, - call_info=StarknetContractCall.from_internal_version( - contract_call=tx_execution_info.call_info + l2_to_l1_messages=tx_execution_info.get_sorted_l2_to_l1_messages(), + call_info=FunctionInvocation.from_internal_version( + call_info=tx_execution_info.call_info ), - internal_calls=[ - StarknetContractCall.from_internal_version(contract_call=contract_call) - for contract_call in tx_execution_info.internal_calls - ], ) diff --git a/src/starkware/starknet/testing/postman.py b/src/starkware/starknet/testing/postman.py index 1c51396d..24b77151 100644 --- a/src/starkware/starknet/testing/postman.py +++ b/src/starkware/starknet/testing/postman.py @@ -21,7 +21,9 @@ def __init__( @classmethod async def create(cls, eth_test_utils: EthTestUtils): - mock_starknet_messaging_contract = eth_test_utils.accounts[0].deploy(MockStarknetMessaging) + mock_starknet_messaging_contract = eth_test_utils.accounts[0].deploy( + MockStarknetMessaging, 0 + ) starknet = await Starknet.empty() return cls( mock_starknet_messaging_contract=mock_starknet_messaging_contract, starknet=starknet @@ -32,16 +34,16 @@ async def _handle_l1_to_l2_messages(self): args = event.args await self.starknet.send_message_to_l2( - from_address=int(args["from_address"], 16), - to_address=args["to_address"], + from_address=int(args["fromAddress"], 16), + to_address=args["toAddress"], selector=args["selector"], payload=args["payload"], nonce=args["nonce"], ) self.mock_starknet_messaging_contract.mockConsumeMessageToL2.transact( - int(args["from_address"], 16), - args["to_address"], + int(args["fromAddress"], 16), + args["toAddress"], args["selector"], args["payload"], args["nonce"], diff --git a/src/starkware/starknet/testing/postman_test.py b/src/starkware/starknet/testing/postman_test.py index e22c60eb..c3eddf45 100644 --- a/src/starkware/starknet/testing/postman_test.py +++ b/src/starkware/starknet/testing/postman_test.py @@ -92,7 +92,7 @@ async def test_postman_l1_to_l2_positive_flow( async def test_postman_l1_to_l2_another_mock_starknet_messaging_contract( postman: Postman, eth_test_utils: EthTestUtils ): - other_messaging_contract = eth_test_utils.accounts[0].deploy(MockStarknetMessaging) + other_messaging_contract = eth_test_utils.accounts[0].deploy(MockStarknetMessaging, 0) INVALID_L2_ADDRESS = 0 INVALID_SELECTOR = 2 # This message is sent into another StarknetMessaging contract and therefore shouldn't be diff --git a/src/starkware/starknet/testing/starknet.py b/src/starkware/starknet/testing/starknet.py index e71f92a6..bceb580c 100644 --- a/src/starkware/starknet/testing/starknet.py +++ b/src/starkware/starknet/testing/starknet.py @@ -6,6 +6,7 @@ from starkware.starknet.services.api.contract_definition import ContractDefinition, EntryPointType from starkware.starknet.services.api.messages import StarknetMessageToL1 from starkware.starknet.testing.contract import StarknetContract +from starkware.starknet.testing.contract_utils import get_abi from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo from starkware.starknet.testing.state import CastableToAddress, CastableToAddressSalt, StarknetState @@ -56,14 +57,13 @@ async def deploy( contract_address_salt=contract_address_salt, constructor_calldata=[] if constructor_calldata is None else constructor_calldata, ) - assert contract_def.abi is not None, "Missing ABI." deploy_execution_info = StarknetTransactionExecutionInfo.from_internal( tx_execution_info=execution_info, result=(), main_call_events=[] ) return StarknetContract( state=self.state, - abi=contract_def.abi, + abi=get_abi(contract_definition=contract_def), contract_address=address, deploy_execution_info=deploy_execution_info, ) @@ -85,6 +85,7 @@ async def send_message_to_l2( to_address: CastableToAddress, selector: Union[int, str], payload: List[int], + max_fee: int = 0, nonce: Optional[int] = None, ) -> TransactionExecutionInfo: """ @@ -102,6 +103,7 @@ async def send_message_to_l2( selector=selector, calldata=[from_address, *payload], caller_address=0, + max_fee=max_fee, entry_point_type=EntryPointType.L1_HANDLER, nonce=nonce, ) diff --git a/src/starkware/starknet/testing/state.py b/src/starkware/starknet/testing/state.py index 404acc1b..b44e05e1 100644 --- a/src/starkware/starknet/testing/state.py +++ b/src/starkware/starknet/testing/state.py @@ -112,6 +112,7 @@ async def invoke_raw( selector: Union[int, str], calldata: List[int], caller_address: int, + max_fee: int, signature: Optional[List[int]] = None, entry_point_type: EntryPointType = EntryPointType.EXTERNAL, nonce: Optional[int] = None, @@ -138,14 +139,15 @@ async def invoke_raw( signature = [] tx = InternalInvokeFunction.create( - general_config=self.general_config, contract_address=contract_address, entry_point_selector=selector, entry_point_type=entry_point_type, calldata=calldata, + max_fee=max_fee, signature=signature, caller_address=caller_address, nonce=nonce, + chain_id=self.general_config.chain_id.value, ) with self.state.copy_and_apply() as state_copy: @@ -154,7 +156,7 @@ async def invoke_raw( ) # Add messages. - for message in tx_execution_info.l2_to_l1_messages: + for message in tx_execution_info.get_sorted_l2_to_l1_messages(): starknet_message = StarknetMessageToL1( from_address=message.from_address, to_address=message.to_address, diff --git a/src/starkware/starknet/third_party/open_zeppelin/Account.cairo b/src/starkware/starknet/third_party/open_zeppelin/Account.cairo index b6cfc4a0..d97a3a27 100644 --- a/src/starkware/starknet/third_party/open_zeppelin/Account.cairo +++ b/src/starkware/starknet/third_party/open_zeppelin/Account.cairo @@ -1,25 +1,46 @@ +# SPDX-License-Identifier: MIT +# OpenZeppelin Cairo Contracts v0.1.0 (account/Account.cairo) + %lang starknet -%builtins pedersen range_check ecdsa +from starkware.cairo.common.alloc import alloc from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.common.hash_state import ( hash_finalize, hash_init, hash_update, hash_update_single) +from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.signature import verify_ecdsa_signature from starkware.starknet.common.syscalls import ( - call_contract, get_caller_address, get_contract_address, get_tx_signature) + call_contract, get_caller_address, get_contract_address, get_tx_info) +from starkware.starknet.third_party.open_zeppelin.utils.constants import PREFIX_TRANSACTION # # Structs # -struct Message: - member sender : felt +struct MultiCall: + member account : felt + member calls_len : felt + member calls : Call* + member nonce : felt + member max_fee : felt + member version : felt +end + +struct Call: member to : felt member selector : felt + member calldata_len : felt member calldata : felt* - member calldata_size : felt - member nonce : felt +end + +# Tmp struct introduced while we wait for Cairo +# to support passing `[Call]` to __execute__ +struct CallArray: + member to : felt + member selector : felt + member data_offset : felt + member data_len : felt end # @@ -39,7 +60,7 @@ end # @view -func assert_only_self{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(): +func assert_only_self{syscall_ptr : felt*}(): let (self) = get_contract_address() let (caller) = get_caller_address() assert self = caller @@ -110,15 +131,15 @@ end @external @raw_output -func execute{ +func __execute__{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr : SignatureBuiltin*}( - to : felt, selector : felt, calldata_len : felt, calldata : felt*, nonce : felt) -> ( - retdata_size : felt, retdata : felt*): + call_array_len : felt, call_array : CallArray*, calldata_len : felt, calldata : felt*, + nonce : felt) -> (retdata_size : felt, retdata : felt*): alloc_locals let (__fp__, _) = get_fp_and_pc() - let (_address) = get_contract_address() + let (tx_info) = get_tx_info() let (_current_nonce) = current_nonce.read() # validate nonce @@ -126,59 +147,74 @@ func execute{ assert nonce = _current_nonce end - local message : Message = Message( - _address, - to, - selector, - calldata, - calldata_size=calldata_len, - _current_nonce + # TMP: Convert `CallArray` to 'Call'. + let (calls : Call*) = alloc() + from_call_array_to_call(call_array_len, call_array, calldata, calls) + let calls_len = call_array_len + + local multicall : MultiCall = MultiCall( + tx_info.account_contract_address, + calls_len, + calls, + _current_nonce, + tx_info.max_fee, + tx_info.version ) # validate transaction - let (hash) = hash_message(&message) - let (signature_len, signature) = get_tx_signature() - is_valid_signature(hash, signature_len, signature) + + is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature) # bump nonce current_nonce.write(_current_nonce + 1) # execute call - let response = call_contract( - contract_address=message.to, - function_selector=message.selector, - calldata_size=message.calldata_size, - calldata=message.calldata) + let (response : felt*) = alloc() + let (response_len) = execute_list(multicall.calls_len, multicall.calls, response) - return (retdata_size=response.retdata_size, retdata=response.retdata) + return (retdata_size=response_len, retdata=response) end -func hash_message{pedersen_ptr : HashBuiltin*}(message : Message*) -> (res : felt): +func execute_list{syscall_ptr : felt*}(calls_len : felt, calls : Call*, response : felt*) -> ( + response_len : felt): alloc_locals - # we need to make `res_calldata` local - # to prevent the reference from being revoked - let (local res_calldata) = hash_calldata(message.calldata, message.calldata_size) - let hash_ptr = pedersen_ptr - with hash_ptr: - let (hash_state_ptr) = hash_init() - # first three iterations are 'sender', 'to', and 'selector' - let (hash_state_ptr) = hash_update(hash_state_ptr, message, 3) - let (hash_state_ptr) = hash_update_single(hash_state_ptr, res_calldata) - let (hash_state_ptr) = hash_update_single(hash_state_ptr, message.nonce) - let (res) = hash_finalize(hash_state_ptr) - let pedersen_ptr = hash_ptr - return (res=res) + + # if no more calls + if calls_len == 0: + return (0) end + + # do the current call + let this_call : Call = [calls] + let res = call_contract( + contract_address=this_call.to, + function_selector=this_call.selector, + calldata_size=this_call.calldata_len, + calldata=this_call.calldata) + # copy the result in response + memcpy(response, res.retdata, res.retdata_size) + # do the next calls recursively + let (response_len) = execute_list(calls_len - 1, calls + Call.SIZE, response + res.retdata_size) + return (response_len + res.retdata_size) end -func hash_calldata{pedersen_ptr : HashBuiltin*}(calldata : felt*, calldata_size : felt) -> ( - res : felt): - let hash_ptr = pedersen_ptr - with hash_ptr: - let (hash_state_ptr) = hash_init() - let (hash_state_ptr) = hash_update(hash_state_ptr, calldata, calldata_size) - let (res) = hash_finalize(hash_state_ptr) - let pedersen_ptr = hash_ptr - return (res=res) +func from_call_array_to_call{syscall_ptr : felt*}( + call_array_len : felt, call_array : CallArray*, calldata : felt*, calls : Call*): + # if no more calls + if call_array_len == 0: + return () end + + # parse the current call + assert [calls] = Call( + to=[call_array].to, + selector=[call_array].selector, + calldata_len=[call_array].data_len, + calldata=calldata + [call_array].data_offset + ) + + # parse the remaining calls recursively + from_call_array_to_call( + call_array_len - 1, call_array + CallArray.SIZE, calldata, calls + Call.SIZE) + return () end diff --git a/src/starkware/starknet/third_party/open_zeppelin/CMakeLists.txt b/src/starkware/starknet/third_party/open_zeppelin/CMakeLists.txt index 534fbbfc..38f5a762 100644 --- a/src/starkware/starknet/third_party/open_zeppelin/CMakeLists.txt +++ b/src/starkware/starknet/third_party/open_zeppelin/CMakeLists.txt @@ -1,4 +1,4 @@ -starknet_compile(compile_open_zeppelin_account account.json Account.cairo "") +starknet_compile(compile_open_zeppelin_account account.json Account.cairo "--account_contract") python_lib(open_zeppelin_contracts_lib PREFIX starkware/starknet/third_party/open_zeppelin diff --git a/src/starkware/starknet/third_party/open_zeppelin/utils/constants.cairo b/src/starkware/starknet/third_party/open_zeppelin/utils/constants.cairo new file mode 100644 index 00000000..a49ea4e5 --- /dev/null +++ b/src/starkware/starknet/third_party/open_zeppelin/utils/constants.cairo @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: MIT +# OpenZeppelin Cairo Contracts v0.1.0 (utils/constants.cairo) + +%lang starknet + +# +# Booleans +# + +const TRUE = 1 +const FALSE = 0 + +# +# Hashing Transactions +# + +const PREFIX_TRANSACTION = 'StarkNet Transaction' + +# +# Numbers +# + +const UINT8_MAX = 256 diff --git a/src/starkware/starknet/wallets/CMakeLists.txt b/src/starkware/starknet/wallets/CMakeLists.txt index 084ddd1b..289eed71 100644 --- a/src/starkware/starknet/wallets/CMakeLists.txt +++ b/src/starkware/starknet/wallets/CMakeLists.txt @@ -6,6 +6,7 @@ python_lib(starknet_wallets_lib starknet_context.py LIBS + starknet_definitions_lib starknet_feeder_gateway_client_lib starknet_gateway_client_lib starkware_crypto_lib @@ -22,6 +23,7 @@ python_lib(starknet_standard_wallets_lib open_zeppelin_contracts_lib starknet_abi_lib starknet_definitions_lib + starknet_transaction_hash_lib starknet_transaction_lib starknet_wallets_lib starkware_crypto_lib diff --git a/src/starkware/starknet/wallets/account.py b/src/starkware/starknet/wallets/account.py index 190be364..248a9ee7 100644 --- a/src/starkware/starknet/wallets/account.py +++ b/src/starkware/starknet/wallets/account.py @@ -12,6 +12,7 @@ class WrappedMethod: address: int selector: int calldata: List[int] + max_fee: int signature: List[int] @@ -32,7 +33,13 @@ async def deploy(self): @abstractmethod async def sign_invoke_transaction( - self, contract_address: int, selector: int, calldata: List[int], nonce: Optional[int] + self, + contract_address: int, + selector: int, + calldata: List[int], + chain_id: int, + max_fee: int, + nonce: Optional[int], ) -> WrappedMethod: """ Given a transaction to execute (or call) within the context of the account, diff --git a/src/starkware/starknet/wallets/open_zeppelin.py b/src/starkware/starknet/wallets/open_zeppelin.py index db57ce63..d39ff5b8 100644 --- a/src/starkware/starknet/wallets/open_zeppelin.py +++ b/src/starkware/starknet/wallets/open_zeppelin.py @@ -3,10 +3,13 @@ import shutil from typing import List, Optional -from starkware.cairo.common.hash_state import compute_hash_on_elements from starkware.crypto.signature.signature import get_random_private_key, private_to_stark_key, sign -from starkware.starknet.definitions import fields -from starkware.starknet.public.abi import get_selector_from_name +from starkware.starknet.core.os.transaction_hash import ( + TransactionHashPrefix, + calculate_transaction_hash_common, +) +from starkware.starknet.definitions import constants, fields +from starkware.starknet.public.abi import EXECUTE_ENTRY_POINT_SELECTOR, get_selector_from_name from starkware.starknet.services.api.gateway.transaction import Deploy, InvokeFunction from starkware.starknet.third_party.open_zeppelin.starknet_contracts import account_contract from starkware.starknet.wallets.account import Account, WrappedMethod @@ -14,7 +17,6 @@ from starkware.starkware_utils.error_handling import StarkErrorCode ACCOUNT_FILE_NAME = "starknet_open_zeppelin_accounts.json" -EXECUTE_SELECTOR = get_selector_from_name("execute") GET_NONCE_SELECTOR = get_selector_from_name("get_nonce") @@ -84,6 +86,9 @@ async def deploy(self): f"""\ Sent deploy account contract transaction. +NOTE: This is a modified version of the OpenZeppelin account contract. The signature is computed +differently. + Contract address: 0x{contract_address:064x} Public key: 0x{public_key:064x} Transaction hash: {gateway_response['transaction_hash']} @@ -95,7 +100,13 @@ async def deploy(self): f.write("\n") async def sign_invoke_transaction( - self, contract_address: int, selector: int, calldata: List[int], nonce: Optional[int] + self, + contract_address: int, + selector: int, + calldata: List[int], + chain_id: int, + max_fee: Optional[int], + nonce: Optional[int], ) -> WrappedMethod: # Read the account information. assert os.path.exists(self.account_file), ( @@ -120,17 +131,31 @@ async def sign_invoke_transaction( # previous transaction was accepted. nonce = await self.get_current_nonce(account_address=account_address) - calldata_hash = compute_hash_on_elements(calldata) - message_hash = compute_hash_on_elements( - [account_address, contract_address, selector, calldata_hash, nonce] + data_offset = 0 + data_len = len(calldata) + call_entry = [contract_address, selector, data_offset, data_len] + call_array_len = 1 + wrapped_method_calldata = [call_array_len, *call_entry, len(calldata), *calldata, nonce] + max_fee = 0 if max_fee is None else max_fee + hash_value = calculate_transaction_hash_common( + tx_hash_prefix=TransactionHashPrefix.INVOKE, + version=constants.TRANSACTION_VERSION, + contract_address=account_address, + entry_point_selector=EXECUTE_ENTRY_POINT_SELECTOR, + calldata=wrapped_method_calldata, + max_fee=max_fee, + chain_id=chain_id, + additional_data=[], ) - signature = sign(msg_hash=message_hash, priv_key=private_key) + + signature = list(sign(msg_hash=hash_value, priv_key=private_key)) return WrappedMethod( address=account_address, - selector=EXECUTE_SELECTOR, - calldata=[contract_address, selector, len(calldata), *calldata, nonce], - signature=list(signature), + selector=EXECUTE_ENTRY_POINT_SELECTOR, + calldata=wrapped_method_calldata, + max_fee=max_fee, + signature=signature, ) async def get_current_nonce(self, account_address: int) -> int: @@ -138,6 +163,8 @@ async def get_current_nonce(self, account_address: int) -> int: contract_address=account_address, entry_point_selector=GET_NONCE_SELECTOR, calldata=[], + max_fee=0, + version=0, signature=[], ) res = await self.starknet_context.feeder_gateway_client.call_contract( diff --git a/src/starkware/starkware_utils/CMakeLists.txt b/src/starkware/starkware_utils/CMakeLists.txt index 503be1bc..7899a92b 100644 --- a/src/starkware/starkware_utils/CMakeLists.txt +++ b/src/starkware/starkware_utils/CMakeLists.txt @@ -8,7 +8,6 @@ python_lib(starkware_utils_lib commitment_tree/binary_fact_tree_da_utils.py commitment_tree/binary_fact_tree_node.py commitment_tree/calculation.py - commitment_tree/inner_node_fact.py commitment_tree/merkle_tree/traverse_tree.py commitment_tree/patricia_tree/virtual_calculation_node.py commitment_tree/patricia_tree/nodes.py @@ -20,6 +19,7 @@ python_lib(starkware_utils_lib ${STARKWARE_UTILS_LIBS_ADDITIONAL_FILES} LIBS + starkware_commitment_tree_facts_lib starkware_config_utils_lib starkware_custom_dict_utils_lib starkware_dataclasses_utils_lib @@ -39,6 +39,7 @@ full_python_test(patricia_tree_test FILES commitment_tree/patricia_tree/nodes_test.py commitment_tree/patricia_tree/patricia_tree_test.py + commitment_tree/patricia_tree/virtual_calculation_node_test.py commitment_tree/patricia_tree/virtual_patricia_node_test.py LIBS diff --git a/src/starkware/starkware_utils/CMakeLists_common.txt b/src/starkware/starkware_utils/CMakeLists_common.txt index 55a7648f..8c2c6d13 100644 --- a/src/starkware/starkware_utils/CMakeLists_common.txt +++ b/src/starkware/starkware_utils/CMakeLists_common.txt @@ -61,3 +61,14 @@ python_lib(starkware_config_utils_lib pip_marshmallow pip_pyyaml ) + +python_lib(starkware_commitment_tree_facts_lib + PREFIX starkware/starkware_utils + + FILES + commitment_tree/inner_node_fact.py + commitment_tree/leaf_fact.py + + LIBS + starkware_storage_lib +) diff --git a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py index 0364be68..4e87ae17 100644 --- a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py @@ -2,15 +2,16 @@ from dataclasses import field from importlib import import_module from logging import Logger -from typing import Collection, Dict, Optional, Tuple, Type, TypeVar +from typing import Collection, Dict, Optional, Tuple, Type import marshmallow_dataclass +from starkware.starkware_utils.commitment_tree.inner_node_fact import InnerNodeFact +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact, TLeafFact from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass from starkware.starkware_utils.validated_fields import bytes_as_hex_metadata -from starkware.storage.storage import Fact, FactFetchingContext +from starkware.storage.storage import FactFetchingContext -TFact = TypeVar("TFact", bound=Fact) BinaryFactDict = Dict[int, Tuple[int, ...]] @@ -29,29 +30,47 @@ class BinaryFactTree(ValidatedMarshmallowDataclass): @classmethod @abstractmethod async def empty_tree( - cls, ffc: FactFetchingContext, height: int, leaf_fact: Fact + cls, ffc: FactFetchingContext, height: int, leaf_fact: LeafFact ) -> "BinaryFactTree": """ Initializes an empty BinaryFactTree of the given height. """ - @abstractmethod async def get_leaves( self, ffc: FactFetchingContext, indices: Collection[int], - fact_cls: Type[TFact], + fact_cls: Type[TLeafFact], + facts: Optional[BinaryFactDict] = None, + ) -> Dict[int, TLeafFact]: + """ + Returns the values of the leaves whose indices are given. + """ + assert not issubclass(fact_cls, InnerNodeFact), ( + f"Leaf fact class object {fact_cls.__name__} must not inherit from " + f"{InnerNodeFact.__name__}." + ) + + return await self._get_leaves(ffc=ffc, indices=indices, fact_cls=fact_cls, facts=facts) + + @abstractmethod + async def _get_leaves( + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TLeafFact], facts: Optional[BinaryFactDict] = None, - ) -> Dict[int, TFact]: + ) -> Dict[int, TLeafFact]: """ Returns the values of the leaves whose indices are given. + This is the implementation of the specific virtual node. """ @abstractmethod async def update( self, ffc: FactFetchingContext, - modifications: Collection[Tuple[int, Fact]], + modifications: Collection[Tuple[int, LeafFact]], facts: Optional[BinaryFactDict] = None, ) -> "BinaryFactTree": """ @@ -62,7 +81,9 @@ async def update( by the facts of their paths from the leaves up. """ - async def get_leaf(self, ffc: FactFetchingContext, index: int, fact_cls: Type[TFact]) -> TFact: + async def get_leaf( + self, ffc: FactFetchingContext, index: int, fact_cls: Type[TLeafFact] + ) -> TLeafFact: """ Returns the value of a single leaf whose index is given. """ diff --git a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_da_utils.py b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_da_utils.py index d253cf1b..ae28bee1 100644 --- a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_da_utils.py +++ b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_da_utils.py @@ -1,24 +1,66 @@ import dataclasses from dataclasses import field +from typing import Type, TypeVar from starkware.python.utils import from_bytes from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactTree from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.starkware_utils.validated_fields import int_as_hex_metadata +TBinaryFactTreeDiff = TypeVar("TBinaryFactTreeDiff", bound="BinaryFactTreeDiffBase") + @dataclasses.dataclass(frozen=True) -class BinaryFactTreeDiff(ValidatedDataclass): +class BinaryFactTreeDiffBase(ValidatedDataclass): initial_root: int = field(metadata=int_as_hex_metadata(validated_field=None)) final_root: int = field(metadata=int_as_hex_metadata(validated_field=None)) + + @classmethod + def from_trees( + cls: Type[TBinaryFactTreeDiff], initial_tree: BinaryFactTree, final_tree: BinaryFactTree + ) -> TBinaryFactTreeDiff: + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class BinaryFactTreeDiffVersion1(BinaryFactTreeDiffBase): height: int @classmethod def from_trees( - cls, initial_tree: BinaryFactTree, final_tree: BinaryFactTree - ) -> "BinaryFactTreeDiff": + cls: Type["BinaryFactTreeDiffVersion1"], + initial_tree: BinaryFactTree, + final_tree: BinaryFactTree, + ) -> "BinaryFactTreeDiffVersion1": return cls( initial_root=from_bytes(initial_tree.root), final_root=from_bytes(final_tree.root), height=final_tree.height, ) + + +@dataclasses.dataclass(frozen=True) +class BinaryFactTreeDiffVersion2(BinaryFactTreeDiffBase): + @classmethod + def from_trees( + cls: Type["BinaryFactTreeDiffVersion2"], + initial_tree: BinaryFactTree, + final_tree: BinaryFactTree, + ) -> "BinaryFactTreeDiffVersion2": + return cls( + initial_root=from_bytes(initial_tree.root), final_root=from_bytes(final_tree.root) + ) + + @classmethod + def from_v1(cls, tree_diff_v1: BinaryFactTreeDiffVersion1) -> "BinaryFactTreeDiffVersion2": + return BinaryFactTreeDiffVersion2( + initial_root=tree_diff_v1.initial_root, final_root=tree_diff_v1.final_root + ) + + def to_v1(self, height: int) -> BinaryFactTreeDiffVersion1: + return BinaryFactTreeDiffVersion1( + initial_root=self.initial_root, final_root=self.final_root, height=height + ) + + +BinaryFactTreeDiff = BinaryFactTreeDiffVersion2 diff --git a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py index 6f6d8dcd..668c4a69 100644 --- a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py +++ b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py @@ -3,7 +3,7 @@ from typing import AsyncIterator, Collection, Dict, List, Optional, Tuple, Type, TypeVar from starkware.python.utils import from_bytes -from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict, TFact +from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict, TLeafFact from starkware.starkware_utils.commitment_tree.inner_node_fact import InnerNodeFact from starkware.starkware_utils.commitment_tree.merkle_tree.traverse_tree import traverse_tree from starkware.storage.storage import FactFetchingContext @@ -74,9 +74,9 @@ async def _get_leaves( self, ffc: FactFetchingContext, indices: Collection[int], - fact_cls: Type[TFact], + fact_cls: Type[TLeafFact], facts: Optional[BinaryFactDict] = None, - ) -> Dict[int, TFact]: + ) -> Dict[int, TLeafFact]: """ Returns the values of the leaves whose indices are given. @@ -86,29 +86,32 @@ async def _get_leaves( This method is to be called by a get_leaves() method of a specific tree implementation (derived class of BinaryFactTree). """ - assert not issubclass(fact_cls, InnerNodeFact), ( - f"Leaf fact class object {fact_cls.__name__} must not inherit from " - f"{InnerNodeFact.__name__}." - ) - - def unify_leaves( - left_leaves: Dict[int, TFact], right_leaves: Dict[int, TFact] - ) -> Dict[int, TFact]: - return {**left_leaves, **{x + mid: y for x, y in right_leaves.items()}} - if len(indices) == 0: return {} if self.is_leaf: - assert set(indices) == {0}, f"Commitment tree indices out of range: {indices}." - leaf = await fact_cls.get_or_fail(storage=ffc.storage, suffix=self.leaf_hash) + return await self._get_leaf(ffc=ffc, indices=indices, fact_cls=fact_cls) - return {0: leaf} + return await self._get_binary_node_leaves( + ffc=ffc, indices=indices, fact_cls=fact_cls, facts=facts + ) + async def _get_binary_node_leaves( + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TLeafFact], + facts: Optional[BinaryFactDict], + ) -> Dict[int, TLeafFact]: + """ + Returns the values of the leaves whose indices are given. + """ + # Partition indices. mid = 2 ** (self.get_height_in_tree() - 1) left_indices = [index for index in indices if index < mid] right_indices = [(index - mid) for index in indices if index >= mid] + # Get children. left_child, right_child = await self.get_children(ffc=ffc, facts=facts) # Optimizations in order to avoid a redundant asyncio.gather call that postpones the @@ -117,33 +120,34 @@ def unify_leaves( right_leaves = await right_child._get_leaves( ffc=ffc, indices=right_indices, fact_cls=fact_cls, facts=facts ) - return unify_leaves(right_leaves=right_leaves, left_leaves={}) + return unify_binary_leaves(middle_index=mid, right_leaves=right_leaves, left_leaves={}) if len(right_indices) == 0: left_leaves = await left_child._get_leaves( ffc=ffc, indices=left_indices, fact_cls=fact_cls, facts=facts ) - return unify_leaves(right_leaves={}, left_leaves=left_leaves) + return unify_binary_leaves(middle_index=mid, right_leaves={}, left_leaves=left_leaves) left_leaves, right_leaves = await asyncio.gather( left_child._get_leaves(ffc=ffc, indices=left_indices, fact_cls=fact_cls, facts=facts), right_child._get_leaves(ffc=ffc, indices=right_indices, fact_cls=fact_cls, facts=facts), ) - - return unify_leaves(left_leaves=left_leaves, right_leaves=right_leaves) + return unify_binary_leaves( + middle_index=mid, left_leaves=left_leaves, right_leaves=right_leaves + ) async def get_diff_between_trees( self, other: TBinaryFactTreeNode, ffc: FactFetchingContext, - fact_cls: Type[TFact], + fact_cls: Type[TLeafFact], facts: Optional[BinaryFactDict] = None, - ) -> List[Tuple[int, TFact, TFact]]: + ) -> List[Tuple[int, TLeafFact, TLeafFact]]: """ Returns a list of (key, old_fact, new_fact) that are different between this tree and another. - The height of the two trees must be equel. + The height of the two trees must be equal. If the 'facts' argument is not None, this dictionary is filled with facts read from the DB. """ @@ -151,7 +155,7 @@ async def get_diff_between_trees( f"Tree heights must be equal. Got: {other.get_height_in_tree()} for 'other'; " f"expected: {self.get_height_in_tree()}." ) - result: List[Tuple[int, TFact, TFact]] = [] + result: List[Tuple[int, TLeafFact, TLeafFact]] = [] async def get_children_callback( node: _BinaryFactTreeDiff, @@ -186,6 +190,23 @@ async def get_children_callback( return result + async def _get_leaf( + self, ffc: FactFetchingContext, indices: Collection[int], fact_cls: Type[TLeafFact] + ) -> Dict[int, TLeafFact]: + assert set(indices) == {0}, f"Commitment tree indices out of range: {indices}." + leaf = await fact_cls.get_or_fail(storage=ffc.storage, suffix=self.leaf_hash) + + return {0: leaf} + + +# Utilities. + + +def unify_binary_leaves( + middle_index: int, left_leaves: Dict[int, TLeafFact], right_leaves: Dict[int, TLeafFact] +) -> Dict[int, TLeafFact]: + return {**left_leaves, **{x + middle_index: y for x, y in right_leaves.items()}} + async def read_node_fact( ffc: FactFetchingContext, diff --git a/src/starkware/starkware_utils/commitment_tree/calculation.py b/src/starkware/starkware_utils/commitment_tree/calculation.py index a1e90f90..dda69644 100644 --- a/src/starkware/starkware_utils/commitment_tree/calculation.py +++ b/src/starkware/starkware_utils/commitment_tree/calculation.py @@ -10,6 +10,7 @@ TBinaryFactTreeNode, TInnerNodeFact, ) +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact from starkware.storage.storage import FactFetchingContext, HashFunctionType T = TypeVar("T") @@ -167,6 +168,31 @@ def get_dependency_calculations(self) -> List[Calculation[T]]: HashCalculation = Calculation[bytes] +# NOTE: We avoid using ValidatedDataclass here for performance. +@dataclasses.dataclass(frozen=True) +class LeafFactCalculation(HashCalculation): + """ + A calculation that contains a LeafFact and produces its hash. It doesn't depend on any other + calculations. + """ + + fact: LeafFact + + def calculate( + self, + dependency_results: list, + hash_func: HashFunctionType, + fact_nodes: NodeFactDict, + ) -> bytes: + assert len(dependency_results) == 0, "LeafFactCalculation has no dependencies." + hash_result = self.fact._hash(hash_func=hash_func) + fact_nodes[hash_result] = self.fact + return hash_result + + def get_dependency_calculations(self) -> List[Calculation[bytes]]: + return [] + + class CalculationNode(Calculation[TBinaryFactTreeNode], ABC): """ A calculation that produces a BinaryFactTreeNode. The calculation can be created from either a @@ -190,8 +216,20 @@ async def combine( @classmethod @abstractmethod - def create(cls: Type[TCalculationNode], node: TBinaryFactTreeNode) -> TCalculationNode: + def create_from_node( + cls: Type[TCalculationNode], node: TBinaryFactTreeNode + ) -> TCalculationNode: """ Creates a Calculation object from a node. It will produce the node and will have no dependencies. + This will be used in order to create calculations that represent unchanged subtrees. + """ + + @classmethod + @abstractmethod + def create_from_fact(cls: Type[TCalculationNode], fact: LeafFact) -> TCalculationNode: + """ + Creates a Calculation object from a fact. It will calculate the fact's hash and produce a + node with the hash result. It will have no dependencies. + This will be used in order to create calculations that represent changed leaves. """ diff --git a/src/starkware/starkware_utils/commitment_tree/leaf_fact.py b/src/starkware/starkware_utils/commitment_tree/leaf_fact.py new file mode 100644 index 00000000..6f8ee251 --- /dev/null +++ b/src/starkware/starkware_utils/commitment_tree/leaf_fact.py @@ -0,0 +1,20 @@ +from abc import abstractmethod +from typing import TypeVar + +from starkware.storage.storage import Fact + + +class LeafFact(Fact): + """ + A fact that represents a leaf in a commitment tree. + """ + + @property + @abstractmethod + def is_empty(self) -> bool: + """ + Returns true iff the fact represents a leaf that has no value or was deleted. + """ + + +TLeafFact = TypeVar("TLeafFact", bound=LeafFact) diff --git a/src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py b/src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py index 4ee04669..c63af553 100644 --- a/src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py @@ -49,7 +49,7 @@ async def worker_func(): finally: queue.task_done() - async def closer(): + async def closer(n_workers: int): # Wait for all tasks to be marked with task_done. This guarantees that all tasks were # completed, and no new task will be created. await queue.join() @@ -57,5 +57,6 @@ async def closer(): for _ in range(n_workers): await queue.put(AbortWorker()) - await asyncio.gather(closer(), *(worker_func() for _ in range(n_workers))) + assert n_workers is not None + await asyncio.gather(closer(n_workers=n_workers), *(worker_func() for _ in range(n_workers))) assert queue.empty() diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py index 898f6b55..c745d3b0 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py @@ -5,8 +5,9 @@ from starkware.starkware_utils.commitment_tree.binary_fact_tree import ( BinaryFactDict, BinaryFactTree, - TFact, + TLeafFact, ) +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import EmptyNodeFact from starkware.starkware_utils.commitment_tree.patricia_tree.virtual_calculation_node import ( VirtualCalculationNode, @@ -15,7 +16,7 @@ VirtualPatriciaNode, ) from starkware.starkware_utils.commitment_tree.update_tree import update_tree -from starkware.storage.storage import Fact, FactFetchingContext +from starkware.storage.storage import FactFetchingContext @marshmallow_dataclass.dataclass(frozen=True) @@ -26,7 +27,7 @@ class PatriciaTree(BinaryFactTree): @classmethod async def empty_tree( - cls, ffc: FactFetchingContext, height: int, leaf_fact: Fact + cls, ffc: FactFetchingContext, height: int, leaf_fact: LeafFact ) -> "PatriciaTree": """ Initializes an empty PatriciaTree of the given height. @@ -39,13 +40,13 @@ async def empty_tree( return PatriciaTree(root=EmptyNodeFact.EMPTY_NODE_HASH, height=height) - async def get_leaves( + async def _get_leaves( self, ffc: FactFetchingContext, indices: Collection[int], - fact_cls: Type[TFact], + fact_cls: Type[TLeafFact], facts: Optional[BinaryFactDict] = None, - ) -> Dict[int, TFact]: + ) -> Dict[int, TLeafFact]: """ Returns the values of the leaves whose indices are given. """ @@ -57,7 +58,7 @@ async def get_leaves( async def update( self, ffc: FactFetchingContext, - modifications: Collection[Tuple[int, Fact]], + modifications: Collection[Tuple[int, LeafFact]], facts: Optional[BinaryFactDict] = None, ) -> "PatriciaTree": """ diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree_test.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree_test.py index 98f42330..8526a474 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree_test.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree_test.py @@ -16,7 +16,7 @@ ) from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.storage.storage import FactFetchingContext -from starkware.storage.storage_utils import LeafFact +from starkware.storage.storage_utils import SimpleLeafFact from starkware.storage.test_utils import MockStorage @@ -86,11 +86,13 @@ async def test_update_and_decommit( """ Builds a Patricia tree using update(), and tests that the facts stored suffice to decommit. """ - tree = await PatriciaTree.empty_tree(ffc=ffc, height=height, leaf_fact=LeafFact(value=0)) + tree = await PatriciaTree.empty_tree(ffc=ffc, height=height, leaf_fact=SimpleLeafFact(value=0)) # Create some random modifications, store the facts and update the tree. # Note that leaves with value 0 are not modifications (hence, range(1, ...)). - leaves = [LeafFact(value=value) for value in random_object.choices(range(1, 1000), k=n_leaves)] + leaves = [ + SimpleLeafFact(value=value) for value in random_object.choices(range(1, 1000), k=n_leaves) + ] leaf_hashes_bytes = await asyncio.gather(*(leaf_fact.set_fact(ffc=ffc) for leaf_fact in leaves)) leaf_hashes = [from_bytes(leaf_hash_bytes) for leaf_hash_bytes in leaf_hashes_bytes] indices = random_object.sample(range(2 ** height), k=n_leaves) diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py index 1409959d..70fd06db 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py @@ -8,8 +8,10 @@ CalculationNode, ConstantCalculation, HashCalculation, + LeafFactCalculation, NodeFactDict, ) +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import ( BinaryNodeFact, EdgeNodeFact, @@ -94,7 +96,7 @@ def __post_init__(self): verify_path_value(path=self.path, length=self.length) @classmethod - def create(cls, node: VirtualPatriciaNode): + def create_from_node(cls, node: VirtualPatriciaNode): if node.is_empty: return cls.empty_node(height=node.height) @@ -105,6 +107,10 @@ def create(cls, node: VirtualPatriciaNode): height=node.height, ) + @classmethod + def create_from_fact(cls, fact: LeafFact): + return cls(bottom_calculation=LeafFactCalculation(fact=fact), path=0, length=0, height=0) + @classmethod def empty_node(cls, height: int) -> "VirtualCalculationNode": return cls( @@ -120,6 +126,9 @@ def is_empty(self) -> bool: if isinstance(self.bottom_calculation, ConstantCalculation): return self.bottom_calculation.value == EmptyNodeFact.EMPTY_NODE_HASH + if isinstance(self.bottom_calculation, LeafFactCalculation): + return self.bottom_calculation.fact.is_empty + return False @property diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node_test.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node_test.py new file mode 100644 index 00000000..53bdaf7e --- /dev/null +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node_test.py @@ -0,0 +1,121 @@ +import random + +import pytest + +from starkware.starkware_utils.commitment_tree.calculation import ConstantCalculation +from starkware.starkware_utils.commitment_tree.patricia_tree.virtual_calculation_node import ( + BinaryCalculation, + EdgeCalculation, + VirtualCalculationNode, +) +from starkware.storage.storage import FactFetchingContext +from starkware.storage.storage_utils import SimpleLeafFact +from starkware.storage.test_utils import MockStorage, hash_func + + +@pytest.fixture +def ffc() -> FactFetchingContext: + return FactFetchingContext(storage=MockStorage(), hash_func=hash_func) + + +@pytest.fixture +async def leaf_calculation(ffc: FactFetchingContext) -> ConstantCalculation: + leaf_hash = await SimpleLeafFact(value=random.randrange(1, 100)).set_fact(ffc=ffc) + return ConstantCalculation(leaf_hash) + + +@pytest.fixture +async def leaf_calculation2(ffc: FactFetchingContext) -> ConstantCalculation: + leaf_hash = await SimpleLeafFact(value=random.randrange(100, 200)).set_fact(ffc=ffc) + return ConstantCalculation(leaf_hash) + + +def test_invalid_length(leaf_calculation: ConstantCalculation): + with pytest.raises(AssertionError, match="Edge path must be at most of length 0; got: 0b1."): + VirtualCalculationNode(bottom_calculation=leaf_calculation, path=1, length=0, height=1) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("height", [0, 7]) +async def test_combine_two_empty(ffc: FactFetchingContext, height: int): + child = VirtualCalculationNode.empty_node(height=height) + parent = await VirtualCalculationNode.combine(ffc=ffc, left=child, right=child, facts=None) + assert parent == VirtualCalculationNode.empty_node(height=height + 1) + + +@pytest.mark.asyncio +async def test_combine_unmatching_height(ffc: FactFetchingContext): + empty_node_0 = VirtualCalculationNode.empty_node(height=0) + empty_node_1 = VirtualCalculationNode.empty_node(height=1) + with pytest.raises( + AssertionError, match="Only trees of same height can be combined; got: left=0 right=1." + ): + await VirtualCalculationNode.combine( + ffc=ffc, left=empty_node_0, right=empty_node_1, facts=None + ) + + +@pytest.mark.asyncio +async def test_combine_left_empty_right_leaf( + ffc: FactFetchingContext, leaf_calculation: ConstantCalculation +): + left = VirtualCalculationNode.empty_node(height=0) + right = VirtualCalculationNode(bottom_calculation=leaf_calculation, path=0, length=0, height=0) + parent = await VirtualCalculationNode.combine(ffc=ffc, left=left, right=right, facts=None) + assert parent == VirtualCalculationNode( + bottom_calculation=leaf_calculation, path=1, length=1, height=1 + ) + + +@pytest.mark.asyncio +async def test_combine_left_leaf_right_empty( + ffc: FactFetchingContext, leaf_calculation: ConstantCalculation +): + left = VirtualCalculationNode(bottom_calculation=leaf_calculation, path=0, length=0, height=0) + right = VirtualCalculationNode.empty_node(height=0) + parent = await VirtualCalculationNode.combine(ffc=ffc, left=left, right=right, facts=None) + assert parent == VirtualCalculationNode( + bottom_calculation=leaf_calculation, path=0, length=1, height=1 + ) + + +@pytest.mark.asyncio +async def test_combine_left_empty_right_virtual_edge( + ffc: FactFetchingContext, leaf_calculation: ConstantCalculation +): + left = VirtualCalculationNode.empty_node(height=1) + right = VirtualCalculationNode(bottom_calculation=leaf_calculation, path=0, length=1, height=1) + parent = await VirtualCalculationNode.combine(ffc=ffc, left=left, right=right, facts=None) + assert parent == VirtualCalculationNode( + bottom_calculation=leaf_calculation, path=0b10, length=2, height=2 + ) + + +@pytest.mark.asyncio +async def test_combine_left_virtual_edge_right_empty( + ffc: FactFetchingContext, leaf_calculation: ConstantCalculation +): + left = VirtualCalculationNode(bottom_calculation=leaf_calculation, path=1, length=1, height=1) + right = VirtualCalculationNode.empty_node(height=1) + parent = await VirtualCalculationNode.combine(ffc=ffc, left=left, right=right, facts=None) + assert parent == VirtualCalculationNode( + bottom_calculation=leaf_calculation, path=0b01, length=2, height=2 + ) + + +@pytest.mark.asyncio +async def test_combine_two_virtual_edges( + ffc: FactFetchingContext, leaf_calculation: ConstantCalculation, leaf_calculation2 +): + left = VirtualCalculationNode(bottom_calculation=leaf_calculation, path=1, length=1, height=1) + right = VirtualCalculationNode(bottom_calculation=leaf_calculation2, path=0, length=1, height=1) + parent = await VirtualCalculationNode.combine(ffc=ffc, left=left, right=right, facts=None) + assert parent == VirtualCalculationNode( + bottom_calculation=BinaryCalculation( + left=EdgeCalculation(bottom=leaf_calculation, path=1, length=1), + right=EdgeCalculation(bottom=leaf_calculation2, path=0, length=1), + ), + path=0, + length=0, + height=2, + ) diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py index d226dc36..e410892c 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Optional, Tuple +from typing import Collection, Dict, Optional, Tuple, Type from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict from starkware.starkware_utils.commitment_tree.binary_fact_tree_node import ( @@ -7,6 +7,7 @@ read_node_fact, write_node_fact, ) +from starkware.starkware_utils.commitment_tree.leaf_fact import TLeafFact from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import ( BinaryNodeFact, EdgeNodeFact, @@ -172,3 +173,81 @@ def __eq__(self, other: object) -> bool: and self.length == other.length and self.height == other.height ) + + async def _get_leaves( + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TLeafFact], + facts: Optional[BinaryFactDict] = None, + ) -> Dict[int, TLeafFact]: + """ + See base class for documentation. + """ + if len(indices) == 0: + return {} + + if self.is_leaf: + return await self._get_leaf(ffc=ffc, indices=indices, fact_cls=fact_cls) + + if self.is_virtual_edge: + return await self._get_edge_node_leaves( + ffc=ffc, indices=indices, fact_cls=fact_cls, facts=facts + ) + + return await self._get_binary_node_leaves( + ffc=ffc, indices=indices, fact_cls=fact_cls, facts=facts + ) + + async def _get_edge_node_leaves( + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TLeafFact], + facts: Optional[BinaryFactDict] = None, + ) -> Dict[int, TLeafFact]: + """ + Returns the values of the leaves whose indices are given. + """ + # Partition indices. + path_suffix_width = self.height - self.length + path_prefix = self.path << path_suffix_width + bottom_subtree_indices = [ + index - path_prefix for index in indices if (index >> path_suffix_width) == self.path + ] + empty_indices = [index for index in indices if (index >> path_suffix_width) != self.path] + + # Get bottom subtree root. + bottom_subtree_root = self.from_hash(hash_value=self.bottom_node, height=path_suffix_width) + bottom_subtree_leaves = await bottom_subtree_root._get_leaves( + ffc=ffc, indices=bottom_subtree_indices, fact_cls=fact_cls, facts=facts + ) + empty_leaves = await get_empty_leaves(ffc=ffc, indices=empty_indices, fact_cls=fact_cls) + return unify_edge_leaves( + path_prefix=path_prefix, + bottom_subtree_leaves=bottom_subtree_leaves, + empty_leaves=empty_leaves, + ) + + +# Utilities. + + +async def get_empty_leaves( + ffc: FactFetchingContext, indices: Collection[int], fact_cls: Type[TLeafFact] +) -> Dict[int, TLeafFact]: + if len(indices) == 0: + return {} + + empty_leaf = await fact_cls.get_or_fail( + storage=ffc.storage, suffix=EmptyNodeFact.EMPTY_NODE_HASH + ) + return {index: empty_leaf for index in indices} + + +def unify_edge_leaves( + path_prefix: int, + bottom_subtree_leaves: Dict[int, TLeafFact], + empty_leaves: Dict[int, TLeafFact], +) -> Dict[int, TLeafFact]: + return {**empty_leaves, **{x + path_prefix: y for x, y in bottom_subtree_leaves.items()}} diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py index 03bb708a..977c72df 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py @@ -1,11 +1,13 @@ import asyncio -from typing import Collection +import random +from typing import Collection, Dict import pytest from starkware.cairo.common.patricia_utils import compute_patricia_from_leaves_for_test from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash, pedersen_hash_func -from starkware.python.utils import to_bytes +from starkware.python.random_test import random_test +from starkware.python.utils import safe_zip, to_bytes from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import ( BinaryNodeFact, EdgeNodeFact, @@ -19,15 +21,20 @@ ) from starkware.starkware_utils.commitment_tree.update_tree import update_tree from starkware.storage.storage import FactFetchingContext -from starkware.storage.storage_utils import LeafFact +from starkware.storage.storage_utils import SimpleLeafFact from starkware.storage.test_utils import MockStorage +# Fixtures. + @pytest.fixture def ffc() -> FactFetchingContext: return FactFetchingContext(storage=MockStorage(), hash_func=pedersen_hash_func) +# Utilities. + + async def make_virtual_edge_non_canonical( ffc: FactFetchingContext, node: VirtualPatriciaNode ) -> VirtualPatriciaNode: @@ -47,6 +54,42 @@ def verify_root(leaves: Collection[int], expected_root_hash: bytes): assert expected_root_hash == to_bytes(root_hash) +async def build_empty_patricia_virtual_node( + ffc: FactFetchingContext, height: int +) -> VirtualPatriciaNode: + # Done manually, since PatriciaTree.empty() is in charge of that and is not used here. + await SimpleLeafFact.empty().set_fact(ffc=ffc) + + # Build empty tree. + return VirtualPatriciaNode.empty_node(height=height) + + +async def build_patricia_virtual_node( + ffc: FactFetchingContext, height: int, leaves: Dict[int, SimpleLeafFact] +) -> VirtualPatriciaNode: + # Build empty tree. + tree = await build_empty_patricia_virtual_node(ffc=ffc, height=height) + return await update_tree( + tree=tree, + ffc=ffc, + modifications=leaves.items(), + calculation_node_cls=VirtualCalculationNode, + ) + + +async def sample_and_verify_leaf_values( + ffc: FactFetchingContext, tree: VirtualPatriciaNode, expected_leaves: Dict[int, SimpleLeafFact] +): + sampled_indices = list(expected_leaves.keys()) + actual_leaves = await tree._get_leaves( + ffc=ffc, indices=sampled_indices, fact_cls=SimpleLeafFact + ) + assert actual_leaves == expected_leaves + + +# Tests. + + @pytest.mark.asyncio async def test_get_children(ffc: FactFetchingContext): """ @@ -58,14 +101,16 @@ async def test_get_children(ffc: FactFetchingContext): # 0 12 0 0 0 0 30 0 """ # Create empty trees and write their facts to DB. - await LeafFact(value=0).set_fact(ffc=ffc) - empty_tree_0 = VirtualPatriciaNode.empty_node(height=0) - empty_tree_1 = VirtualPatriciaNode.empty_node(height=1) + empty_tree_0 = await build_empty_patricia_virtual_node(ffc=ffc, height=0) + empty_tree_1 = await build_empty_patricia_virtual_node(ffc=ffc, height=1) assert await empty_tree_1.get_children(ffc=ffc) == (empty_tree_0, empty_tree_0) # Create leaves and write their facts to DB. leaf_hash_12, leaf_hash_30 = await asyncio.gather( - *(leaf_fact.set_fact(ffc=ffc) for leaf_fact in (LeafFact(value=12), LeafFact(value=30))) + *( + leaf_fact.set_fact(ffc=ffc) + for leaf_fact in (SimpleLeafFact(value=12), SimpleLeafFact(value=30)) + ) ) leaf_12 = VirtualPatriciaNode(bottom_node=leaf_hash_12, path=0, length=0, height=0) leaf_30 = VirtualPatriciaNode(bottom_node=leaf_hash_30, path=0, length=0, height=0) @@ -131,18 +176,19 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): Builds a Patricia tree of length 3 with the following values in the leaves: 1 -> 12, 6 -> 30. This is the same tree as in the test above, but in this test built using _update(). """ - # Done manually, since PatriciaTree.empty() is in charge of that and is not used here. - await LeafFact(value=0).set_fact(ffc=ffc) - # Build empty tree. - tree = VirtualPatriciaNode.empty_node(height=3) + tree = await build_empty_patricia_virtual_node(ffc=ffc, height=3) # Compare empty root to test util result. leaves_range = range(8) verify_root(leaves=[0 for _ in leaves_range], expected_root_hash=tree.bottom_node) # Update leaf values. - leaves = {1: LeafFact(value=12), 4: LeafFact(value=1000), 6: LeafFact(value=30)} + leaves = { + 1: SimpleLeafFact(value=12), + 4: SimpleLeafFact(value=1000), + 6: SimpleLeafFact(value=30), + } tree = await update_tree( tree=tree, ffc=ffc, @@ -152,12 +198,10 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): # Check get_leaves(). expected_leaves = { - leaf_id: leaves[leaf_id] if leaf_id in leaves else LeafFact(value=0) + leaf_id: leaves[leaf_id] if leaf_id in leaves else SimpleLeafFact.empty() for leaf_id in leaves_range } - assert ( - await tree._get_leaves(ffc=ffc, indices=leaves_range, fact_cls=LeafFact) == expected_leaves - ) + await sample_and_verify_leaf_values(ffc=ffc, tree=tree, expected_leaves=expected_leaves) # Compare to test util result. verify_root( @@ -167,10 +211,10 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): # Update leaf values again: new leaves contain addition, deletion and updating a key. updated_leaves = { - 0: LeafFact(value=2), - 1: LeafFact(value=20), - 3: LeafFact(value=6), - 6: LeafFact(value=0), + 0: SimpleLeafFact(value=2), + 1: SimpleLeafFact(value=20), + 3: SimpleLeafFact(value=6), + 6: SimpleLeafFact.empty(), } tree = await update_tree( tree=tree, @@ -181,9 +225,7 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): # Check get_leaves(). updated_leaves = {**expected_leaves, **updated_leaves} - assert ( - await tree._get_leaves(ffc=ffc, indices=leaves_range, fact_cls=LeafFact) == updated_leaves - ) + await sample_and_verify_leaf_values(ffc=ffc, tree=tree, expected_leaves=updated_leaves) # Compare to test util result. sorted_by_index_leaf_values = [updated_leaves[leaf_id].value for leaf_id in leaves_range] @@ -194,13 +236,15 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): @pytest.mark.asyncio async def test_binary_fact_tree_node_create_diff(ffc: FactFetchingContext): # All tree values ​​are zero. - empty_tree = await PatriciaTree.empty_tree(ffc=ffc, height=251, leaf_fact=LeafFact(value=0)) + empty_tree = await PatriciaTree.empty_tree( + ffc=ffc, height=251, leaf_fact=SimpleLeafFact.empty() + ) virtual_empty_tree_node = VirtualPatriciaNode.from_hash( hash_value=empty_tree.root, height=empty_tree.height ) # All tree values ​​are zero except for the fifth leaf, which has a value of 8. - one_change_tree = await empty_tree.update(ffc=ffc, modifications=[(5, LeafFact(value=8))]) + one_change_tree = await empty_tree.update(ffc=ffc, modifications=[(5, SimpleLeafFact(value=8))]) virtual_one_change_node = VirtualPatriciaNode.from_hash( hash_value=one_change_tree.root, height=empty_tree.height ) @@ -208,7 +252,7 @@ async def test_binary_fact_tree_node_create_diff(ffc: FactFetchingContext): # All tree values ​​are zero except for the fifth leaf, which has a value of 8. # and the 58th leaf, which is 81. two_change_tree = await one_change_tree.update( - ffc=ffc, modifications=[(58, LeafFact(value=81))] + ffc=ffc, modifications=[(58, SimpleLeafFact(value=81))] ) virtual_two_change_node = VirtualPatriciaNode.from_hash( hash_value=two_change_tree.root, height=empty_tree.height @@ -217,17 +261,45 @@ async def test_binary_fact_tree_node_create_diff(ffc: FactFetchingContext): # The difference between the tree whose values are all zero and the tree that has # all values zero except two values is exactly the 2 values. diff_result = await virtual_empty_tree_node.get_diff_between_trees( - other=virtual_two_change_node, ffc=ffc, fact_cls=LeafFact + other=virtual_two_change_node, ffc=ffc, fact_cls=SimpleLeafFact ) assert diff_result == [ - (5, LeafFact(value=0), LeafFact(value=8)), - (58, LeafFact(value=0), LeafFact(value=81)), + (5, SimpleLeafFact.empty(), SimpleLeafFact(value=8)), + (58, SimpleLeafFact.empty(), SimpleLeafFact(value=81)), ] # The difference between the tree whose values are zero except for the fifth leaf # and the tree whose values are all zero except for the fifth leaf (there they are equal) # and for the 58th leaf is exactly the 58th leaf. diff_result = await virtual_one_change_node.get_diff_between_trees( - other=virtual_two_change_node, ffc=ffc, fact_cls=LeafFact + other=virtual_two_change_node, ffc=ffc, fact_cls=SimpleLeafFact + ) + assert diff_result == [(58, SimpleLeafFact.empty(), SimpleLeafFact(value=81))] + + +@random_test() +@pytest.mark.asyncio +async def test_get_leaves(seed: int, ffc: FactFetchingContext): + # Build random tree. + height = 100 + n_leaves = random.randint(1, 5) * 100 + leaf_values = random.choices(range(1, 1000), k=n_leaves) + leaf_indices = [random.getrandbits(height) for _ in range(n_leaves)] + leaves = dict(safe_zip(leaf_indices, (SimpleLeafFact(value=value) for value in leaf_values))) + tree = await build_patricia_virtual_node(ffc=ffc, height=height, leaves=leaves) + + # Sample random subset of initialized leaves. + n_sampled_leaves = random.randint(1, n_leaves) + sampled_indices = random.sample(leaf_indices, k=n_sampled_leaves) + await sample_and_verify_leaf_values( + ffc=ffc, + tree=tree, + expected_leaves={index: leaf for index, leaf in leaves.items() if index in sampled_indices}, + ) + + # Sample random subset of empty leaves (almost zero prob. they will land on initialize ones). + empty_leaf = SimpleLeafFact.empty() + sampled_indices = [random.getrandbits(height) for _ in range(10)] + await sample_and_verify_leaf_values( + ffc=ffc, tree=tree, expected_leaves={index: empty_leaf for index in sampled_indices} ) - assert diff_result == [(58, LeafFact(value=0), LeafFact(value=81))] diff --git a/src/starkware/starkware_utils/commitment_tree/update_tree.py b/src/starkware/starkware_utils/commitment_tree/update_tree.py index 81a9e0ab..1d807b66 100644 --- a/src/starkware/starkware_utils/commitment_tree/update_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/update_tree.py @@ -7,13 +7,15 @@ TBinaryFactTreeNode, ) from starkware.starkware_utils.commitment_tree.calculation import CalculationNode, NodeFactDict +from starkware.starkware_utils.commitment_tree.inner_node_fact import InnerNodeFact +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact from starkware.starkware_utils.commitment_tree.merkle_tree.traverse_tree import traverse_tree from starkware.starkware_utils.executor import executor_ctx_var -from starkware.storage.storage import Fact, FactFetchingContext +from starkware.storage.storage import FactFetchingContext # Should be Tuple["UpdateTree", "UpdateTree"], but recursive types are not supported in mypy: # https://github.com/python/mypy/issues/731. -UpdateTree = Optional[Union[Tuple[Any, Any], Fact]] +UpdateTree = Optional[Union[Tuple[Any, Any], LeafFact]] NodeType = NamedTuple( "NodeType", [("index", int), ("tree", BinaryFactTreeNode), ("update", UpdateTree)] ) @@ -22,7 +24,7 @@ async def update_tree( tree: TBinaryFactTreeNode, ffc: FactFetchingContext, - modifications: Collection[Tuple[int, Fact]], + modifications: Collection[Tuple[int, LeafFact]], calculation_node_cls: Type[CalculationNode], facts: Optional[BinaryFactDict] = None, ) -> TBinaryFactTreeNode: @@ -31,7 +33,8 @@ async def update_tree( storage and returns a new BinaryFactTree representing the fact of the root of the new tree. If facts argument is not None, this dictionary is filled during building the new tree - by the facts of their paths from the leaves up. + by the facts of the modified nodes (the modified leaves won't enter to this dict as they are + already known to the function caller). This method is to be called by a update() method of a specific tree implementation (derived class of BinaryFactTree). @@ -70,20 +73,15 @@ async def update_necessary(node_index: int): del updated_nodes[2 * node_index + 1] async def update_if_possible(node_index: int, binary_fact_tree_node: BinaryFactTreeNode): - updated_nodes[node_index] = calculation_node_cls.create( + updated_nodes[node_index] = calculation_node_cls.create_from_node( node=binary_fact_tree_node, ) await update_necessary(node_index=node_index) - async def set_fact( - new_fact: UpdateTree, node_index: int, binary_fact_tree_node: BinaryFactTreeNode - ): - assert isinstance(new_fact, Fact) + async def set_fact(new_fact: UpdateTree, node_index: int): + assert isinstance(new_fact, LeafFact) - leaf_hash = await new_fact.set_fact(ffc=ffc) - updated_nodes[node_index] = calculation_node_cls.create( - node=binary_fact_tree_node.create_leaf(hash_value=leaf_hash) - ) + updated_nodes[node_index] = calculation_node_cls.create_from_fact(fact=new_fact) await update_necessary(node_index=node_index) async def traverse_node(node: NodeType) -> AsyncIterator[NodeType]: @@ -104,11 +102,7 @@ async def traverse_node(node: NodeType) -> AsyncIterator[NodeType]: if binary_fact_tree_node.is_leaf: # Leaf update. - await set_fact( - new_fact=update_subtree, - node_index=node_index, - binary_fact_tree_node=binary_fact_tree_node, - ) + await set_fact(new_fact=update_subtree, node_index=node_index) return # Inner node with updates. @@ -141,12 +135,14 @@ async def build_updated_calculation() -> CalculationNode: if facts is not None: for fact_hash, node_fact in new_facts.items(): - facts[from_bytes(fact_hash)] = node_fact.to_tuple() + # The leaves aren't stored in `facts`. Only nodes are stored there. + if isinstance(node_fact, InnerNodeFact): + facts[from_bytes(fact_hash)] = node_fact.to_tuple() return root_node -def build_update_tree(height: int, modifications: Collection[Tuple[int, Fact]]) -> UpdateTree: +def build_update_tree(height: int, modifications: Collection[Tuple[int, LeafFact]]) -> UpdateTree: """ Constructs a tree from leaf updates. This is not a full binary tree. It is just the subtree induced by the modification leaves. @@ -165,7 +161,7 @@ def build_update_tree(height: int, modifications: Collection[Tuple[int, Fact]]) for _ in range(height): parents = set(index // 2 for index in layer.keys()) - # Note that dictionary.get(key) is None if the the key is not in the dictionary. + # Note that dictionary.get(key) is None if the key is not in the dictionary. layer = {index: (layer.get(index * 2), layer.get(index * 2 + 1)) for index in parents} # We reached layer_height=0, the top layer with only the root (with index 0). diff --git a/src/starkware/starkware_utils/config_base.py b/src/starkware/starkware_utils/config_base.py index de74d65e..d31798b2 100644 --- a/src/starkware/starkware_utils/config_base.py +++ b/src/starkware/starkware_utils/config_base.py @@ -61,6 +61,15 @@ def load(cls: Type[TConfig], data: dict) -> TConfig: def remove_none_values(self, data, many=False): return {key: value for key, value in data.items() if value is not None} + @classmethod + def from_file( + cls: Type[TConfig], config_file_path: str, load_logging_config: Optional[bool] = True + ) -> TConfig: + raw_config = load_config( + config_file_path=config_file_path, load_logging_config=load_logging_config + ) + return cls.load(data=raw_config) + def log_fields(config: Config): for field in dataclasses.fields(config): diff --git a/src/starkware/starkware_utils/error_handling.py b/src/starkware/starkware_utils/error_handling.py index 7bc3b401..ed6f2bd8 100644 --- a/src/starkware/starkware_utils/error_handling.py +++ b/src/starkware/starkware_utils/error_handling.py @@ -28,6 +28,8 @@ class StarkErrorCode(ErrorCode): BATCH_ABORTED = auto() #: Connection error with the node (for example, Infura too many requests). CONNECTION_ERROR = auto() + #: Duplicate order. + DUPLICATE_ORDER = auto() #: Fact not registered in fact registry. FACT_NOT_REGISTERED = auto() #: Multi-Transaction with zero transactions. @@ -80,8 +82,12 @@ class StarkErrorCode(ErrorCode): MALFORMED_REQUEST = auto() #: Pipeline object is missing because it was migrated from an older version object. MIGRATED_PIPELINE_OBJECT_MISSING = auto() + #: The chain ID does not exist in storage. + MISSING_BLOCKCHAIN_ID = auto() #: One of the fee objects is missing while the other exists. MISSING_FEE_OBJECT = auto() + #: Nested multi-transaction (multi-transaction inside multi-transaction) + NESTED_MULTI_TRANSACTION = auto() #: The order is expired. ORDER_OVERDUE = auto() #: Positive amount value is out of range. diff --git a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py index e539808a..da11be38 100644 --- a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py +++ b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py @@ -29,6 +29,9 @@ ) +# Class definitions. + + class IntAsStr(mfields.Field): """ A field that behaves like an integer, but serializes to a string. Some amount fields are @@ -89,6 +92,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None: return None assert isinstance(value, int) + assert value >= 0, "IntAsHex does not support negative values." return hex(value) def _deserialize(self, value, attr, data, **kwargs): @@ -191,3 +195,10 @@ def enum_field_metadata( boolean_field_metadata: Dict[str, Any] = dict(marshmallow_field=RequiredBoolean()) optional_field_metadata: Dict[str, Any] = dict(allow_none=True, load_default=None) + + +# Utilities. + + +def load_int_value(field_metadata: Dict[str, Any], value: str) -> int: + return field_metadata["marshmallow_field"]._deserialize(value=value, attr=None, data=None) diff --git a/src/starkware/starkware_utils/time/fastforward.py b/src/starkware/starkware_utils/time/fastforward.py index f9f244c6..c26dd291 100644 --- a/src/starkware/starkware_utils/time/fastforward.py +++ b/src/starkware/starkware_utils/time/fastforward.py @@ -1,13 +1,14 @@ import asyncio from selectors import DefaultSelector +from typing import Optional class FFSelector(DefaultSelector): - def __init__(self, start_time, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, start_time: float): + super().__init__() self._current_time = start_time - def select(self, timeout): + def select(self, timeout: Optional[float] = None): # There are tasks to be scheduled. Continue simulating. if timeout is None: # If timeout is infinity, just wait without increasing _current_time. @@ -16,9 +17,16 @@ def select(self, timeout): return DefaultSelector.select(self, 0) -class FFEventLoop(asyncio.SelectorEventLoop): # type: ignore - def __init__(self, start_time: int = 0): +class FFEventLoop(asyncio.SelectorEventLoop): + # SelectorEventLoop is platform-dependent; on unix systems the _selector field exists. To make + # mypy happy, also define a class variable (overridden by instance variable in supertype + # constructor) with the type annotation. + # See https://docs.python.org/3.6/library/asyncio-eventloops.html#asyncio.SelectorEventLoop for + # details. + _selector: FFSelector + + def __init__(self, start_time: float = 0): super().__init__(selector=FFSelector(start_time=start_time)) - def time(self): + def time(self) -> float: return self._selector._current_time diff --git a/src/starkware/storage/batch_store_test.py b/src/starkware/storage/batch_store_test.py index 9956e481..b9579b63 100644 --- a/src/starkware/storage/batch_store_test.py +++ b/src/starkware/storage/batch_store_test.py @@ -18,9 +18,9 @@ async def set_value(val_id): await storage.set_value(f"key{val_id}".encode("ascii"), f"value{val_id}".encode("ascii")) async def get_value(val_id): - assert await storage.get_value(f"key{val_id}".encode("ascii")) == f"value{val_id}".encode( - "ascii" - ) + assert await storage.get_value_or_fail( + key=f"key{val_id}".encode("ascii") + ) == f"value{val_id}".encode("ascii") tasks = [asyncio.create_task(set_value(i)) for i in range(4)] await asyncio.sleep(0.02) diff --git a/src/starkware/storage/gated_storage_test.py b/src/starkware/storage/gated_storage_test.py index 2ff876a1..19bf4fef 100644 --- a/src/starkware/storage/gated_storage_test.py +++ b/src/starkware/storage/gated_storage_test.py @@ -12,9 +12,9 @@ async def test_gated_storage(): for k, v in keys_values: assert await storage.get_value(key=k) is None await storage.set_value(key=k, value=v) - assert await storage.get_value(key=k) == v + assert await storage.get_value_or_fail(key=k) == v assert not await storage.setnx_value(key=k, value=b"wrong") - assert await storage.get_value(key=k) == v + assert await storage.get_value_or_fail(key=k) == v assert storage.storage0.db.keys() == {b"k0", b"k1"} assert len(storage.storage1.db.keys()) == 1 @@ -35,7 +35,7 @@ async def test_magic_header_gated_storage(): storage = GatedStorage(limit=1000, storage0=MockStorage(), storage1=MockStorage()) key, value = (b"k0", MAGIC_HEADER + b"v0") await storage.set_value(key=key, value=value) - assert await storage.get_value(key=key) == value + assert await storage.get_value_or_fail(key=key) == value assert storage.storage0.db.keys() == {b"k0"} assert len(storage.storage1.db.keys()) == 1 await storage.del_value(key=key) diff --git a/src/starkware/storage/internal_proxy_storage.py b/src/starkware/storage/internal_proxy_storage.py deleted file mode 100644 index 1a34497e..00000000 --- a/src/starkware/storage/internal_proxy_storage.py +++ /dev/null @@ -1,19 +0,0 @@ -from starkware.storage.storage import Storage - - -class InternalProxyStorage(Storage): - """ - Local storage that communicates with an internal client. - """ - - def __init__(self, internal_client): - self.internal_client = internal_client - - async def set_value(self, key, value): - raise NotImplementedError("Cannot set storage values in this version.") - - async def del_value(self, key): - raise NotImplementedError("Cannot delete storage values in this version.") - - async def get_value(self, key): - return await self.internal_client.get_value(key) diff --git a/src/starkware/storage/internal_proxy_storage_test.py b/src/starkware/storage/internal_proxy_storage_test.py deleted file mode 100644 index 3ae30093..00000000 --- a/src/starkware/storage/internal_proxy_storage_test.py +++ /dev/null @@ -1,34 +0,0 @@ -import asyncio -import logging - -import pytest - -from starkware.storage.internal_proxy_storage import InternalProxyStorage - -logger = logging.getLogger(__name__) - - -class MockInternalClient: - async def get_value(self, key): - return str(key) + "_result" - - -@pytest.mark.asyncio -async def test_internal_proxy_storage(): - storage = InternalProxyStorage(internal_client=MockInternalClient()) - - async def get_value(val_id): - assert await storage.get_value(f"key{val_id}") == f"key{val_id}_result" - - tasks = [asyncio.create_task(get_value(i)) for i in range(4)] - await asyncio.sleep(0.02) - for task in tasks: - await task - - # Make sure deletions don't work. - with pytest.raises(NotImplementedError): - await storage.del_value(1) - - # Make sure value setting doesn't work. - with pytest.raises(NotImplementedError): - await storage.set_value(1, 2) diff --git a/src/starkware/storage/storage.cmake b/src/starkware/storage/storage.cmake index 2730e054..adf598b0 100644 --- a/src/starkware/storage/storage.cmake +++ b/src/starkware/storage/storage.cmake @@ -17,7 +17,6 @@ python_lib(starkware_storage_lib dict_storage.py gated_storage.py imm_storage.py - internal_proxy_storage.py names.py storage.py @@ -38,6 +37,7 @@ python_lib(starkware_storage_utils_lib storage_utils.py LIBS + starkware_commitment_tree_facts_lib starkware_python_utils_lib starkware_storage_lib ) @@ -60,7 +60,6 @@ full_python_test(starkware_storage_test FILES batch_store_test.py gated_storage_test.py - internal_proxy_storage_test.py storage_test.py LIBS diff --git a/src/starkware/storage/storage.py b/src/starkware/storage/storage.py index 138695aa..63a32a2f 100644 --- a/src/starkware/storage/storage.py +++ b/src/starkware/storage/storage.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar +from starkware.python.object_utils import generic_object_repr from starkware.python.utils import from_bytes, get_exception_repr, to_bytes from starkware.starkware_utils.config_base import get_object_by_path from starkware.starkware_utils.serializable import Serializable @@ -50,60 +51,78 @@ async def del_value(self, key: bytes): pass async def mset(self, updates: Dict[bytes, bytes]): - await asyncio.gather(*(self.set_value(*item) for item in updates.items())) + await asyncio.gather( + *(self.set_value(key=key, value=value) for key, value in updates.items()) + ) async def mget(self, keys: Sequence[bytes]) -> Tuple[Optional[bytes], ...]: - return tuple(await asyncio.gather(*(self.get_value(key) for key in keys))) + return tuple(await asyncio.gather(*(self.get_value(key=key) for key in keys))) + + async def get_value_or_fail(self, key: bytes) -> bytes: + assert isinstance(key, bytes) + result = await self.get_value(key=key) + assert result is not None, f"Key {key!r} unexpectedly does not appear in storage." + return result async def set_int(self, key: bytes, value: int): assert isinstance(key, bytes) assert isinstance(value, int) value_bytes = str(value).encode("ascii") - await self.set_value(key, value_bytes) + await self.set_value(key=key, value=value_bytes) async def setnx_int(self, key: bytes, value: int) -> bool: assert isinstance(key, bytes) assert isinstance(value, int) value_bytes = str(value).encode("ascii") - return await self.setnx_value(key, value_bytes) + return await self.setnx_value(key=key, value=value_bytes) - async def get_int(self, key: bytes, default=None) -> Optional[int]: + async def get_int(self, key: bytes) -> Optional[int]: assert isinstance(key, bytes) - result = await self.get_value(key) + result = await self.get_value(key=key) + return None if result is None else int(result) + + async def get_int_or_default(self, key: bytes, default: int) -> int: + assert isinstance(key, bytes) + result = await self.get_value(key=key) return default if result is None else int(result) + async def get_int_or_fail(self, key: bytes) -> int: + assert isinstance(key, bytes) + result = await self.get_value_or_fail(key=key) + return int(result) + async def set_float(self, key: bytes, value: float): assert isinstance(key, bytes) assert isinstance(value, float) value_bytes = str(value).encode("ascii") - await self.set_value(key, value_bytes) + await self.set_value(key=key, value=value_bytes) async def setnx_float(self, key: bytes, value: float) -> bool: assert isinstance(key, bytes) assert isinstance(value, float) value_bytes = str(value).encode("ascii") - return await self.setnx_value(key, value_bytes) + return await self.setnx_value(key=key, value=value_bytes) async def get_float(self, key: bytes, default=None) -> Optional[float]: assert isinstance(key, bytes) - result = await self.get_value(key) + result = await self.get_value(key=key) return default if result is None else float(result) async def set_str(self, key: bytes, value: str): assert isinstance(key, bytes) assert isinstance(value, str) value_bytes = value.encode("ascii") - await self.set_value(key, value_bytes) + await self.set_value(key=key, value=value_bytes) async def setnx_str(self, key: bytes, value: str) -> bool: assert isinstance(key, bytes) assert isinstance(value, str) value_bytes = value.encode("ascii") - return await self.setnx_value(key, value_bytes) + return await self.setnx_value(key=key, value=value_bytes) async def get_str(self, key: bytes, default=None) -> Optional[str]: assert isinstance(key, bytes) - result = await self.get_value(key) + result = await self.get_value(key=key) return default if result is None else result.decode("ascii") async def setnx_value(self, key: bytes, value: bytes) -> bool: @@ -112,11 +131,11 @@ async def setnx_value(self, key: bytes, value: bytes) -> bool: async def setnx_time(self, key: bytes, time: float): assert isinstance(key, bytes) assert isinstance(time, float) - await self.setnx_float(key, time) + await self.setnx_float(key=key, value=time) async def get_time(self, key: bytes) -> Optional[float]: assert isinstance(key, bytes) - return await self.get_float(key) + return await self.get_float(key=key) TDBObject = TypeVar("TDBObject", bound="DBObject") @@ -147,10 +166,9 @@ async def get_or_fail(cls: Type[TDBObject], storage: Storage, suffix: bytes) -> If key does not exist, raises an exception. """ db_key = cls.db_key(suffix=suffix) - result = await storage.get_value(key=db_key) - assert result is not None, f"Key {db_key!r} does not appear in storage." + result = await storage.get_value_or_fail(key=db_key) - return cls.deserialize(result) + return cls.deserialize(data=result) async def set(self, storage: Storage, suffix: bytes): serialized = await asyncio.get_event_loop().run_in_executor(None, self.serialize) @@ -158,7 +176,7 @@ async def set(self, storage: Storage, suffix: bytes): async def setnx(self, storage: Storage, suffix: bytes) -> bool: serialized = await asyncio.get_event_loop().run_in_executor(None, self.serialize) - return await storage.setnx_value(self.db_key(suffix), serialized) + return await storage.setnx_value(self.db_key(suffix=suffix), value=serialized) def get_update_for_mset(self, suffix: bytes) -> Tuple[bytes, bytes]: """ @@ -169,7 +187,7 @@ def get_update_for_mset(self, suffix: bytes) -> Tuple[bytes, bytes]: *[obj.get_indexed_update_for_mset(suffix) for key, obj in obj_updates.items()], )) """ - return (self.db_key(suffix), self.serialize()) + return (self.db_key(suffix=suffix), self.serialize()) TIndexedDBObject = TypeVar("TIndexedDBObject", bound="IndexedDBObject") @@ -182,19 +200,30 @@ class IndexedDBObject(DBObject): @classmethod def key(cls, index: int) -> bytes: - return cls.db_key(str(index).encode("ascii")) + return cls.db_key(suffix=str(index).encode("ascii")) @classmethod async def get_obj( cls: Type[TIndexedDBObject], storage: Storage, index: int ) -> Optional[TIndexedDBObject]: - return await cls.get(storage, str(index).encode("ascii")) + return await cls.get(storage=storage, suffix=str(index).encode("ascii")) + + @classmethod + async def get_obj_or_fail( + cls: Type[TIndexedDBObject], storage: Storage, index: int + ) -> TIndexedDBObject: + db_object_or_aborted = await cls.get_obj(storage=storage, index=index) + assert ( + db_object_or_aborted is not None + ), f"{cls.__name__} at index {index} does not exist in storage." + + return db_object_or_aborted async def set_obj(self, storage: Storage, index: int): - await self.set(storage, str(index).encode("ascii")) + await self.set(storage=storage, suffix=str(index).encode("ascii")) async def setnx_obj(self, storage: Storage, index: int) -> bool: - return await self.setnx(storage, str(index).encode("ascii")) + return await self.setnx(storage=storage, suffix=str(index).encode("ascii")) def get_indexed_update_for_mset(self, index: int) -> Tuple[bytes, bytes]: """ @@ -256,10 +285,7 @@ def __init__( self.n_workers = n_workers def __repr__(self) -> str: - return ( - f"{type(self)}(storage={self.storage!r}, hash_func={self.hash_func!r}, " - f"n_workers={self.n_workers!r})" - ) + return generic_object_repr(obj=self) class Fact(DBObject): diff --git a/src/starkware/storage/storage_utils.py b/src/starkware/storage/storage_utils.py index f54618b0..20a96575 100644 --- a/src/starkware/storage/storage_utils.py +++ b/src/starkware/storage/storage_utils.py @@ -1,11 +1,12 @@ import dataclasses from starkware.python.utils import from_bytes, to_bytes -from starkware.storage.storage import Fact, HashFunctionType +from starkware.starkware_utils.commitment_tree.leaf_fact import LeafFact +from starkware.storage.storage import HashFunctionType @dataclasses.dataclass(frozen=True) -class LeafFact(Fact): +class SimpleLeafFact(LeafFact): value: int @classmethod @@ -19,9 +20,13 @@ def _hash(self, hash_func: HashFunctionType) -> bytes: return self.serialize() @classmethod - def deserialize(cls, data: bytes) -> "LeafFact": + def deserialize(cls, data: bytes) -> "SimpleLeafFact": return cls(from_bytes(data)) @classmethod - def empty(cls) -> "LeafFact": + def empty(cls) -> "SimpleLeafFact": return cls(value=0) + + @property + def is_empty(self) -> bool: + return self.value == 0