diff --git a/LICENSE.txt b/LICENSE.txt index 275b21c8..b4daf1ea 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ Cairo Toolchain License (Source Available) -Version 1.0 dated December 22, 2020 +Version 1.1 dated January 18, 2022 This license contains the terms and conditions under which StarkWare Industries, Ltd ("StarkWare") makes available its Cairo Toolchain @@ -8,10 +8,9 @@ Industries, Ltd ("StarkWare") makes available its Cairo Toolchain conditions. StarkWare grants you ("Licensee") a license to use the Toolchain, only -for the purpose of developing and compiling Cairo programs. Licensee's +for the purpose of developing and compiling Cairo programs. Licensee's other use of the Toolchain is limited to non-commercial use, which means academic, -scientific, or research and development use, or evaluating the Cairo -language and Toolchain. +scientific, or research use, or evaluating the Cairo language and Toolchain. StarkWare grants Licensee a license to modify the Toolchain, only as necessary to fix errors. Licensee may, but is not obligated to, provide diff --git a/README.md b/README.md index c2263254..b4d51645 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.7.0.zip . +> docker cp ${container_id}:/app/cairo-lang-0.7.1.zip . > docker rm -v ${container_id} ``` diff --git a/scripts/requirements-deps.json b/scripts/requirements-deps.json index 04cbe9a1..060c7c86 100644 --- a/scripts/requirements-deps.json +++ b/scripts/requirements-deps.json @@ -16,7 +16,7 @@ "required_version": ">=1.1.2" }, { - "installed_version": "4.0.1", + "installed_version": "4.0.2", "key": "async-timeout", "package_name": "async-timeout", "required_version": ">=4.0.0a3,<5.0" @@ -28,31 +28,31 @@ "required_version": "==0.13.0" }, { - "installed_version": "21.2.0", + "installed_version": "21.4.0", "key": "attrs", "package_name": "attrs", "required_version": ">=17.3.0" }, { - "installed_version": "2.0.8", + "installed_version": "2.0.11", "key": "charset-normalizer", "package_name": "charset-normalizer", "required_version": ">=2.0,<3.0" }, { - "installed_version": "1.2.0", + "installed_version": "1.3.0", "key": "frozenlist", "package_name": "frozenlist", "required_version": ">=1.1.1" }, { - "installed_version": "5.2.0", + "installed_version": "6.0.2", "key": "multidict", "package_name": "multidict", "required_version": ">=4.5,<7.0" }, { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.7.4" @@ -73,7 +73,7 @@ { "dependencies": [ { - "installed_version": "1.2.0", + "installed_version": "1.3.0", "key": "frozenlist", "package_name": "frozenlist", "required_version": ">=1.1.0" @@ -88,14 +88,14 @@ { "dependencies": [ { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.6.5" } ], "package": { - "installed_version": "4.0.1", + "installed_version": "4.0.2", "key": "async-timeout", "package_name": "async-timeout" } @@ -111,7 +111,7 @@ { "dependencies": [], "package": { - "installed_version": "21.2.0", + "installed_version": "21.4.0", "key": "attrs", "package_name": "attrs" } @@ -135,7 +135,7 @@ { "dependencies": [], "package": { - "installed_version": "4.2.4", + "installed_version": "5.0.0", "key": "cachetools", "package_name": "cachetools" } @@ -151,7 +151,7 @@ { "dependencies": [], "package": { - "installed_version": "2.0.8", + "installed_version": "2.0.11", "key": "charset-normalizer", "package_name": "charset-normalizer" } @@ -189,13 +189,13 @@ { "dependencies": [ { - "installed_version": "2.2.2", + "installed_version": "2.3.0", "key": "eth-typing", "package_name": "eth-typing", "required_version": ">=2.0.0,<3.0.0" }, { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.2.0,<2.0.0" @@ -234,10 +234,10 @@ "required_version": ">=0.5.0,<0.6.0" }, { - "installed_version": "0.3.3", + "installed_version": "0.3.4", "key": "eth-keys", "package_name": "eth-keys", - "required_version": ">=0.2.1,<0.4.0,!=0.3.2" + "required_version": ">=0.3.4,<0.4.0" }, { "installed_version": "0.2.1", @@ -246,7 +246,7 @@ "required_version": ">=0.1.2,<2" }, { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.3.0,<2" @@ -265,7 +265,7 @@ } ], "package": { - "installed_version": "0.5.6", + "installed_version": "0.5.7", "key": "eth-account", "package_name": "eth-account" } @@ -273,7 +273,7 @@ { "dependencies": [], "package": { - "installed_version": "0.2.0", + "installed_version": "0.3.2", "key": "eth-hash", "package_name": "eth-hash" } @@ -287,19 +287,19 @@ "required_version": ">=0.9.0,<1.0.0" }, { - "installed_version": "0.3.3", + "installed_version": "0.3.4", "key": "eth-keys", "package_name": "eth-keys", "required_version": ">=0.1.0-beta.4,<1.0.0" }, { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.0.0-beta.1,<2.0.0" }, { - "installed_version": "3.11.0", + "installed_version": "3.14.1", "key": "pycryptodome", "package_name": "pycryptodome", "required_version": ">=3.4.7,<4.0.0" @@ -314,20 +314,20 @@ { "dependencies": [ { - "installed_version": "2.2.2", + "installed_version": "2.3.0", "key": "eth-typing", "package_name": "eth-typing", "required_version": ">=2.2.1,<3.0.0" }, { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", - "required_version": ">=1.3.0,<2.0.0" + "required_version": ">=1.8.2,<2.0.0" } ], "package": { - "installed_version": "0.3.3", + "installed_version": "0.3.4", "key": "eth-keys", "package_name": "eth-keys" } @@ -335,7 +335,7 @@ { "dependencies": [ { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.0.1,<2" @@ -362,7 +362,7 @@ { "dependencies": [], "package": { - "installed_version": "2.2.2", + "installed_version": "2.3.0", "key": "eth-typing", "package_name": "eth-typing" } @@ -376,20 +376,20 @@ "required_version": ">=0.10.1,<1.0.0" }, { - "installed_version": "0.2.0", + "installed_version": "0.3.2", "key": "eth-hash", "package_name": "eth-hash", - "required_version": ">=0.1.0,<1.0.0" + "required_version": ">=0.3.1,<0.4.0" }, { - "installed_version": "2.2.2", + "installed_version": "2.3.0", "key": "eth-typing", "package_name": "eth-typing", "required_version": ">=2.2.1,<3.0.0" } ], "package": { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils" } @@ -413,7 +413,7 @@ { "dependencies": [], "package": { - "installed_version": "1.2.0", + "installed_version": "1.3.0", "key": "frozenlist", "package_name": "frozenlist" } @@ -437,20 +437,20 @@ { "dependencies": [ { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.6.4" }, { - "installed_version": "3.6.0", + "installed_version": "3.7.0", "key": "zipp", "package_name": "zipp", "required_version": ">=0.5" } ], "package": { - "installed_version": "4.8.2", + "installed_version": "4.10.1", "key": "importlib-metadata", "package_name": "importlib-metadata" } @@ -472,7 +472,7 @@ "required_version": ">=0.0.7" }, { - "installed_version": "2.26.0", + "installed_version": "2.27.1", "key": "requests", "package_name": "requests", "required_version": ">=2.11" @@ -487,25 +487,25 @@ { "dependencies": [ { - "installed_version": "21.2.0", + "installed_version": "21.4.0", "key": "attrs", "package_name": "attrs", "required_version": ">=17.4.0" }, { - "installed_version": "4.8.2", + "installed_version": "4.10.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": null }, { - "installed_version": "0.18.0", + "installed_version": "0.18.1", "key": "pyrsistent", "package_name": "pyrsistent", "required_version": ">=0.14.0" }, { - "installed_version": "58.4.0", + "installed_version": "60.5.0", "key": "setuptools", "package_name": "setuptools", "required_version": null @@ -642,7 +642,7 @@ { "dependencies": [], "package": { - "installed_version": "5.2.0", + "installed_version": "6.0.2", "key": "multidict", "package_name": "multidict" } @@ -666,7 +666,7 @@ { "dependencies": [], "package": { - "installed_version": "1.21.4", + "installed_version": "1.21.5", "key": "numpy", "package_name": "numpy" } @@ -674,7 +674,7 @@ { "dependencies": [ { - "installed_version": "3.0.6", + "installed_version": "3.0.7", "key": "pyparsing", "package_name": "pyparsing", "required_version": ">=2.0.2,!=3.0.5" @@ -719,7 +719,7 @@ } ], "package": { - "installed_version": "2.2.0", + "installed_version": "2.2.1", "key": "pipdeptree", "package_name": "pipdeptree" } @@ -727,7 +727,7 @@ { "dependencies": [ { - "installed_version": "4.8.2", + "installed_version": "4.10.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": ">=0.12" @@ -742,7 +742,7 @@ { "dependencies": [], "package": { - "installed_version": "0.12.0", + "installed_version": "0.13.1", "key": "prometheus-client", "package_name": "prometheus-client" } @@ -750,7 +750,7 @@ { "dependencies": [], "package": { - "installed_version": "3.19.1", + "installed_version": "3.19.4", "key": "protobuf", "package_name": "protobuf" } @@ -766,7 +766,7 @@ { "dependencies": [], "package": { - "installed_version": "3.11.0", + "installed_version": "3.14.1", "key": "pycryptodome", "package_name": "pycryptodome" } @@ -774,7 +774,7 @@ { "dependencies": [], "package": { - "installed_version": "3.0.6", + "installed_version": "3.0.7", "key": "pyparsing", "package_name": "pyparsing" } @@ -782,7 +782,7 @@ { "dependencies": [], "package": { - "installed_version": "0.18.0", + "installed_version": "0.18.1", "key": "pyrsistent", "package_name": "pyrsistent" } @@ -790,13 +790,13 @@ { "dependencies": [ { - "installed_version": "21.2.0", + "installed_version": "21.4.0", "key": "attrs", "package_name": "attrs", "required_version": ">=19.2.0" }, { - "installed_version": "4.8.2", + "installed_version": "4.10.1", "key": "importlib-metadata", "package_name": "importlib-metadata", "required_version": ">=0.12" @@ -826,14 +826,14 @@ "required_version": ">=1.8.2" }, { - "installed_version": "0.10.2", - "key": "toml", - "package_name": "toml", - "required_version": null + "installed_version": "2.0.1", + "key": "tomli", + "package_name": "tomli", + "required_version": ">=1.0.0" } ], "package": { - "installed_version": "6.2.5", + "installed_version": "7.0.0", "key": "pytest", "package_name": "pytest" } @@ -841,14 +841,20 @@ { "dependencies": [ { - "installed_version": "6.2.5", + "installed_version": "7.0.0", "key": "pytest", "package_name": "pytest", - "required_version": ">=5.4.0" + "required_version": ">=6.1.0" + }, + { + "installed_version": "4.0.1", + "key": "typing-extensions", + "package_name": "typing-extensions", + "required_version": ">=3.7.2" } ], "package": { - "installed_version": "0.16.0", + "installed_version": "0.18.0", "key": "pytest-asyncio", "package_name": "pytest-asyncio" } @@ -862,7 +868,7 @@ "required_version": ">=2017.4.17" }, { - "installed_version": "2.0.8", + "installed_version": "2.0.11", "key": "charset-normalizer", "package_name": "charset-normalizer", "required_version": "~=2.0.0" @@ -874,14 +880,14 @@ "required_version": ">=2.5,<4" }, { - "installed_version": "1.26.7", + "installed_version": "1.26.8", "key": "urllib3", "package_name": "urllib3", "required_version": ">=1.21.1,<1.27" } ], "package": { - "installed_version": "2.26.0", + "installed_version": "2.27.1", "key": "requests", "package_name": "requests" } @@ -889,7 +895,7 @@ { "dependencies": [ { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.0.2,<2" @@ -904,7 +910,7 @@ { "dependencies": [], "package": { - "installed_version": "58.4.0", + "installed_version": "60.5.0", "key": "setuptools", "package_name": "setuptools" } @@ -935,9 +941,9 @@ { "dependencies": [], "package": { - "installed_version": "0.10.2", - "key": "toml", - "package_name": "toml" + "installed_version": "2.0.1", + "key": "tomli", + "package_name": "tomli" } }, { @@ -951,7 +957,7 @@ { "dependencies": [], "package": { - "installed_version": "2.13.2", + "installed_version": "2.13.3", "key": "typeguard", "package_name": "typeguard" } @@ -959,7 +965,7 @@ { "dependencies": [], "package": { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions" } @@ -973,7 +979,7 @@ "required_version": ">=0.3.0" }, { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.7.4" @@ -988,7 +994,7 @@ { "dependencies": [], "package": { - "installed_version": "1.26.7", + "installed_version": "1.26.8", "key": "urllib3", "package_name": "urllib3" } @@ -1016,25 +1022,25 @@ "required_version": ">=2.0.0b6,<3.0.0" }, { - "installed_version": "0.5.6", + "installed_version": "0.5.7", "key": "eth-account", "package_name": "eth-account", - "required_version": ">=0.5.6,<0.6.0" + "required_version": ">=0.5.7,<0.6.0" }, { - "installed_version": "0.2.0", + "installed_version": "0.3.2", "key": "eth-hash", "package_name": "eth-hash", "required_version": ">=0.2.0,<1.0.0" }, { - "installed_version": "2.2.2", + "installed_version": "2.3.0", "key": "eth-typing", "package_name": "eth-typing", "required_version": ">=2.0.0,<3.0.0" }, { - "installed_version": "1.9.5", + "installed_version": "1.10.0", "key": "eth-utils", "package_name": "eth-utils", "required_version": ">=1.9.5,<2.0.0" @@ -1064,22 +1070,22 @@ "required_version": ">=1.1.6,<2.0.0" }, { - "installed_version": "3.19.1", + "installed_version": "3.19.4", "key": "protobuf", "package_name": "protobuf", "required_version": ">=3.10.0,<4" }, { - "installed_version": "2.26.0", + "installed_version": "2.27.1", "key": "requests", "package_name": "requests", "required_version": ">=2.16.0,<3.0.0" }, { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions", - "required_version": ">=3.7.4.1,<4" + "required_version": ">=3.7.4.1,<5" }, { "installed_version": "9.1", @@ -1089,7 +1095,7 @@ } ], "package": { - "installed_version": "5.25.0", + "installed_version": "5.27.0", "key": "web3", "package_name": "web3" } @@ -1105,7 +1111,7 @@ { "dependencies": [], "package": { - "installed_version": "0.37.0", + "installed_version": "0.37.1", "key": "wheel", "package_name": "wheel" } @@ -1119,13 +1125,13 @@ "required_version": ">=2.0" }, { - "installed_version": "5.2.0", + "installed_version": "6.0.2", "key": "multidict", "package_name": "multidict", "required_version": ">=4.0" }, { - "installed_version": "3.10.0.2", + "installed_version": "4.0.1", "key": "typing-extensions", "package_name": "typing-extensions", "required_version": ">=3.7.4" @@ -1140,7 +1146,7 @@ { "dependencies": [], "package": { - "installed_version": "3.6.0", + "installed_version": "3.7.0", "key": "zipp", "package_name": "zipp" } diff --git a/scripts/requirements-gen.txt b/scripts/requirements-gen.txt index f8fe93a0..28c6b202 100644 --- a/scripts/requirements-gen.txt +++ b/scripts/requirements-gen.txt @@ -1,7 +1,7 @@ aiohttp cachetools ecdsa -eth-hash[pycryptodome]==0.2.0 +eth-hash[pycryptodome] fastecdsa frozendict==1.2 lark-parser diff --git a/scripts/requirements.txt b/scripts/requirements.txt index cb2114f1..3ffc9835 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -2,30 +2,30 @@ aiohttp==3.8.1 aiosignal==1.2.0 -async-timeout==4.0.1 +async-timeout==4.0.2 asynctest==0.13.0 -attrs==21.2.0 +attrs==21.4.0 base58==2.1.1 bitarray==1.2.2 -cachetools==4.2.4 +cachetools==5.0.0 certifi==2021.10.8 -charset-normalizer==2.0.8 +charset-normalizer==2.0.11 cytoolz==0.11.2 ecdsa==0.17.0 eth-abi==2.1.1 -eth-account==0.5.6 -eth-hash[pycryptodome]==0.2.0 +eth-account==0.5.7 +eth-hash[pycryptodome]==0.3.2 eth-keyfile==0.5.1 -eth-keys==0.3.3 +eth-keys==0.3.4 eth-rlp==0.2.1 -eth-typing==2.2.2 -eth-utils==1.9.5 +eth-typing==2.3.0 +eth-utils==1.10.0 fastecdsa==2.2.3 frozendict==1.2 -frozenlist==1.2.0 +frozenlist==1.3.0 hexbytes==0.2.2 idna==3.3 -importlib-metadata==4.8.2 +importlib-metadata==4.10.1 iniconfig==1.1.1 ipfshttpclient==0.8.0a2 jsonschema==3.2.0 @@ -37,38 +37,38 @@ marshmallow-enum==1.5.1 marshmallow-oneofschema==3.0.1 mpmath==1.2.1 multiaddr==0.0.9 -multidict==5.2.0 +multidict==6.0.2 mypy-extensions==0.4.3 netaddr==0.8.0 -numpy==1.21.4 +numpy==1.21.5 packaging==21.3 parsimonious==0.8.1 pip==21.3.1 -pipdeptree==2.2.0 +pipdeptree==2.2.1 pluggy==1.0.0 -prometheus-client==0.12.0 -protobuf==3.19.1 +prometheus-client==0.13.1 +protobuf==3.19.4 py==1.11.0 -pycryptodome==3.11.0 -pyparsing==3.0.6 -pyrsistent==0.18.0 -pytest==6.2.5 -pytest-asyncio==0.16.0 +pycryptodome==3.14.1 +pyparsing==3.0.7 +pyrsistent==0.18.1 +pytest==7.0.0 +pytest-asyncio==0.18.0 PyYAML==6.0 -requests==2.26.0 +requests==2.27.1 rlp==2.0.1 -setuptools==58.4.0 +setuptools==60.5.0 six==1.16.0 sympy==1.9 -toml==0.10.2 +tomli==2.0.1 toolz==0.11.2 -typeguard==2.13.2 -typing-extensions==3.10.0.2 +typeguard==2.13.3 typing-inspect==0.7.1 -urllib3==1.26.7 +typing_extensions==4.0.1 +urllib3==1.26.8 varint==1.0.2 -web3==5.25.0 +web3==5.27.0 websockets==9.1 -wheel==0.37.0 +wheel==0.37.1 yarl==1.7.2 -zipp==3.6.0 +zipp==3.7.0 diff --git a/src/cmake_utils/pip_rules.cmake b/src/cmake_utils/pip_rules.cmake index 98bcd607..3f4e98f3 100644 --- a/src/cmake_utils/pip_rules.cmake +++ b/src/cmake_utils/pip_rules.cmake @@ -31,27 +31,44 @@ function(python_pip TARGET) # The filename will have '==' in it. set(STAMP_FILE ${CMAKE_BINARY_DIR}/python_pip/${TARGET}_${INTERPRETER}_${REQ}.stamp) - # Build wheel and prepare library directory. - add_custom_command( - OUTPUT ${STAMP_FILE} - # Download or build wheel. - COMMENT "Building wheel ${REQ} for ${INTERPRETER}" - COMMAND ${CMAKE_COMMAND} -E make_directory ${LIB_DIR} - COMMAND ${CMAKE_COMMAND} -E make_directory ${DOWNLOAD_DIR} - COMMAND - ${INTERPRETER} -m pip wheel --no-deps -w ${DOWNLOAD_DIR}/ ${REQ} ${PIP_INSTALL_ARGS_${INTERPRETER}} - # Extract wheel. - COMMAND cd ${LIB_DIR} && ${CMAKE_COMMAND} -E tar xzf ${DOWNLOAD_DIR}/*.whl - # Some wheels may put their files at /{name}-{version}.data/(pure|plat)lib/, instead of under - # the root directory. See https://www.python.org/dev/peps/pep-0427/#id24. - # Copy the files from there. Suppress errors, which happen most of the times when this - # subdirectory does not exist. - COMMAND cp -r ${LIB_DIR}/*.data/*lib/* ${LIB_DIR}/ > /dev/null 2>&1 || true - # Cleanup download. - COMMAND ${CMAKE_COMMAND} -E remove_directory ${DOWNLOAD_DIR} - # Timestamp. - COMMAND ${CMAKE_COMMAND} -E touch ${STAMP_FILE} - ) + # Creating library directory. + if (${REQ} MATCHES "==local$") + string(REPLACE "==" "-" PACKAGE_NAME ${REQ}) + set(ZIP_FILE "${PROJECT_SOURCE_DIR}/${PACKAGE_NAME}.zip") + add_custom_command( + OUTPUT ${STAMP_FILE} + COMMENT "Building ${REQ} from a local copy." + COMMAND rm -rf ${LIB_DIR}/* + COMMAND unzip ${ZIP_FILE} -d ${LIB_DIR} > /dev/null + # We don't know if the directory in the zip has the same name as the package. + COMMAND ls ${LIB_DIR} | grep -v -x ${PACKAGE_NAME} | xargs -r -I {} mv ${LIB_DIR}/{} ${LIB_DIR}/${PACKAGE_NAME} + COMMAND mv ${LIB_DIR}/${PACKAGE_NAME}/* ${LIB_DIR} + COMMAND rm -rf ${LIB_DIR}/${PACKAGE_NAME}/ + COMMAND ${CMAKE_COMMAND} -E touch ${STAMP_FILE} + DEPENDS ${ZIP_FILE} + ) + else() + add_custom_command( + OUTPUT ${STAMP_FILE} + # Download or build wheel. + COMMENT "Building wheel ${REQ} for ${INTERPRETER}" + COMMAND ${CMAKE_COMMAND} -E make_directory ${LIB_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory ${DOWNLOAD_DIR} + COMMAND + ${INTERPRETER} -m pip wheel --no-deps -w ${DOWNLOAD_DIR}/ ${REQ} ${PIP_INSTALL_ARGS_${INTERPRETER}} + # Extract wheel. + COMMAND cd ${LIB_DIR} && ${CMAKE_COMMAND} -E tar xzf ${DOWNLOAD_DIR}/*.whl + # Some wheels may put their files at /{name}-{version}.data/(pure|plat)lib/, instead of + # under the root directory. See https://www.python.org/dev/peps/pep-0427/#id24. + # Copy the files from there. Suppress errors, which happen most of the times when this + # subdirectory does not exist. + COMMAND cp -r ${LIB_DIR}/*.data/*lib/* ${LIB_DIR}/ > /dev/null 2>&1 || true + # Cleanup download. + COMMAND ${CMAKE_COMMAND} -E remove_directory ${DOWNLOAD_DIR} + # Timestamp. + COMMAND ${CMAKE_COMMAND} -E touch ${STAMP_FILE} + ) + endif() set(ALL_STAMPS ${ALL_STAMPS} ${STAMP_FILE}) set(ALL_LIB_DIRS ${ALL_LIB_DIRS} "${INTERPRETER}:${LIB_DIR}") diff --git a/src/cmake_utils/python_rules.cmake b/src/cmake_utils/python_rules.cmake index 9be862ed..b45a0449 100644 --- a/src/cmake_utils/python_rules.cmake +++ b/src/cmake_utils/python_rules.cmake @@ -9,6 +9,8 @@ function(get_lib_info_file OUTPUT_VARIABLE LIB) set(${OUTPUT_VARIABLE} ${PY_LIB_INFO_GLOBAL_DIR}/${LIB}.info PARENT_SCOPE) endfunction() +add_custom_target(all_python_libs_dryrun) + # Creates a python library target. # Caller should make this target depend on artifact targets (using add_dependencies()) # to force correct build order. @@ -91,21 +93,30 @@ function(python_lib LIB) get_lib_info_file(INFO_FILE ${LIB}) file(RELATIVE_PATH CMAKE_DIR ${CMAKE_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) + set(GEN_PY_LIB_COMMAND + ${GEN_PY_LIB_EXECUTABLE} + --name ${LIB} + --lib_dir ${LIB_DIR_ROOT} + --files ${LIB_FILES} + --lib_deps ${ARGS_LIBS} + --py_exe_deps ${ARGS_PY_EXE_DEPENDENCIES} + --cmake_dir ${CMAKE_DIR} + --prefix ${ARGS_PREFIX} + ) add_custom_command( OUTPUT ${INFO_FILE} - COMMAND ${GEN_PY_LIB_EXECUTABLE} - --name ${LIB} - --lib_dir ${LIB_DIR_ROOT} - --files ${LIB_FILES} - --lib_deps ${ARGS_LIBS} - --output ${INFO_FILE} - --py_exe_deps ${ARGS_PY_EXE_DEPENDENCIES} - --cmake_dir ${CMAKE_DIR} - --prefix ${ARGS_PREFIX} + COMMAND ${GEN_PY_LIB_COMMAND} --output ${INFO_FILE} DEPENDS ${GEN_PY_LIB_EXECUTABLE} ${DEP_INFO} ${UNITED_LIBS} ${ARGS_PY_EXE_DEPENDENCIES} ${ALL_FILE_DEPS} ${LIB}_copy_files ) add_custom_target(${LIB} ALL DEPENDS ${INFO_FILE}) + add_custom_command( + OUTPUT ${INFO_FILE}.dryrun + COMMAND ${GEN_PY_LIB_COMMAND} --output ${INFO_FILE}.dryrun + DEPENDS ${GEN_PY_LIB_EXECUTABLE} + ) + add_custom_target(${LIB}_dryrun DEPENDS ${INFO_FILE}.dryrun) + add_dependencies(all_python_libs_dryrun ${LIB}_dryrun) endfunction() # Creates a virtual environment target. diff --git a/src/services/everest/api/feeder_gateway/CMakeLists.txt b/src/services/everest/api/feeder_gateway/CMakeLists.txt index 7fedfc08..e6ee4a2f 100644 --- a/src/services/everest/api/feeder_gateway/CMakeLists.txt +++ b/src/services/everest/api/feeder_gateway/CMakeLists.txt @@ -7,3 +7,14 @@ python_lib(everest_feeder_gateway_client_lib LIBS services_external_api_lib ) + +python_lib(everest_feeder_gateway_response_objects_lib + PREFIX services/everest/api/feeder_gateway + + FILES + response_objects.py + + LIBS + starkware_dataclasses_utils_lib + pip_marshmallow +) diff --git a/src/services/everest/api/feeder_gateway/response_objects.py b/src/services/everest/api/feeder_gateway/response_objects.py new file mode 100644 index 00000000..29849672 --- /dev/null +++ b/src/services/everest/api/feeder_gateway/response_objects.py @@ -0,0 +1,16 @@ +from typing import Any, Dict + +import marshmallow + +from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass + + +class BaseResponseObject(ValidatedMarshmallowDataclass): + """ + Contains common functionality to response objects from the FeederGateway. + This class must not contain a marshmallow schema and should not be directly (de)serialized. + """ + + @marshmallow.post_dump + def remove_none_values(self, data: Dict[Any, Any], many: bool = False) -> Dict[Any, Any]: + return {key: value for key, value in data.items() if value is not None} diff --git a/src/services/everest/business_logic/CMakeLists.txt b/src/services/everest/business_logic/CMakeLists.txt index 8d196b48..f85a929a 100644 --- a/src/services/everest/business_logic/CMakeLists.txt +++ b/src/services/everest/business_logic/CMakeLists.txt @@ -27,3 +27,15 @@ python_lib(everest_internal_transaction_lib pip_marshmallow_enum pip_marshmallow_oneofschema ) + +python_lib(everest_transaction_execution_objects_lib + PREFIX services/everest/business_logic + + FILES + transaction_execution_objects.py + + LIBS + starkware_dataclasses_utils_lib + pip_marshmallow + pip_marshmallow_dataclass +) diff --git a/src/services/everest/business_logic/transaction_execution_objects.py b/src/services/everest/business_logic/transaction_execution_objects.py new file mode 100644 index 00000000..946f449c --- /dev/null +++ b/src/services/everest/business_logic/transaction_execution_objects.py @@ -0,0 +1,28 @@ +from typing import Any, Dict, Optional + +import marshmallow +import marshmallow_dataclass + +from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionFailureReason(ValidatedMarshmallowDataclass): + """ + Contains the failure reason (error code and error message) of an invalid + transaction. + """ + + tx_id: int + code: str + error_message: Optional[str] + + @marshmallow.decorators.post_dump + def truncate_error_message(self, data: Dict[str, Any], many: bool, **kwargs) -> Dict[str, Any]: + error_message = data["error_message"] + if error_message is None: + # Do nothing. + return data + + data["error_message"] = error_message[:5000] + return data diff --git a/src/services/everest/definitions/fields.py b/src/services/everest/definitions/fields.py index 4bcea682..1388a47a 100644 --- a/src/services/everest/definitions/fields.py +++ b/src/services/everest/definitions/fields.py @@ -55,8 +55,8 @@ def format_invalid_value_error_message(self, value: str, name: Optional[str] = N ) # Serialization. - def get_marshmallow_field(self) -> mfields.Field: - return mfields.String(required=True) + def get_marshmallow_field(self, required: bool, load_default: Any) -> mfields.Field: + return mfields.String(required=required, load_default=load_default) def convert_valid_to_checksum(self, value: str) -> ChecksumAddress: self.validate(value=value) diff --git a/src/starkware/cairo/common/cairo_function_runner.py b/src/starkware/cairo/common/cairo_function_runner.py index 0234acc3..42ca7f00 100644 --- a/src/starkware/cairo/common/cairo_function_runner.py +++ b/src/starkware/cairo/common/cairo_function_runner.py @@ -11,7 +11,6 @@ RangeCheckBuiltinRunner, ) from starkware.cairo.lang.builtins.signature.signature_builtin_runner import SignatureBuiltinRunner -from starkware.cairo.lang.compiler.identifier_definition import LabelDefinition from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.tracer.tracer import trace_runner @@ -19,9 +18,9 @@ from starkware.cairo.lang.vm.crypto import pedersen_hash from starkware.cairo.lang.vm.output_builtin_runner import OutputBuiltinRunner from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue -from starkware.cairo.lang.vm.security import SecurityError, verify_secure_runner +from starkware.cairo.lang.vm.security import verify_secure_runner from starkware.cairo.lang.vm.utils import RunResources -from starkware.cairo.lang.vm.vm_exceptions import VmException +from starkware.cairo.lang.vm.vm_exceptions import SecurityError, VmException class CairoFunctionRunner(CairoRunner): @@ -126,26 +125,18 @@ def run( verify_secure - Run verify_secure_runner to do extra verifications. trace_on_failure - Run the tracer in case of failure to help debugging. apply_modulo_to_args - Apply modulo operation on integer arguments. - use_full_name - Treat func_name as a fully qualified identifer name, instance of a relative - one. + use_full_name - Treat 'func_name' as a fully qualified identifier name, rather than a + relative one. """ assert isinstance(self.program, Program) + entrypoint = self.program.get_label(func_name, full_name_lookup=use_full_name) + structs_factory = CairoStructFactory.from_program(program=self.program) full_args_struct = structs_factory.build_func_args( func=ScopedName.from_string(scope=func_name) ) all_args = full_args_struct(*args, **kwargs) - entrypoint: Union[str, int] - if use_full_name: - identifier = self.program.identifiers.get_by_full_name( - name=ScopedName.from_string(scope=func_name) - ) - assert isinstance(identifier, LabelDefinition) - entrypoint = identifier.pc - else: - entrypoint = func_name - try: self.run_from_entrypoint( entrypoint, diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index faef31a4..39e898a4 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.7.0 +0.7.1 diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt index 2a8ec00f..f6f32325 100644 --- a/src/starkware/cairo/lang/compiler/CMakeLists.txt +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -111,6 +111,7 @@ python_lib(cairo_compile_test_utils_lib LIBS cairo_compile_lib + starkware_python_utils_lib pip_pytest ) diff --git a/src/starkware/cairo/lang/compiler/parser_transformer.py b/src/starkware/cairo/lang/compiler/parser_transformer.py index 85ece23c..5e6e200e 100644 --- a/src/starkware/cairo/lang/compiler/parser_transformer.py +++ b/src/starkware/cairo/lang/compiler/parser_transformer.py @@ -91,6 +91,8 @@ from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.scoped_name import ScopedName +DEFAULT_SHORT_STRING_MAX_LENGTH = 31 + @dataclasses.dataclass class ParserContext: @@ -98,7 +100,7 @@ class ParserContext: Represents information that affects the parsing process. """ - short_string_max_length: int = 31 + short_string_max_length: int = DEFAULT_SHORT_STRING_MAX_LENGTH parent_location: Optional[ParentLocation] = None # If True, treat type identifiers as resolved. diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py index 1a8fd147..670274f5 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py @@ -301,11 +301,11 @@ def test_compound_expressions_tempvars(): [ap] = [ap + (-1)] + [ap + (-1)]; ap++ -%{ memory[ap] = int(val) %} +%{ memory[ap] = to_felt_or_relocatable(val) %} ap += 1 [ap] = [ap + (-1)] * 15; ap++ [ap] = [ap + (-1)] + 5; ap++ -%{ memory[ap] = int(1) %} +%{ memory[ap] = to_felt_or_relocatable(1) %} ap += 1 [ap] = [ap + (-2)] + [ap + (-1)]; ap++ """.replace( diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py index 235278c0..385dc2e3 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py @@ -1041,14 +1041,15 @@ def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): src_size = self.get_size(dest_type) if isinstance(elm.expr, ExprHint): - if not isinstance(dest_type, TypeFelt): + if not isinstance(dest_type, (TypeFelt, TypePointer)): raise PreprocessorError( - "Hint tempvars must be of type felt.", location=elm.expr.location + "Hint tempvars must be of type felt or a pointer.", + location=elm.expr.location, ) self.visit( CodeElementHint( hint=ExprHint( - hint_code=f"memory[ap] = int({elm.expr.hint_code})", + hint_code=f"memory[ap] = to_felt_or_relocatable({elm.expr.hint_code})", n_prefix_newlines=0, location=elm.location, ), @@ -1117,16 +1118,16 @@ def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAsse if self.auxiliary_info is not None: self.auxiliary_info.start_compound_assert_eq(lhs=instruction.a, rhs=instruction.b) - src_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_a, expr_type=expr_type_a) - dst_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_b, expr_type=expr_type_b) + dst_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_a, expr_type=expr_type_a) + src_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_b, expr_type=expr_type_b) original_ap_tracking = self.flow_tracking.get_ap_tracking() - for src, dst in safe_zip(src_exprs, dst_exprs): + for dst, src in safe_zip(dst_exprs, src_exprs): ap_diff = self.flow_tracking.get_ap_tracking() - original_ap_tracking - src = self.simplifier.visit(translate_ap(src, ap_diff)) dst = self.simplifier.visit(translate_ap(dst, ap_diff)) + src = self.simplifier.visit(translate_ap(src, ap_diff)) compound_expressions_code_elements, (expr_a, expr_b) = process_compound_assert( - src, dst, self._compound_expression_context + dst, src, self._compound_expression_context ) assert_eq = CodeElementInstruction( instruction=InstructionAst( diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py index 8540e074..5d17646c 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py @@ -220,7 +220,8 @@ def test_temporary_variable(): tempvar q : T assert q.t = 0 tempvar w -tempvar h = nondet %{ 5**i %} +tempvar h1 = nondet %{ 5**i %} +tempvar h2 : felt* = cast(nondet %{ segments.add_temp_segment() %}, felt*) + 3 """ program = preprocess_str(code=code, prime=PRIME) assert ( @@ -238,8 +239,11 @@ def test_temporary_variable(): ap += 2 [ap + (-1)] = 0 ap += 1 -%{ memory[ap] = int(5**i) %} +%{ memory[ap] = to_felt_or_relocatable(5**i) %} ap += 1 +%{ memory[ap] = to_felt_or_relocatable(segments.add_temp_segment()) %} +ap += 1 +[ap] = [ap + (-1)] + 3; ap++ """ ) @@ -274,7 +278,7 @@ def test_temporary_variable_failures(): tempvar a : T = nondet %{ 1 %} """, """ -file:?:?: Hint tempvars must be of type felt. +file:?:?: Hint tempvars must be of type felt or a pointer. tempvar a : T = nondet %{ 1 %} ^************^ """, diff --git a/src/starkware/cairo/lang/compiler/program.py b/src/starkware/cairo/lang/compiler/program.py index ed4d7a5b..b6c12fea 100644 --- a/src/starkware/cairo/lang/compiler/program.py +++ b/src/starkware/cairo/lang/compiler/program.py @@ -24,7 +24,7 @@ from starkware.cairo.lang.compiler.references import Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName, ScopedNameAsStr from starkware.starkware_utils.marshmallow_dataclass_fields import IntAsHex -from starkware.starkware_utils.validated_dataclass import SerializableMarshmallowDataclass +from starkware.starkware_utils.serializable_dataclass import SerializableMarshmallowDataclass @dataclasses.dataclass diff --git a/src/starkware/cairo/lang/compiler/test_utils.py b/src/starkware/cairo/lang/compiler/test_utils.py index abe4e6c2..78952029 100644 --- a/src/starkware/cairo/lang/compiler/test_utils.py +++ b/src/starkware/cairo/lang/compiler/test_utils.py @@ -1,6 +1,26 @@ +from starkware.cairo.lang.compiler.parser_transformer import DEFAULT_SHORT_STRING_MAX_LENGTH +from starkware.python.utils import to_ascii_string + + def read_file_from_dict(dct): """ Given a dictionary from a package name (a.b.c) to a file content returns a function that can be passed to collect_imports. """ return lambda x: (dct[x], x) + + +def short_string_to_felt(short_string: str) -> int: + """ + Returns a felt representation of the given short string. + """ + if len(short_string) > DEFAULT_SHORT_STRING_MAX_LENGTH: + raise ValueError( + f"Short string (e.g., 'abc') length must be at most {DEFAULT_SHORT_STRING_MAX_LENGTH}." + ) + try: + text_bytes = short_string.encode("ascii") + except UnicodeEncodeError: + raise ValueError(f"Expected an ascii string. Found: {to_ascii_string(short_string)}.") + + return int.from_bytes(text_bytes, "big") diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index 8df713f0..f414fe1a 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.7.0", + "version": "0.7.1", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/lang.cmake b/src/starkware/cairo/lang/lang.cmake index 30e81852..3c4e6bef 100644 --- a/src/starkware/cairo/lang/lang.cmake +++ b/src/starkware/cairo/lang/lang.cmake @@ -28,6 +28,7 @@ python_venv(cairo_lang_package_venv cairo_bootloader_generate_fact_lib cairo_common_lib cairo_compile_lib + cairo_compile_test_utils_lib cairo_hash_program_lib cairo_run_lib cairo_script_lib @@ -37,6 +38,7 @@ python_venv(cairo_lang_package_venv starknet_block_hash_lib starknet_script_lib starknet_testing_lib + starkware_eth_test_utils_lib ) python_lib(cairo_instances_lib diff --git a/src/starkware/cairo/lang/setup.py b/src/starkware/cairo/lang/setup.py index 67da9b60..f16f7838 100644 --- a/src/starkware/cairo/lang/setup.py +++ b/src/starkware/cairo/lang/setup.py @@ -30,6 +30,7 @@ "starkware.starknet": ["common/*.cairo"], "starkware.starknet.core.os": ["*.cairo", "*.json"], "starkware.starknet.security": ["whitelists/*.json"], + "starkware.starknet.testing": ["*.json"], "starkware.starknet.third_party.open_zeppelin": ["account.json"], }, scripts=[ diff --git a/src/starkware/cairo/lang/tracer/tracer_data.py b/src/starkware/cairo/lang/tracer/tracer_data.py index 4eaabb90..77a8f510 100644 --- a/src/starkware/cairo/lang/tracer/tracer_data.py +++ b/src/starkware/cairo/lang/tracer/tracer_data.py @@ -202,7 +202,7 @@ def from_files( """ Factory method constructing TracerData from files. """ - program = Program.Schema().load(json.load(open(program_path))) + program = Program.load(data=json.load(open(program_path))) field_bytes = math.ceil(program.prime.bit_length() / 8) memory = read_memory(memory_path, field_bytes) trace = read_trace(trace_path) @@ -215,7 +215,7 @@ def from_files( public_input = None debug_info = ( - DebugInfo.Schema().load(json.load(open(debug_info_path))) + DebugInfo.load(data=json.load(open(debug_info_path))) if debug_info_path is not None else None ) diff --git a/src/starkware/cairo/lang/vm/cairo_run.py b/src/starkware/cairo/lang/vm/cairo_run.py index 6266c5d4..f92b497e 100644 --- a/src/starkware/cairo/lang/vm/cairo_run.py +++ b/src/starkware/cairo/lang/vm/cairo_run.py @@ -213,7 +213,7 @@ def load_program(program) -> ProgramBase: "Did you compile the code before running it? " f"Error: '{err}'" ) - return Program.Schema().load(program_json) + return Program.load(data=program_json) def cairo_run(args): @@ -505,7 +505,7 @@ def write_air_public_input( def write_debug_info(debug_info_file: IO[str], debug_info: DebugInfo): - json.dump(obj=DebugInfo.Schema().dump(debug_info), fp=debug_info_file) + json.dump(obj=debug_info.dump(), fp=debug_info_file) debug_info_file.flush() diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index f86ee7b5..7429a014 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -372,6 +372,20 @@ def read_return_values(self): range(pointer - self.execution_base, self.vm.run_context.ap - self.execution_base) ) + def mark_as_accessed(self, address: RelocatableValue, size: int): + """ + Marks the memory range [address, address + size) as accessed. + + This is useful when a memory range is not accessed in a partial scenario + but is known to be accessed in the real use case. + + For example, a StarkNet contract entry point might not use all the information provided by + the StarkNet OS. + """ + assert self.accessed_addresses is not None + for i in range(size): + self.accessed_addresses.add(address + i) + def check_used_cells(self): """ Returns True if there are enough allocated cells for the builtins. diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index ae93d402..f844972c 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -3,6 +3,7 @@ from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.cairo.lang.vm.vm_exceptions import SecurityError FIRST_MEMORY_ADDR = 1 @@ -85,9 +86,11 @@ def compute_effective_sizes(self, include_tmp_segments: bool = False): index: 0 for index in range(first_segment_index, self.n_segments) } for addr in self.memory: - assert isinstance( - addr, RelocatableValue - ), f"Expected memory address to be relocatable value. Found: {addr}." + if not isinstance(addr, RelocatableValue): + raise SecurityError( + f"Expected memory address to be relocatable value. Found: {addr}." + ) + previous_max_size = self._segment_used_sizes[addr.segment_index] self._segment_used_sizes[addr.segment_index] = max(previous_max_size, addr.offset + 1) diff --git a/src/starkware/cairo/lang/vm/relocatable.py b/src/starkware/cairo/lang/vm/relocatable.py index 333a9274..01661e98 100644 --- a/src/starkware/cairo/lang/vm/relocatable.py +++ b/src/starkware/cairo/lang/vm/relocatable.py @@ -108,6 +108,15 @@ def to_tuple(value: MaybeRelocatable) -> Tuple[int, ...]: else: raise NotImplementedError(f"Expected MaybeRelocatable, got: {type(value).__name__}.") + @staticmethod + def to_felt_or_relocatable(value: T): + """ + Converts to int unless value is RelocatableValue, otherwise return value as is. + """ + if isinstance(value, RelocatableValue): + return value + return int(value) + @classmethod def from_tuple(cls, value: Tuple[int, ...]) -> MaybeRelocatable: """ diff --git a/src/starkware/cairo/lang/vm/relocatable_test.py b/src/starkware/cairo/lang/vm/relocatable_test.py index 32324f30..e3f7d1dd 100644 --- a/src/starkware/cairo/lang/vm/relocatable_test.py +++ b/src/starkware/cairo/lang/vm/relocatable_test.py @@ -59,6 +59,16 @@ def test_to_tuple_from_tuple(): assert RelocatableValue.from_tuple((1, 2)) == x +def test_to_felt_or_relocatable(): + assert RelocatableValue.to_felt_or_relocatable(5) == 5 + + x = RelocatableValue(1, 2) + assert RelocatableValue.to_felt_or_relocatable(x) == x + + assert RelocatableValue.to_felt_or_relocatable(True) == 1 + assert RelocatableValue.to_felt_or_relocatable(False) == 0 + + def test_relocatable_value_frozen(): x = RelocatableValue(1, 2) with pytest.raises( diff --git a/src/starkware/cairo/lang/vm/security.py b/src/starkware/cairo/lang/vm/security.py index 14c97c1f..21bed281 100644 --- a/src/starkware/cairo/lang/vm/security.py +++ b/src/starkware/cairo/lang/vm/security.py @@ -1,9 +1,6 @@ from starkware.cairo.lang.vm.cairo_runner import CairoRunner from starkware.cairo.lang.vm.relocatable import RelocatableValue - - -class SecurityError(Exception): - pass +from starkware.cairo.lang.vm.vm_exceptions import SecurityError def verify_secure_runner(runner: CairoRunner, verify_builtins=True): diff --git a/src/starkware/cairo/lang/vm/security_test.py b/src/starkware/cairo/lang/vm/security_test.py index a279789e..89d429a6 100644 --- a/src/starkware/cairo/lang/vm/security_test.py +++ b/src/starkware/cairo/lang/vm/security_test.py @@ -3,7 +3,8 @@ from starkware.cairo.lang.vm.cairo_runner import get_runner_from_code from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.cairo.lang.vm.security import SecurityError, verify_secure_runner +from starkware.cairo.lang.vm.security import verify_secure_runner +from starkware.cairo.lang.vm.vm_exceptions import SecurityError PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 diff --git a/src/starkware/cairo/lang/vm/virtual_machine_base.py b/src/starkware/cairo/lang/vm/virtual_machine_base.py index 326404bf..1ee6e24a 100644 --- a/src/starkware/cairo/lang/vm/virtual_machine_base.py +++ b/src/starkware/cairo/lang/vm/virtual_machine_base.py @@ -164,6 +164,7 @@ def __init__( "fis_quad_residue": lambda a, p=self.prime: math_utils.is_quad_residue(a, p), "fsqrt": lambda a, p=self.prime: math_utils.sqrt(a, p), "safe_div": math_utils.safe_div, + "to_felt_or_relocatable": RelocatableValue.to_felt_or_relocatable, } ) diff --git a/src/starkware/cairo/lang/vm/vm_core.py b/src/starkware/cairo/lang/vm/vm_core.py index 73da0055..52938b84 100644 --- a/src/starkware/cairo/lang/vm/vm_core.py +++ b/src/starkware/cairo/lang/vm/vm_core.py @@ -374,11 +374,10 @@ def decode_instruction(encoded_inst: int, imm: Optional[int] = None) -> Instruct def decode_current_instruction(self) -> Instruction: try: instruction_encoding, imm = self.run_context.get_instruction_encoding() + instruction = self.decode_instruction(instruction_encoding, imm) except Exception as exc: raise self.as_vm_exception(exc) from None - instruction = self.decode_instruction(instruction_encoding, imm) - return instruction def opcode_assertions(self, instruction: Instruction, operands: Operands): diff --git a/src/starkware/cairo/lang/vm/vm_exceptions.py b/src/starkware/cairo/lang/vm/vm_exceptions.py index 1f023390..50f98883 100644 --- a/src/starkware/cairo/lang/vm/vm_exceptions.py +++ b/src/starkware/cairo/lang/vm/vm_exceptions.py @@ -6,6 +6,10 @@ from starkware.cairo.lang.compiler.error_handling import LocationError +class SecurityError(Exception): + pass + + class VmExceptionBase(Exception): """ Base class for exceptions thrown by the Cairo VM. diff --git a/src/starkware/cairo/lang/vm/vm_test.py b/src/starkware/cairo/lang/vm/vm_test.py index e49f9930..ded0de7b 100644 --- a/src/starkware/cairo/lang/vm/vm_test.py +++ b/src/starkware/cairo/lang/vm/vm_test.py @@ -262,6 +262,15 @@ def test_hint_between_references(): run_single(code=code, steps=1) +def test_nondet_hint_pointer(): + code = """ +%{ from starkware.cairo.lang.vm.relocatable import RelocatableValue %} +tempvar x : felt* = cast(nondet %{ RelocatableValue(12, 34) %}, felt*) + 3 +""" + vm = run_single(code=code, steps=2) + assert vm.run_context.memory[101] == RelocatableValue(12, 37) + + def test_hint_exception(): code = """ # Some comment. @@ -679,6 +688,23 @@ def test_call_unknown(): run_single(code, 1) +def test_invalid_instruction(): + code = """ + dw -1 + """ + with pytest.raises(VmException) as exc_info: + run_single(code, 1) + + assert str(exc_info.value) == ( + """\ +:2:5: Error at pc=0:10: +Unsupported instruction. + dw -1 + ^***^\ +""" + ) + + def test_call_wrong_operands(): code = """ call rel 0 diff --git a/src/starkware/cairo/sharp/sharp_client.py b/src/starkware/cairo/sharp/sharp_client.py index b758fae3..aed298c8 100755 --- a/src/starkware/cairo/sharp/sharp_client.py +++ b/src/starkware/cairo/sharp/sharp_client.py @@ -58,7 +58,7 @@ def compile_cairo(self, source_code_path: str, flags: Optional[List[str]] = None ] + used_flags ) - program = Program.Schema().load(json.load(open(compiled_program_file.name, "r"))) + program = Program.load(data=json.load(open(compiled_program_file.name, "r"))) return program def run_program(self, program: Program, program_input_path: Optional[str]) -> CairoPie: @@ -69,7 +69,7 @@ def run_program(self, program: Program, program_input_path: Optional[str]) -> Ca with tempfile.NamedTemporaryFile("w") as cairo_pie_file, tempfile.NamedTemporaryFile( "w" ) as program_file: - json.dump(Program.Schema().dump(program), program_file, indent=4, sort_keys=True) + json.dump(program.dump(), program_file, indent=4, sort_keys=True) program_file.flush() cairo_run_cmd = list( filter( @@ -199,7 +199,7 @@ def submit(args, command_args): cairo_pie = CairoPie.from_file(args.cairo_pie) else: if args.program is not None: - program = Program.Schema().load(json.load(open(args.program))) + program = Program.load(data=json.load(open(args.program))) else: assert args.source is not None print("Compiling...", file=sys.stderr) diff --git a/src/starkware/eth/eth_test_utils.py b/src/starkware/eth/eth_test_utils.py index 8fcec016..59c52868 100644 --- a/src/starkware/eth/eth_test_utils.py +++ b/src/starkware/eth/eth_test_utils.py @@ -10,6 +10,8 @@ import pytest import web3.exceptions from web3 import HTTPProvider, Web3 +from web3 import types as web3_types +from web3.contract import Contract # Max timeout for web3 requests in seconds. TIMEOUT_FOR_WEB3_REQUESTS = 120 # Seconds. @@ -97,14 +99,14 @@ class EthAccount: Represents an account in the system. """ - def __init__(self, w3, address: str): + def __init__(self, w3: Web3, address: str): self.address = address self.w3 = w3 def __repr__(self): return f"{type(self).__name__}({self.address})" - def deploy(self, contract_json, *constructor_args): + def deploy(self, contract_json, *constructor_args) -> "EthContract": """ Deploys a contract. contract_json should be the compiled json, including the "abi" and "bytecode" keys. @@ -128,13 +130,15 @@ def deploy(self, contract_json, *constructor_args): return EthContract( w3=self.w3, address=contract_address, - w3_contract=self.w3.eth.contract(abi=abi, address=contract_address), + w3_contract=self.w3.eth.contract(address=contract_address, abi=abi), abi=abi, deployer=self, ) def transfer(self, to: "EthAccount", value: int): - self.w3.eth.send_transaction({"from": self.address, "to": to.address, "value": value}) + self.w3.eth.send_transaction( + {"from": self.address, "to": to.address, "value": web3_types.Wei(value)} + ) @property def balance(self) -> int: @@ -146,22 +150,25 @@ class EthContract: Represents an Ethereum contract. """ - def __init__(self, w3, address: str, w3_contract, abi: Abi, deployer: EthAccount): + def __init__( + self, w3: Web3, address: str, w3_contract: Contract, abi: Abi, deployer: EthAccount + ): self.w3 = w3 self.address = address self.w3_contract = w3_contract self.abi = abi self.deployer = deployer - def __getattr__(self, name): + def __getattr__(self, name: str) -> "EthContractFunction": return EthContractFunction(contract=self, name=name) - def replace_abi(self, abi): - w3_contract = self.w3.eth.contract(abi=abi, address=self.address) + def replace_abi(self, abi: Abi) -> "EthContract": + w3_contract = self.w3.eth.contract(address=Web3.toChecksumAddress(self.address), abi=abi) + return EthContract( w3=self.w3, address=self.address, - w3_contract=w3_contract, + w3_contract=w3_contract, # type: ignore[arg-type] abi=abi, deployer=self.deployer, ) @@ -195,9 +202,9 @@ def __init__(self, contract: EthContract, name: str): def _func(self): return getattr(self.contract.w3_contract.functions, self.name) - def transact(self, *args, transact_args=None): + def transact(self, *args, transact_args: Optional[Dict[str, Any]] = None) -> "EthReceipt": transact_args = prepare_transact_args( - transact_args, default_from=self.contract.deployer.address + transact_args=transact_args, default_from=self.contract.deployer.address ) args = fix_tx_args(args) diff --git a/src/starkware/python/utils.py b/src/starkware/python/utils.py index bafaa59b..ecbdd9f5 100644 --- a/src/starkware/python/utils.py +++ b/src/starkware/python/utils.py @@ -1,9 +1,12 @@ import asyncio +import contextlib import itertools +import logging import os import random import re import subprocess +import time from collections import UserDict from typing import Any, AsyncIterable, Awaitable, Iterable, List, Optional, TypeVar @@ -341,3 +344,26 @@ def _all_subclasses(cls: type) -> List[type]: def get_exception_repr(exception: Exception) -> str: return f"{type(exception).__name__}({exception})" + + +@contextlib.contextmanager +def log_time(logger: logging.Logger, name: str): + """ + Logs the elapsed time in seconds. + + Example: + with log_time(logger=logger, name="Foo"): + sleep(1) + """ + start = time.time() + try: + yield + finally: + logger.info(f"Ran '{name}'. Elapsed: {time.time() - start}.") + + +def to_ascii_string(value: str) -> str: + """ + Converts the given string to an ascii-encodeable one by replacing non-ascii characters with '?'. + """ + return value.encode("ascii", "replace").decode("ascii") diff --git a/src/starkware/python/utils_test.py b/src/starkware/python/utils_test.py index b8ece5bb..1966416a 100644 --- a/src/starkware/python/utils_test.py +++ b/src/starkware/python/utils_test.py @@ -1,4 +1,6 @@ +import random import re +import string import pytest @@ -11,6 +13,7 @@ indent, iter_blockify, safe_zip, + to_ascii_string, unique, ) @@ -132,3 +135,17 @@ class F(D): all_subclasses_set = set(all_subclass_objects) assert len(all_subclass_objects) == len(all_subclasses_set) assert all_subclasses_set == {A, C, D, E, F} + + +def test_to_ascii_str(): + # Should not change printable strings. + assert to_ascii_string(value=string.printable) == string.printable + + string_pattern = "Value: {value}." + expected_string = string_pattern.format(value="?") + non_ascii_character_orders = [128, 1_114_111, random.randint(128, 1_114_111)] + # Check that these non-ascii characters are converted as expected (replaced with '?'). + for order in non_ascii_character_orders: + converted_string = to_ascii_string(value=string_pattern.format(value=chr(order))) + assert converted_string.isascii() + assert converted_string == expected_string diff --git a/src/starkware/starknet/business_logic/CMakeLists.txt b/src/starkware/starknet/business_logic/CMakeLists.txt index fb903bee..e6491d01 100644 --- a/src/starkware/starknet/business_logic/CMakeLists.txt +++ b/src/starkware/starknet/business_logic/CMakeLists.txt @@ -35,6 +35,7 @@ python_lib(starknet_internal_transaction_interface_lib everest_internal_transaction_lib everest_transaction_lib starknet_business_logic_lib + starknet_definitions_lib starknet_general_config_lib starknet_transaction_lib starknet_transaction_execution_objects_lib @@ -53,6 +54,7 @@ python_lib(starknet_transaction_execution_objects_lib everest_definitions_lib everest_internal_transaction_lib starknet_business_logic_lib + starknet_contract_definition_lib starknet_definitions_lib starkware_dataclasses_utils_lib pip_marshmallow_dataclass @@ -68,6 +70,8 @@ python_lib(starknet_internal_transaction_lib cairo_function_runner_lib cairo_relocatable_lib cairo_vm_lib + everest_business_logic_lib + everest_internal_transaction_lib everest_transaction_lib starknet_abi_lib starknet_business_logic_lib @@ -76,12 +80,17 @@ python_lib(starknet_internal_transaction_lib starknet_definitions_lib starknet_general_config_lib starknet_internal_transaction_interface_lib + starknet_os_abi_lib starknet_os_utils_lib starknet_storage_lib starknet_transaction_execution_objects_lib + starknet_transaction_hash_lib starknet_transaction_lib starkware_config_utils_lib + starkware_dataclasses_utils_lib starkware_error_handling_lib + starkware_python_utils_lib + starkware_storage_lib pip_marshmallow pip_marshmallow_dataclass pip_marshmallow_enum diff --git a/src/starkware/starknet/business_logic/internal_transaction.py b/src/starkware/starknet/business_logic/internal_transaction.py index 6730e53f..ee597b77 100644 --- a/src/starkware/starknet/business_logic/internal_transaction.py +++ b/src/starkware/starknet/business_logic/internal_transaction.py @@ -4,7 +4,7 @@ import logging from abc import abstractmethod from dataclasses import field -from typing import ClassVar, Dict, List, Optional, Tuple, Type, cast +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, cast import marshmallow import marshmallow_dataclass @@ -16,9 +16,14 @@ from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner from starkware.cairo.lang.vm.cairo_pie import ExecutionResources from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.cairo.lang.vm.security import SecurityError from starkware.cairo.lang.vm.utils import ResourcesError -from starkware.cairo.lang.vm.vm_exceptions import HintException, VmException, VmExceptionBase +from starkware.cairo.lang.vm.vm_exceptions import ( + HintException, + SecurityError, + VmException, + VmExceptionBase, +) +from starkware.python.utils import to_bytes from starkware.starknet.business_logic.internal_transaction_interface import ( InternalStateTransaction, ) @@ -34,6 +39,12 @@ TransactionExecutionInfo, ) from starkware.starknet.core.os import os_utils, syscall_utils +from starkware.starknet.core.os.contract_hash import compute_contract_hash +from starkware.starknet.core.os.transaction_hash import ( + TransactionHashPrefix, + calculate_deploy_transaction_hash, + calculate_transaction_hash_common, +) from starkware.starknet.definitions import fields from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig @@ -48,13 +59,10 @@ ContractEntryPoint, EntryPointType, ) -from starkware.starknet.services.api.gateway.contract_address import calculate_contract_address -from starkware.starknet.services.api.gateway.transaction import Deploy, InvokeFunction, Transaction -from starkware.starknet.services.api.gateway.transaction_hash import ( - TransactionHashPrefix, - calculate_deploy_transaction_hash, - calculate_transaction_hash_common, +from starkware.starknet.services.api.gateway.contract_address import ( + calculate_contract_address_from_hash, ) +from starkware.starknet.services.api.gateway.transaction import Deploy, InvokeFunction, Transaction from starkware.starknet.storage.starknet_storage import BusinessLogicStarknetStorage from starkware.starkware_utils.config_base import Config from starkware.starkware_utils.error_handling import ( @@ -62,6 +70,7 @@ stark_assert, wrap_with_stark_exception, ) +from starkware.storage.storage import FactFetchingContext, Storage logger = logging.getLogger(__name__) @@ -256,7 +265,8 @@ class InternalDeploy(InternalTransaction): contract_address: int = field(metadata=fields.contract_address_metadata) contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) - contract_definition: ContractDefinition + contract_hash: bytes = field(metadata=fields.non_required_contract_hash_metadata) + constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) # A unique identifier of the transaction in the StarkNet network. @@ -269,32 +279,76 @@ class InternalDeploy(InternalTransaction): # The size of the header of the deployment information that is outputted by the StarkNet OS. deployment_info_header_size: ClassVar[int] = 3 + @marshmallow.decorators.pre_load + def replace_contract_definition_with_contract_hash( + schema, data: Dict[str, Any], many: bool, **kwargs + ) -> Dict[str, Any]: + if "contract_hash" in data: + return data + + contract_definition_json = data.pop("contract_definition") + contract_definition = ContractDefinition.load(data=contract_definition_json) + contract_hash = compute_contract_hash(contract_definition=contract_definition) + data["contract_hash"] = to_bytes(contract_hash).hex() + + return data + @classmethod def create( cls, contract_address_salt: int, contract_definition: ContractDefinition, constructor_calldata: List[int], - general_config, + chain_id: int, ): - contract_address = calculate_contract_address( + contract_hash = compute_contract_hash(contract_definition=contract_definition) + contract_address = calculate_contract_address_from_hash( salt=contract_address_salt, - contract_definition=contract_definition, + contract_hash=contract_hash, constructor_calldata=constructor_calldata, caller_address=0, ) return cls( contract_address=contract_address, contract_address_salt=contract_address_salt, - contract_definition=contract_definition, + contract_hash=to_bytes(contract_hash), constructor_calldata=constructor_calldata, hash_value=calculate_deploy_transaction_hash( contract_address=contract_address, constructor_calldata=constructor_calldata, - chain_id=general_config.chain_id.value, + chain_id=chain_id, ), ) + @classmethod + async def create_for_testing( + cls, + ffc: FactFetchingContext, + contract_definition: ContractDefinition, + contract_address_salt: int, + constructor_calldata: List[int], + chain_id: int, + ) -> "InternalDeploy": + """ + Creates an InternalDeploy transaction and writes its contract definition to the DB. + This constructor should only be used in tests. + """ + contract_definition_fact = ContractDefinitionFact(contract_definition=contract_definition) + await contract_definition_fact.set_fact(ffc=ffc) + tx = InternalDeploy.create( + contract_address_salt=contract_address_salt, + contract_definition=contract_definition, + constructor_calldata=constructor_calldata, + chain_id=chain_id, + ) + return tx + + async def get_contract_definition(self, storage: Storage) -> ContractDefinition: + contract_definition_fact = await ContractDefinitionFact.get_or_fail( + storage=storage, suffix=self.contract_hash + ) + return contract_definition_fact.contract_definition + @classmethod def _specific_from_external( cls, external_tx: Transaction, general_config: StarknetGeneralConfig @@ -304,15 +358,11 @@ def _specific_from_external( contract_address_salt=external_tx.contract_address_salt, contract_definition=external_tx.contract_definition, constructor_calldata=external_tx.constructor_calldata, - general_config=general_config, + chain_id=general_config.chain_id.value, ) def to_external(self) -> Deploy: - return Deploy( - contract_address_salt=self.contract_address_salt, - contract_definition=self.contract_definition, - constructor_calldata=self.constructor_calldata, - ) + raise NotImplementedError("Cannot convert internal deploy transaction to external object.") def get_state_selector(self, general_config: Config) -> StateSelector: """ @@ -342,18 +392,15 @@ async def _apply_specific_state_updates( ), ) - self.contract_definition.validate() + contract_definition = await self.get_contract_definition(storage=state.ffc.storage) + contract_definition.validate() - # Set contract definition fact to facts storage. - contract_definition_fact = ContractDefinitionFact( - contract_definition=self.contract_definition - ) - contract_hash = await contract_definition_fact.set_fact(ffc=state.ffc) - state.contract_definitions[contract_hash] = self.contract_definition + # Add contract definition to carried state. + state.contract_definitions[self.contract_hash] = contract_definition # Create updated contract state. newly_deployed_contract_state = await ContractState.create( - contract_hash=contract_hash, + contract_hash=self.contract_hash, storage_commitment_tree=contract_state.storage_commitment_tree, ) state.contract_states[self.contract_address] = ContractCarriedState( @@ -371,7 +418,8 @@ async def _apply_specific_state_updates( async def invoke_constructor( self, state: CarriedState, general_config: StarknetGeneralConfig ) -> TransactionExecutionInfo: - if len(self.contract_definition.entry_points_by_type[EntryPointType.CONSTRUCTOR]) == 0: + contract_definition = await self.get_contract_definition(storage=state.ffc.storage) + if len(contract_definition.entry_points_by_type[EntryPointType.CONSTRUCTOR]) == 0: stark_assert( len(self.constructor_calldata) == 0, code=StarknetErrorCode.TRANSACTION_FAILED, @@ -390,19 +438,11 @@ async def invoke_constructor( signature=[], hash_value=0, caller_address=0, + nonce=None, ) return await tx._apply_specific_state_updates(state=state, general_config=general_config) - def _synchronous_apply_specific_state_updates( - self, - state: CarriedState, - general_config: StarknetGeneralConfig, - loop: asyncio.AbstractEventLoop, - tx_execution_context: TransactionExecutionContext, - ) -> TransactionExecutionInfo: - raise NotImplementedError - @marshmallow_dataclass.dataclass(frozen=True) class InternalInvokeFunction(InternalTransaction): @@ -430,7 +470,7 @@ class InternalInvokeFunction(InternalTransaction): # A unique nonce, added by the StarkNet core contract on L1. # This nonce is used to make the hash_value of transactions that service L1 messages unique. # This field may be set only when entry_point_type is EntryPointType.L1_HANDLER. - nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata, default=None) + nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) # Class variables. tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION @@ -559,27 +599,32 @@ async def _apply_specific_state_updates( Applies self to 'state' by running _synchronous_apply_specific_state_updates. This is the asynchronous version of the method below. """ + account_contract_address = ( + 0 if self.entry_point_type is EntryPointType.CONSTRUCTOR else self.contract_address + ) + # Pass the running loop before entering to it. It will be used to run asynchronous # tasks, such as fetching data from storage. loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - _synchronous_apply_specific_state_updates = functools.partial( - self._synchronous_apply_specific_state_updates, + execute_contract_function = functools.partial( + self.execute_contract_function, state=state, general_config=general_config, loop=loop, tx_execution_context=TransactionExecutionContext.create( - n_steps=general_config.invoke_tx_max_n_steps + account_contract_address=account_contract_address, + n_steps=general_config.invoke_tx_max_n_steps, ), ) execution_info = await loop.run_in_executor( executor=None, # Runs on the default executor. - func=_synchronous_apply_specific_state_updates, + func=execute_contract_function, ) return execution_info - def _synchronous_apply_specific_state_updates( + def execute_contract_function( self, state: CarriedState, general_config: StarknetGeneralConfig, @@ -616,26 +661,11 @@ def _synchronous_apply_specific_state_updates( # Update resources usage (for bouncer). state.cairo_usage += runner.get_execution_resources() - # Build transaction execution info. - contract_call_cairo_usage = state.cairo_usage - previous_cairo_usage - call_info = ContractCall( - from_address=self.caller_address, - to_address=self.contract_address, - code_address=self.code_address, - entry_point_selector=self.entry_point_selector, - entry_point_type=self.entry_point_type, - calldata=self.calldata, - signature=self.signature, - cairo_usage=contract_call_cairo_usage, - events=syscall_handler.events, - l2_to_l1_messages=[], - internal_call_responses=syscall_handler.internal_call_responses, - storage_read_values=syscall_handler.starknet_storage.read_values, - storage_accessed_addresses=syscall_handler.starknet_storage.accessed_addresses, - ) - + # Build and return transaction execution info. return TransactionExecutionInfo( - call_info=call_info, + call_info=self._build_call_info( + previous_cairo_usage=previous_cairo_usage, syscall_handler=syscall_handler + ), l2_to_l1_messages=syscall_handler.l2_to_l1_messages, retdata=self._get_return_values(runner=runner), internal_calls=syscall_handler.internal_calls, @@ -759,12 +789,12 @@ def _run( initial_os_context=os_context, ) - # The OS touches all the arguments so they shouldn't be counted as holes. - assert runner.accessed_addresses is not None # When execution starts the stack holds entry_points_args + [ret_fp, ret_pc]. args_ptr = runner.initial_fp - (len(entry_points_args) + 2) - for i in range(len(entry_points_args)): - runner.accessed_addresses.add(args_ptr + i) + + # The arguments are touched by the OS and should not be counted as holes, mark them + # as accessed. + runner.mark_as_accessed(address=args_ptr, size=len(entry_points_args)) return runner, syscall_handler @@ -788,14 +818,16 @@ def _get_selected_entry_point( if ep0.selector == DEFAULT_ENTRY_POINT_SELECTOR: return ep0 + selector_formatter = fields.EntryPointSelectorField.format + address_formatter = fields.ContractAddressField.format # Non-unique entry points are not possible in a ContractDefinition object, thus # len(filtered_entry_points) <= 1. stark_assert( len(filtered_entry_points) == 1, code=StarknetErrorCode.ENTRY_POINT_NOT_FOUND_IN_CONTRACT, message=( - f"Entry point {self.entry_point_selector} not found in contract with address " - f"{self.contract_address}." + f"Entry point {selector_formatter(self.entry_point_selector)} not found in contract" + f" with address {address_formatter(self.contract_address)}." ), ) @@ -819,13 +851,35 @@ async def call(self, state: CarriedState, general_config: StarknetGeneralConfig) loop=loop, caller_address=self.caller_address, tx_execution_context=TransactionExecutionContext.create( - n_steps=general_config.invoke_tx_max_n_steps + account_contract_address=self.contract_address, + n_steps=general_config.invoke_tx_max_n_steps, ), ) runner, _ = await loop.run_in_executor(executor=None, func=_run) return self._get_return_values(runner=runner) + def _build_call_info( + self, + previous_cairo_usage: ExecutionResources, + syscall_handler: syscall_utils.BusinessLogicSysCallHandler, + ) -> ContractCall: + return ContractCall( + from_address=self.caller_address, + to_address=self.contract_address, + code_address=self.code_address, + entry_point_selector=self.entry_point_selector, + entry_point_type=self.entry_point_type, + calldata=self.calldata, + signature=self.signature, + cairo_usage=syscall_handler.state.cairo_usage - previous_cairo_usage, + events=syscall_handler.events, + l2_to_l1_messages=[], + internal_call_responses=syscall_handler.internal_call_responses, + storage_read_values=syscall_handler.starknet_storage.read_values, + storage_accessed_addresses=syscall_handler.starknet_storage.accessed_addresses, + ) + def _get_return_values(self, runner: CairoFunctionRunner) -> List[int]: with wrap_with_stark_exception( code=StarknetErrorCode.INVALID_RETURN_DATA, diff --git a/src/starkware/starknet/business_logic/internal_transaction_interface.py b/src/starkware/starknet/business_logic/internal_transaction_interface.py index 7bd3137f..2f0bdf95 100644 --- a/src/starkware/starknet/business_logic/internal_transaction_interface.py +++ b/src/starkware/starknet/business_logic/internal_transaction_interface.py @@ -1,4 +1,3 @@ -import asyncio import logging from abc import abstractmethod from typing import Iterable, Optional, cast @@ -6,10 +5,7 @@ from services.everest.business_logic.internal_transaction import EverestInternalStateTransaction from services.everest.business_logic.state import CarriedStateBase from starkware.starknet.business_logic.state import CarriedState, StateSelector -from starkware.starknet.business_logic.transaction_execution_objects import ( - TransactionExecutionContext, - TransactionExecutionInfo, -) +from starkware.starknet.business_logic.transaction_execution_objects import TransactionExecutionInfo from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starkware_utils.config_base import Config @@ -22,7 +18,7 @@ class InternalStateTransaction(EverestInternalStateTransaction): """ StarkNet internal state transaction. This is the API of transactions that update the state, - but do not necessarily have an external transaction counterpart. + but do not necessarily have an external transaction counterpart. See for example, SyntheticTransaction. """ @@ -74,13 +70,3 @@ async def _apply_specific_state_updates( self, state: CarriedState, general_config: StarknetGeneralConfig ) -> Optional[TransactionExecutionInfo]: pass - - @abstractmethod - def _synchronous_apply_specific_state_updates( - self, - state: CarriedState, - general_config: StarknetGeneralConfig, - loop: asyncio.AbstractEventLoop, - tx_execution_context: TransactionExecutionContext, - ) -> Optional[TransactionExecutionInfo]: - pass diff --git a/src/starkware/starknet/business_logic/state.py b/src/starkware/starknet/business_logic/state.py index ee4be1e2..a93a9529 100644 --- a/src/starkware/starknet/business_logic/state.py +++ b/src/starkware/starknet/business_logic/state.py @@ -192,7 +192,7 @@ def create_unfilled( ) @classmethod - async def create_empty_for_test( + async def empty_for_testing( cls, shared_state: Optional["SharedState"], ffc: FactFetchingContext, diff --git a/src/starkware/starknet/business_logic/state_objects.py b/src/starkware/starknet/business_logic/state_objects.py index fe4b0929..dbbe1b88 100644 --- a/src/starkware/starknet/business_logic/state_objects.py +++ b/src/starkware/starknet/business_logic/state_objects.py @@ -123,10 +123,11 @@ def assert_initialized(self, contract_address: int): Takes contract_address as input to improve the error message. """ + address_formatter = fields.ContractAddressField.format stark_assert( self.initialized, code=StarknetErrorCode.UNINITIALIZED_CONTRACT, - message=f"Contract with address {contract_address} is not deployed.", + message=f"Contract with address {address_formatter(contract_address)} is not deployed.", ) async def update( diff --git a/src/starkware/starknet/business_logic/transaction_execution_objects.py b/src/starkware/starknet/business_logic/transaction_execution_objects.py index 3c9e4d0d..ebdacfa3 100644 --- a/src/starkware/starknet/business_logic/transaction_execution_objects.py +++ b/src/starkware/starknet/business_logic/transaction_execution_objects.py @@ -27,13 +27,21 @@ class TransactionExecutionContext(ValidatedDataclass): A context for transaction execution, which is shared between internal calls. """ + # The account contract from which this transaction originates. + account_contract_address: int = field( + metadata=fields.AddressField.metadata(field_name="account_contract_address") + ) run_resources: RunResources # Used for tracking global events order. n_emitted_events: int = field(metadata=sequential_id_metadata("Number of emitted events")) @classmethod - def create(cls, n_steps: int) -> "TransactionExecutionContext": - return cls(run_resources=RunResources(n_steps=n_steps), n_emitted_events=0) + def create(cls, account_contract_address: int, n_steps: int) -> "TransactionExecutionContext": + return cls( + account_contract_address=account_contract_address, + run_resources=RunResources(n_steps=n_steps), + n_emitted_events=0, + ) @dataclasses.dataclass(frozen=True) @@ -160,10 +168,6 @@ def empty(cls, to_address: int) -> "ContractCall": storage_accessed_addresses=set(), ) - @classmethod - def empty_for_tests(cls) -> "ContractCall": - return cls.empty(to_address=0) - @property def state_selector(self) -> StateSelector: code_address = self.to_address if self.code_address is None else self.code_address diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt index fdfb78fe..6d38fe5a 100644 --- a/src/starkware/starknet/cli/CMakeLists.txt +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -9,6 +9,7 @@ python_lib(starknet_cli_lib cairo_compile_lib cairo_tracer_lib cairo_version_lib + cairo_vm_crypto_lib cairo_vm_utils_lib everest_definitions_lib services_external_api_lib diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index c88fdf50..3fe7e5a7 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -20,8 +20,8 @@ from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager from starkware.starknet.cli.reconstruct_starknet_traceback import reconstruct_starknet_traceback -from starkware.starknet.compiler.compile import get_selector_from_name from starkware.starknet.definitions import fields +from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.public.abi_structs import identifier_manager_from_abi from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.services.api.feeder_gateway.feeder_gateway_client import FeederGatewayClient @@ -480,8 +480,8 @@ async def get_transaction(args, command_args): feeder_gateway_client = get_feeder_gateway_client(args) - tx_info_as_dict = await feeder_gateway_client.get_transaction(tx_hash=args.hash) - print(json.dumps(tx_info_as_dict, indent=4, sort_keys=True)) + tx_info = await feeder_gateway_client.get_transaction(tx_hash=args.hash) + print(tx_info.dumps(indent=4, sort_keys=True)) async def get_transaction_receipt(args, command_args): @@ -493,8 +493,8 @@ async def get_transaction_receipt(args, command_args): feeder_gateway_client = get_feeder_gateway_client(args) - tx_receipt_as_dict = await feeder_gateway_client.get_transaction_receipt(tx_hash=args.hash) - print(json.dumps(tx_receipt_as_dict, indent=4, sort_keys=True)) + tx_receipt = await feeder_gateway_client.get_transaction_receipt(tx_hash=args.hash) + print(tx_receipt.dumps(indent=4, sort_keys=True)) def handle_network_param(args): @@ -539,10 +539,23 @@ async def get_block(args, command_args): feeder_gateway_client = get_feeder_gateway_client(args) - block_as_dict = await feeder_gateway_client.get_block( - block_hash=args.hash, block_number=args.number + block = await feeder_gateway_client.get_block(block_hash=args.hash, block_number=args.number) + print(block.dumps(indent=4, sort_keys=True)) + + +async def get_state_update(args, command_args): + parser = argparse.ArgumentParser(description=("Outputs the state update of a given block")) + add_block_identifier_argument( + parser=parser, block_role_description="display", with_block_prefix=True + ) + + parser.parse_args(command_args, namespace=args) + feeder_gateway_client = get_feeder_gateway_client(args) + + block_state_updates = await feeder_gateway_client.get_state_update( + block_hash=args.block_hash, block_number=args.block_number ) - print(json.dumps(block_as_dict, indent=4, sort_keys=True)) + print(json.dumps(block_state_updates, indent=4, sort_keys=True)) async def get_code(args, command_args): @@ -569,6 +582,30 @@ async def get_code(args, command_args): print(json.dumps(code, indent=4, sort_keys=True)) +async def get_full_contract(args, command_args): + parser = argparse.ArgumentParser( + description=( + "Outputs the contract definition of the contract at the given address with respect to " + "a specific block. In case no block ID is given, uses the latest block." + ) + ) + parser.add_argument( + "--contract_address", type=str, help="The address of the contract.", required=True + ) + add_block_identifier_argument(parser=parser, block_role_description="extract information from") + + parser.parse_args(command_args, namespace=args) + + feeder_gateway_client = get_feeder_gateway_client(args) + + contract_definition = await feeder_gateway_client.get_full_contract( + contract_address=int(args.contract_address, 16), + block_hash=args.block_hash, + block_number=args.block_number, + ) + print(json.dumps(contract_definition, indent=4, sort_keys=True)) + + async def get_contract_addresses(args, command_args): argparse.ArgumentParser(description="Outputs the addresses of the StarkNet system contracts.") @@ -634,8 +671,10 @@ async def main(): "deploy": deploy, "deploy_account": deploy_account, "get_block": get_block, + "get_state_update": get_state_update, "get_code": get_code, "get_contract_addresses": get_contract_addresses, + "get_full_contract": get_full_contract, "get_storage_at": get_storage_at, "get_transaction": get_transaction, "get_transaction_receipt": get_transaction_receipt, diff --git a/src/starkware/starknet/common/syscalls.cairo b/src/starkware/starknet/common/syscalls.cairo index f01a4d8e..c1d7a1fe 100644 --- a/src/starkware/starknet/common/syscalls.cairo +++ b/src/starkware/starknet/common/syscalls.cairo @@ -234,8 +234,7 @@ end # Returns the signature information of the transaction. # -# Note that currently a malicious sequencer may choose to return different values each time -# this function is called. +# NOTE: This function is deprecated. Use get_tx_info() instead. func get_tx_signature{syscall_ptr : felt*}() -> (signature_len : felt, signature : felt*): let syscall = [cast(syscall_ptr, GetTxSignature*)] assert syscall.request = GetTxSignatureRequest(selector=GET_TX_SIGNATURE_SELECTOR) @@ -306,3 +305,47 @@ func emit_event{syscall_ptr : felt*}(keys_len : felt, keys : felt*, data_len : f let syscall_ptr = syscall_ptr + EmitEvent.SIZE return () end + +struct TxInfo: + # The version of the transaction. It is fixed (currently, 0) in the OS, and should be + # signed by the account contract. + # This field allows invalidating old transactions, whenever the meaning of the other + # transaction fields is changed (in the OS). + member version : felt + + # The account contract from which this transaction originates. + member account_contract_address : felt + + # The max_fee field of the transaction. + member max_fee : felt + + # The signature of the transaction. + member signature_len : felt + member signature : felt* +end + +const GET_TX_INFO_SELECTOR = 'GetTxInfo' + +# Describes the GetTxInfo system call format. +struct GetTxInfoRequest: + # The system call selector (= GET_TX_INFO_SELECTOR). + member selector : felt +end + +struct GetTxInfoResponse: + member tx_info : TxInfo* +end + +struct GetTxInfo: + member request : GetTxInfoRequest + member response : GetTxInfoResponse +end + +func get_tx_info{syscall_ptr : felt*}() -> (tx_info : TxInfo*): + let syscall = [cast(syscall_ptr, GetTxInfo*)] + assert syscall.request = GetTxInfoRequest(selector=GET_TX_INFO_SELECTOR) + %{ syscall_handler.get_tx_info(segments=segments, syscall_ptr=ids.syscall_ptr) %} + let response = syscall.response + let syscall_ptr = syscall_ptr + GetTxInfo.SIZE + return (tx_info=response.tx_info) +end diff --git a/src/starkware/starknet/compiler/compile.py b/src/starkware/starknet/compiler/compile.py index e4fcb8c5..f927103f 100644 --- a/src/starkware/starknet/compiler/compile.py +++ b/src/starkware/starknet/compiler/compile.py @@ -129,8 +129,7 @@ def compile_starknet_codes( ) # Dump and load program, so that it is converted to the canonical form. - program_schema = program.Schema() - program = program_schema.load(data=program_schema.dump(obj=program)) + program = Program.load(data=program.dump()) assert isinstance(preprocessed, StarknetPreprocessedProgram) return ContractDefinition( diff --git a/src/starkware/starknet/compiler/external_wrapper.py b/src/starkware/starknet/compiler/external_wrapper.py index d3aa4a8e..acd1bbfe 100644 --- a/src/starkware/starknet/compiler/external_wrapper.py +++ b/src/starkware/starknet/compiler/external_wrapper.py @@ -45,7 +45,7 @@ decode_data, struct_to_argument_info_list, ) -from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE +from starkware.starknet.definitions import constants from starkware.starknet.public.abi import DEFAULT_ENTRY_POINT_NAME, DEFAULT_L1_ENTRY_POINT_NAME from starkware.starknet.services.api.contract_definition import SUPPORTED_BUILTINS @@ -142,7 +142,7 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): if is_raw_output: self.validate_raw_output_signature(elm=elm) - if self.file_lang != STARKNET_LANG_DIRECTIVE: + if self.file_lang != constants.STARKNET_LANG_DIRECTIVE: raise PreprocessorError( "External decorators can only be used in source files that contain the " '"%lang starknet" directive.', diff --git a/src/starkware/starknet/compiler/starknet_preprocessor.py b/src/starkware/starknet/compiler/starknet_preprocessor.py index 84495634..6825c21e 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor.py @@ -26,7 +26,7 @@ parse_entry_point_decorators, ) from starkware.starknet.compiler.validation_utils import get_function_attr -from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE +from starkware.starknet.definitions import constants from starkware.starknet.public.abi_structs import ( prepare_type_for_abi, struct_definition_to_abi_entry, @@ -71,7 +71,7 @@ def visit_BuiltinsDirective(self, directive: BuiltinsDirective): ) def visit_LangDirective(self, directive: LangDirective): - if directive.name != STARKNET_LANG_DIRECTIVE: + if directive.name != constants.STARKNET_LANG_DIRECTIVE: raise PreprocessorError( f"Unsupported %lang directive. Are you using the correct compiler?", location=directive.location, diff --git a/src/starkware/starknet/compiler/validation_utils.py b/src/starkware/starknet/compiler/validation_utils.py index 99937998..1ff9a2ba 100644 --- a/src/starkware/starknet/compiler/validation_utils.py +++ b/src/starkware/starknet/compiler/validation_utils.py @@ -11,7 +11,7 @@ ) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.starknet.compiler.data_encoder import ArgumentInfo, EncodingType, encode_data -from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE +from starkware.starknet.definitions import constants TAttr = TypeVar("TAttr") @@ -51,7 +51,7 @@ def verify_starknet_lang( """ Verifies that file_lang equals STARKNET_LANG_DIRECTIVE and raises an exception otherwise. """ - if file_lang != STARKNET_LANG_DIRECTIVE: + if file_lang != constants.STARKNET_LANG_DIRECTIVE: raise PreprocessorError( f"{name_in_error_message} can only be used in source files that contain the " '"%lang starknet" directive.', diff --git a/src/starkware/starknet/core/os/CMakeLists.txt b/src/starkware/starknet/core/os/CMakeLists.txt index 6fe144c8..e2383ba5 100644 --- a/src/starkware/starknet/core/os/CMakeLists.txt +++ b/src/starkware/starknet/core/os/CMakeLists.txt @@ -18,6 +18,7 @@ python_lib(starknet_os_abi_lib starknet_abi_lib starknet_contract_definition_lib starkware_python_utils_lib + pip_cachetools ) python_lib(starknet_os_utils_lib @@ -71,3 +72,34 @@ full_python_test(starknet_os_program_hash_test starkware_python_utils_lib pip_pytest ) + +python_lib(starknet_transaction_hash_lib + PREFIX starkware/starknet/core/os + + FILES + transaction_hash.py + + LIBS + cairo_common_lib + cairo_vm_crypto_lib + starknet_contract_definition_lib + starknet_definitions_lib + starkware_python_utils_lib +) + +full_python_test(starknet_transaction_hash_test + PREFIX starkware/starknet/core/os + PYTHON python3.7 + TESTED_MODULES starkware/starknet/core/os + + FILES + transaction_hash_test.py + + LIBS + cairo_common_lib + starknet_definitions_lib + starknet_transaction_hash_lib + starkware_crypto_lib + pip_pytest + pip_pytest_asyncio +) diff --git a/src/starkware/starknet/core/os/builtins.cairo b/src/starkware/starknet/core/os/builtins.cairo new file mode 100644 index 00000000..9b1b49fb --- /dev/null +++ b/src/starkware/starknet/core/os/builtins.cairo @@ -0,0 +1,52 @@ +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, HashBuiltin, SignatureBuiltin +from starkware.cairo.common.registers import get_fp_and_pc + +struct BuiltinPointers: + member pedersen : HashBuiltin* + member range_check : felt + member ecdsa : felt + member bitwise : felt +end + +# A struct containing the ASCII encoding of each builtin. +struct BuiltinEncodings: + member pedersen : felt + member range_check : felt + member ecdsa : felt + member bitwise : felt +end + +# A struct containing the instance size of each builtin. +struct BuiltinInstanceSizes: + member pedersen : felt + member range_check : felt + member ecdsa : felt + member bitwise : felt +end + +struct BuiltinParams: + member builtin_encodings : BuiltinEncodings* + member builtin_instance_sizes : BuiltinInstanceSizes* +end + +func get_builtin_params() -> (builtin_params : BuiltinParams*): + alloc_locals + let (local __fp__, _) = get_fp_and_pc() + + local builtin_encodings : BuiltinEncodings = BuiltinEncodings( + pedersen='pedersen', + range_check='range_check', + ecdsa='ecdsa', + bitwise='bitwise') + + local builtin_instance_sizes : BuiltinInstanceSizes = BuiltinInstanceSizes( + pedersen=HashBuiltin.SIZE, + range_check=1, + ecdsa=SignatureBuiltin.SIZE, + bitwise=BitwiseBuiltin.SIZE) + + local builtin_params : BuiltinParams = BuiltinParams( + builtin_encodings=&builtin_encodings, + builtin_instance_sizes=&builtin_instance_sizes) + return (builtin_params=&builtin_params) +end diff --git a/src/starkware/starknet/core/os/contract_hash.py b/src/starkware/starknet/core/os/contract_hash.py index dce1185c..c72d18b1 100644 --- a/src/starkware/starknet/core/os/contract_hash.py +++ b/src/starkware/starknet/core/os/contract_hash.py @@ -1,8 +1,13 @@ +import contextlib import dataclasses import itertools import json import os -from typing import Callable, List +from contextvars import ContextVar +from functools import lru_cache +from typing import Callable, List, Optional + +import cachetools from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner from starkware.cairo.common.structs import CairoStructFactory, CairoStructProxy @@ -20,7 +25,28 @@ CAIRO_FILE = os.path.join(os.path.dirname(__file__), "contracts.cairo") +contract_hash_cache_ctx_var: ContextVar[Optional[cachetools.LRUCache]] = ContextVar( + "contract_hash_cache", default=None +) + + +@contextlib.contextmanager +def set_contract_hash_cache(cache: cachetools.LRUCache): + """ + Sets a cache to be used by compute_contract_hash(). + """ + assert ( + contract_hash_cache_ctx_var.get() is None + ), "Cannot replace an existing contract_hash_cache." + + token = contract_hash_cache_ctx_var.set(cache) + try: + yield + finally: + contract_hash_cache_ctx_var.reset(token) + +@lru_cache() def load_program() -> Program: return compile_cairo_files( [CAIRO_FILE], @@ -31,6 +57,26 @@ def load_program() -> Program: def compute_contract_hash( contract_definition: ContractDefinition, hash_func: Callable[[int, int], int] = pedersen_hash +) -> int: + cache = contract_hash_cache_ctx_var.get() + if cache is None: + return compute_contract_hash_inner( + contract_definition=contract_definition, hash_func=hash_func + ) + + contract_definition_bytes = contract_definition.dumps(sort_keys=True).encode() + key = (starknet_keccak(data=contract_definition_bytes), hash_func) + + if key not in cache: + cache[key] = compute_contract_hash_inner( + contract_definition=contract_definition, hash_func=hash_func + ) + + return cache[key] + + +def compute_contract_hash_inner( + contract_definition: ContractDefinition, hash_func: Callable[[int, int], int] ) -> int: program = load_program() contract_definition_struct = get_contract_definition_struct( @@ -59,9 +105,7 @@ def compute_hinted_contract_definition_hash(contract_definition: ContractDefinit """ Computes the hash of the contract definition, including hints. """ - dumped_program = Program.Schema().dump( - obj=dataclasses.replace(contract_definition.program, debug_info=None) - ) + dumped_program = dataclasses.replace(contract_definition.program, debug_info=None).dump() if len(dumped_program["attributes"]) == 0: # Remove attributes field from raw dictionary, for hash backward compatibility of # contracts deployed prior to adding this feature. diff --git a/src/starkware/starknet/core/os/os.cairo b/src/starkware/starknet/core/os/os.cairo index 65b6790e..86b5b23f 100644 --- a/src/starkware/starknet/core/os/os.cairo +++ b/src/starkware/starknet/core/os/os.cairo @@ -25,7 +25,7 @@ func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecds %{ from starkware.starknet.core.os.os_input import StarknetOsInput - os_input = StarknetOsInput.Schema().load(program_input) + os_input = StarknetOsInput.load(data=program_input) ids.os_output.initial_outputs.messages_to_l1 = segments.add_temp_segment() ids.os_output.initial_outputs.messages_to_l2 = segments.add_temp_segment() @@ -36,14 +36,14 @@ func main{output_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, ecds block_timestamp=nondet %{ syscall_handler.block_info.block_timestamp %}, block_number=nondet %{ syscall_handler.block_info.block_number %}) - tempvar outputs : OsCarriedOutputs = os_output.initial_outputs + tempvar outputs : OsCarriedOutputs* = &os_output.initial_outputs with outputs: let (local reserved_range_checks_end, state_changes) = execute_transactions( block_info=&os_output.block_info) end - assert os_output.final_outputs = outputs + assert os_output.final_outputs = [outputs] local ecdsa_ptr = ecdsa_ptr local bitwise_ptr = bitwise_ptr diff --git a/src/starkware/starknet/core/os/os_program.py b/src/starkware/starknet/core/os/os_program.py index dfcd4566..99230f65 100644 --- a/src/starkware/starknet/core/os/os_program.py +++ b/src/starkware/starknet/core/os/os_program.py @@ -11,7 +11,7 @@ @cachetools.cached(cache={}) def get_os_program() -> Program: with open(STARKNET_OS_COMPILED_PATH, "r") as file: - return Program.Schema().loads(json_data=file.read()) + return Program.loads(data=file.read()) @cachetools.cached(cache={}) diff --git a/src/starkware/starknet/core/os/os_utils.py b/src/starkware/starknet/core/os/os_utils.py index 86a504f4..20bd5211 100644 --- a/src/starkware/starknet/core/os/os_utils.py +++ b/src/starkware/starknet/core/os/os_utils.py @@ -6,7 +6,7 @@ from starkware.starknet.core.os import segment_utils, syscall_utils from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.public.abi import SYSCALL_PTR_OFFSET -from starkware.starkware_utils.error_handling import stark_assert, wrap_with_stark_exception +from starkware.starkware_utils.error_handling import wrap_with_stark_exception def update_builtin_pointers( @@ -86,10 +86,4 @@ def validate_and_process_os_context( segment_base_ptr=syscall_base_ptr, segment_stop_ptr=syscall_stop_ptr, ) - - expected_stop_ptr = syscall_handler.expected_syscall_ptr - stark_assert( - syscall_stop_ptr == expected_stop_ptr, - code=StarknetErrorCode.SECURITY_ERROR, - message=f"Bad syscall_stop_ptr, Expected {expected_stop_ptr}, got {syscall_stop_ptr}.", - ) + syscall_handler.post_run(runner=runner, syscall_stop_ptr=syscall_stop_ptr) diff --git a/src/starkware/starknet/core/os/output.cairo b/src/starkware/starknet/core/os/output.cairo index 10418bd0..bd5ec4c7 100644 --- a/src/starkware/starknet/core/os/output.cairo +++ b/src/starkware/starknet/core/os/output.cairo @@ -1,3 +1,4 @@ +from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.segments import relocate_segment from starkware.cairo.common.serialize import serialize_word from starkware.starknet.core.os.state import CommitmentTreeUpdateOutput @@ -38,6 +39,14 @@ struct OsCarriedOutputs: member deployment_info : DeploymentInfoHeader* end +func os_carried_outputs_new( + messages_to_l1 : MessageToL1Header*, messages_to_l2 : MessageToL2Header*, + deployment_info : DeploymentInfoHeader*) -> (os_carried_outputs : OsCarriedOutputs*): + let (fp_val, pc_val) = get_fp_and_pc() + static_assert OsCarriedOutputs.SIZE == Args.SIZE + return (os_carried_outputs=cast(fp_val - 2 - OsCarriedOutputs.SIZE, OsCarriedOutputs*)) +end + struct BlockInfo: # Currently, the block timestamp is not validated. member block_timestamp : felt diff --git a/src/starkware/starknet/core/os/program_hash.json b/src/starkware/starknet/core/os/program_hash.json index bfcb0789..9bd7437f 100644 --- a/src/starkware/starknet/core/os/program_hash.json +++ b/src/starkware/starknet/core/os/program_hash.json @@ -1,3 +1,3 @@ { - "program_hash": "0x6a3ee04d874caf7a335343a305988e7ad67bfcaf536dc8a5d26189da788dbbc" + "program_hash": "0x26b17b932ce47266d0e6ae3d6bb17c9189a755b41e9e48b3899abc2aae1a298" } diff --git a/src/starkware/starknet/core/os/syscall_utils.py b/src/starkware/starknet/core/os/syscall_utils.py index 56aae14b..9408de67 100644 --- a/src/starkware/starknet/core/os/syscall_utils.py +++ b/src/starkware/starknet/core/os/syscall_utils.py @@ -1,15 +1,15 @@ import asyncio import contextlib import dataclasses -import functools from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Mapping, Type, TypeVar, Union, cast +from typing import Callable, Iterator, List, Mapping, Optional, Type, TypeVar, Union, cast +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner from starkware.cairo.common.structs import CairoStructFactory, CairoStructProxy from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeFelt, TypePointer from starkware.cairo.lang.compiler.identifier_definition import StructDefinition from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager -from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.python.utils import camel_to_snake_case, safe_zip from starkware.starknet.business_logic.internal_transaction_interface import ( InternalStateTransaction, @@ -30,7 +30,7 @@ BusinessLogicStarknetStorage, StarknetStorageInterface, ) -from starkware.starkware_utils.error_handling import StarkException +from starkware.starkware_utils.error_handling import StarkException, stark_assert TCallable = TypeVar("TCallable", bound=Callable) @@ -86,6 +86,9 @@ def __init__(self, general_config: StarknetGeneralConfig): "starkware.starknet.common.syscalls.GetContractAddress", "starkware.starknet.common.syscalls.GetContractAddressRequest", "starkware.starknet.common.syscalls.GetContractAddressResponse", + "starkware.starknet.common.syscalls.GetTxInfo", + "starkware.starknet.common.syscalls.GetTxInfoRequest", + "starkware.starknet.common.syscalls.GetTxInfoResponse", "starkware.starknet.common.syscalls.GetTxSignature", "starkware.starknet.common.syscalls.GetTxSignatureRequest", "starkware.starknet.common.syscalls.GetTxSignatureResponse", @@ -94,6 +97,7 @@ def __init__(self, general_config: StarknetGeneralConfig): "starkware.starknet.common.syscalls.StorageReadRequest", "starkware.starknet.common.syscalls.StorageReadResponse", "starkware.starknet.common.syscalls.StorageWrite", + "starkware.starknet.common.syscalls.TxInfo", ], ).structs @@ -149,6 +153,11 @@ def get_selector(syscall_name: str): syscall_request_struct=self.structs.GetContractAddressRequest, syscall_size=self.structs.GetContractAddress.size, ), + "get_tx_info": SysCallInfo( + selector=get_selector("get_tx_info"), + syscall_request_struct=self.structs.GetTxInfoRequest, + syscall_size=self.structs.GetTxInfo.size, + ), "send_message_to_l1": SysCallInfo( selector=get_selector("send_message_to_l1"), syscall_request_struct=self.structs.SendMessageToL1SysCall, @@ -174,40 +183,28 @@ def get_selector(syscall_name: str): # Public API. def call_contract(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - retdata = self._call_contract( - segments=segments, syscall_ptr=syscall_ptr, syscall_name="call_contract" - ) - self._write_call_contract_response( + self._call_contract_and_write_response( + syscall_name="call_contract", segments=segments, syscall_ptr=syscall_ptr, - retdata=retdata, ) def delegate_call(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - retdata = self._call_contract( - segments=segments, syscall_ptr=syscall_ptr, syscall_name="delegate_call" - ) - self._write_call_contract_response( + self._call_contract_and_write_response( + syscall_name="delegate_call", segments=segments, syscall_ptr=syscall_ptr, - retdata=retdata, ) def delegate_l1_handler(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - retdata = self._call_contract( - segments=segments, syscall_ptr=syscall_ptr, syscall_name="delegate_l1_handler" - ) - self._write_call_contract_response( + self._call_contract_and_write_response( + syscall_name="delegate_l1_handler", segments=segments, syscall_ptr=syscall_ptr, - retdata=retdata, ) def emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - """ - Handles the emit_event system call. - """ - self._emit_event(segments=segments, syscall_ptr=syscall_ptr) + return def get_caller_address(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): """ @@ -270,8 +267,24 @@ def get_sequencer_address(self, segments: MemorySegmentManager, syscall_ptr: Rel syscall_ptr=syscall_ptr, ) + def get_tx_info(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): + """ + Handles the get_tx_info system call. + """ + self._read_and_validate_syscall_request( + syscall_name="get_tx_info", segments=segments, syscall_ptr=syscall_ptr + ) + + response = self.structs.GetTxInfoResponse(tx_info=self._get_tx_info_ptr(segments=segments)) + self._write_syscall_response( + syscall_name="GetTxInfo", + response=response, + segments=segments, + syscall_ptr=syscall_ptr, + ) + def send_message_to_l1(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - self._send_message_to_l1(segments=segments, syscall_ptr=syscall_ptr) + return def get_block_timestamp(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): """ @@ -292,9 +305,16 @@ def get_block_timestamp(self, segments: MemorySegmentManager, syscall_ptr: Reloc ) def get_tx_signature(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - signature = self._get_tx_signature(segments=segments, syscall_ptr=syscall_ptr) + """ + Handles the get_tx_signature system call. + """ + self._read_and_validate_syscall_request( + syscall_name="get_tx_signature", segments=segments, syscall_ptr=syscall_ptr + ) + tx_info_ptr = self._get_tx_info_ptr(segments=segments) + tx_info = self.structs.TxInfo.from_ptr(memory=segments.memory, addr=tx_info_ptr) response = self.structs.GetTxSignatureResponse( - signature_len=len(signature), signature=signature + signature_len=tx_info.signature_len, signature=tx_info.signature ) self._write_syscall_response( @@ -304,6 +324,12 @@ def get_tx_signature(self, segments: MemorySegmentManager, syscall_ptr: Relocata syscall_ptr=syscall_ptr, ) + @abstractmethod + def _get_tx_info_ptr(self, segments: MemorySegmentManager): + """ + Returns a pointer to the TxInfo struct. + """ + def storage_read(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): """ Handles the storage_read system call. @@ -384,15 +410,18 @@ def _call_contract( syscall_name can be "call_contract", "delegate_call" or "delegate_l1_handler". """ - def _write_call_contract_response( + def _call_contract_and_write_response( self, + syscall_name: str, segments: MemorySegmentManager, syscall_ptr: RelocatableValue, - retdata: List[int], ): """ - Fills the CallContractResponse struct. + Executes the contract call and fills the CallContractResponse struct. """ + retdata = self._call_contract( + segments=segments, syscall_ptr=syscall_ptr, syscall_name=syscall_name + ) response = self.structs.CallContractResponse( retdata_size=len(retdata), retdata=self._allocate_segment(segments=segments), @@ -431,26 +460,6 @@ def _get_contract_address(self, segments: MemorySegmentManager, syscall_ptr: Rel Specific implementation of the get_contract_address system call. """ - def _emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - """ - Specific implementation of the emit_event system call. - """ - return - - def _send_message_to_l1(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): - """ - Specific implementation of the send_message_to_l1 system call. - """ - return - - @abstractmethod - def _get_tx_signature( - self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue - ) -> List[int]: - """ - Returns the signature information for the transaction. - """ - @abstractmethod def _storage_read(self, address: int) -> int: """ @@ -483,25 +492,6 @@ def get_runtime_type(cairo_type: CairoType) -> Union[Type[int], Type[Relocatable raise NotImplementedError(f"Unexpected type: {cairo_type.format()}.") -def count_syscall(func: TCallable) -> TCallable: - """ - Used to decorate system calls. - Increases the counter for the relevant system call by 1 for each time the call is - invoked. - """ - - @functools.wraps(func) - def increment_syscall_counter_wrapper(*args, **kwargs): - self: BusinessLogicSysCallHandler = args[0] - syscall_name = func.__name__ - previous_syscall_count = self.state.syscall_counter.get(syscall_name, 0) - self.state.syscall_counter[syscall_name] = previous_syscall_count + 1 - - return func(*args, **kwargs) - - return cast(TCallable, increment_syscall_counter_wrapper) - - class BusinessLogicSysCallHandler(SysCallHandlerBase): """ The SysCallHandler implementation that is used by the batcher. @@ -540,9 +530,16 @@ def __init__( # Kept for validations during the run. self.expected_syscall_ptr = initial_syscall_ptr + # A pointer to the Cairo TxInfo struct. + self.tx_info_ptr: Optional[RelocatableValue] = None + def _allocate_segment(self, segments: MemorySegmentManager) -> RelocatableValue: return segments.add() + def _count_syscall(self, syscall_name: str): + previous_syscall_count = self.state.syscall_counter.get(syscall_name, 0) + self.state.syscall_counter[syscall_name] = previous_syscall_count + 1 + def _read_and_validate_syscall_request( self, syscall_name: str, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> CairoStructProxy: @@ -550,6 +547,9 @@ def _read_and_validate_syscall_request( Returns the system call request written in the syscall segment, starting at syscall_ptr. Performs validations on the request. """ + # Update syscall count. + self._count_syscall(syscall_name=syscall_name) + request = self._read_syscall_request( syscall_name=syscall_name, segments=segments, syscall_ptr=syscall_ptr ) @@ -580,7 +580,6 @@ def _read_and_validate_syscall_request( return request - @count_syscall def _call_contract( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue, syscall_name: str ) -> List[int]: @@ -616,16 +615,17 @@ def _call_contract( entry_point_selector=cast(int, request.function_selector), entry_point_type=entry_point_type, calldata=calldata, - signature=[], + signature=self.signature, hash_value=0, caller_address=caller_address, + nonce=None, ) with self.contract_call_execution_context( tx=tx, called_contract_address=tx.contract_address ): # Execute contract call. - execution_info = tx._synchronous_apply_specific_state_updates( + execution_info = tx.execute_contract_function( state=self.state, general_config=self.general_config, loop=self.loop, @@ -698,10 +698,9 @@ def _update_starknet_storage(self): contract_storage_updates = self.state.contract_states[self.contract_address].storage_updates self.starknet_storage.reset_state(storage_updates=contract_storage_updates) - @count_syscall - def _emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): + def emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): """ - Specific implementation of the emit_event system call. + Handles the emit_event system call. """ request = self._read_and_validate_syscall_request( syscall_name="emit_event", segments=segments, syscall_ptr=syscall_ptr @@ -721,8 +720,20 @@ def _emit_event(self, segments: MemorySegmentManager, syscall_ptr: RelocatableVa ) ) - @count_syscall - def _send_message_to_l1(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): + def _get_tx_info_ptr(self, segments: MemorySegmentManager) -> RelocatableValue: + if self.tx_info_ptr is None: + tx_info = self.structs.TxInfo( + version=0, + account_contract_address=self.tx_execution_context.account_contract_address, + max_fee=0, + signature_len=len(self.signature), + signature=segments.gen_arg(self.signature), + ) + self.tx_info_ptr = cast(RelocatableValue, segments.gen_arg(arg=tx_info)) + + return self.tx_info_ptr + + def send_message_to_l1(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): request = self._read_and_validate_syscall_request( syscall_name="send_message_to_l1", segments=segments, syscall_ptr=syscall_ptr ) @@ -739,15 +750,12 @@ def _send_message_to_l1(self, segments: MemorySegmentManager, syscall_ptr: Reloc ) ) - @count_syscall def _get_block_number(self) -> int: return self.state.block_info.block_number - @count_syscall def _get_block_timestamp(self) -> int: return self.state.block_info.block_timestamp - @count_syscall def _get_caller_address( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> int: @@ -757,7 +765,6 @@ def _get_caller_address( return self.caller_address - @count_syscall def _get_contract_address( self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> int: @@ -767,21 +774,9 @@ def _get_contract_address( return self.contract_address - @count_syscall - def _get_tx_signature( - self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue - ) -> List[int]: - self._read_and_validate_syscall_request( - syscall_name="get_tx_signature", segments=segments, syscall_ptr=syscall_ptr - ) - - return self.signature - - @count_syscall def _storage_read(self, address: int) -> int: return self.starknet_storage.read(address=address) - @count_syscall def _storage_write(self, address: int, value: int): # Read the value before the write operation in order to log it in the read_values list. # This value is needed to create the DictAccess while executing the corresponding @@ -801,10 +796,54 @@ def _storage_write(self, address: int, value: int): previous_n_writings + 1 ) - @count_syscall def get_sequencer_address(self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue): return super().get_sequencer_address(segments=segments, syscall_ptr=syscall_ptr) + def post_run_tx_info_related_logic(self, runner: CairoFunctionRunner): + """ + Validates that there were no out of bounds writes to the tx_info related segments and marks + tx_info as accessed. + """ + tx_info_ptr = self.tx_info_ptr + if tx_info_ptr is None: + # tx_info_ptr was never used. + return + + segments = runner.segments + + tx_info_size = self.structs.TxInfo.size + stark_assert( + segments.get_segment_used_size(segment_index=tx_info_ptr.segment_index) == tx_info_size, + code=StarknetErrorCode.SECURITY_ERROR, + message=f"Out of bounds write to tx_info segment.", + ) + + runner.mark_as_accessed(address=tx_info_ptr, size=tx_info_size) + + tx_info = self.structs.TxInfo.from_ptr(memory=segments.memory, addr=tx_info_ptr) + signature_ptr = tx_info.signature + stark_assert( + segments.get_segment_used_size(segment_index=signature_ptr.segment_index) + == len(self.signature), + code=StarknetErrorCode.SECURITY_ERROR, + message=f"Out of bounds write to signature segment.", + ) + + runner.mark_as_accessed(address=signature_ptr, size=len(self.signature)) + + def post_run(self, runner: CairoFunctionRunner, syscall_stop_ptr: MaybeRelocatable): + """ + Performs post run syscall related tasks. + """ + expected_stop_ptr = self.expected_syscall_ptr + stark_assert( + syscall_stop_ptr == expected_stop_ptr, + code=StarknetErrorCode.SECURITY_ERROR, + message=f"Bad syscall_stop_ptr, Expected {expected_stop_ptr}, got {syscall_stop_ptr}.", + ) + + self.post_run_tx_info_related_logic(runner=runner) + class OsSysCallHandler(SysCallHandlerBase): """ @@ -840,6 +879,12 @@ def __init__( self.block_info = block_info + # A pointer to the Cairo TxInfo struct. + # This pointer needs to match the TxInfo pointer that is going to be used during the system + # call validation by the StarkNet OS. + # Set during enter_tx. + self.tx_info_ptr: Optional[RelocatableValue] = None + def _read_and_validate_syscall_request( self, syscall_name: str, segments: MemorySegmentManager, syscall_ptr: RelocatableValue ) -> CairoStructProxy: @@ -885,10 +930,9 @@ def _get_contract_address( ) -> int: return self.call_stack[-1].to_address - def _get_tx_signature( - self, segments: MemorySegmentManager, syscall_ptr: RelocatableValue - ) -> List[int]: - return self.call_stack[-1].signature + def _get_tx_info_ptr(self, segments: MemorySegmentManager) -> RelocatableValue: + assert self.tx_info_ptr is not None + return self.tx_info_ptr def _storage_read(self, address: int) -> int: return next(self.execute_code_read_iterators[-1]) @@ -910,6 +954,22 @@ def execute_syscall_storage_write(self) -> int: """ return next(self.execute_syscall_read_iterators[-1]) + def start_tx(self, tx_info_ptr: RelocatableValue): + """ + Called when starting the execution of a transaction. + + 'tx_info_ptr' is a pointer to the TxInfo struct corresponding to said transaction. + """ + assert self.tx_info_ptr is None + self.tx_info_ptr = tx_info_ptr + + def end_tx(self): + """ + Called after the execution of the current transaction complete. + """ + assert self.tx_info_ptr is not None + self.tx_info_ptr = None + def enter_call(self): call_info = next(self._contract_calls_iterator) self._call_response_iterator = iter(call_info.internal_call_responses) diff --git a/src/starkware/starknet/services/api/gateway/transaction_hash.py b/src/starkware/starknet/core/os/transaction_hash.py similarity index 100% rename from src/starkware/starknet/services/api/gateway/transaction_hash.py rename to src/starkware/starknet/core/os/transaction_hash.py diff --git a/src/starkware/starknet/services/api/gateway/transaction_hash_test.py b/src/starkware/starknet/core/os/transaction_hash_test.py similarity index 96% rename from src/starkware/starknet/services/api/gateway/transaction_hash_test.py rename to src/starkware/starknet/core/os/transaction_hash_test.py index e51a49df..0030b430 100644 --- a/src/starkware/starknet/services/api/gateway/transaction_hash_test.py +++ b/src/starkware/starknet/core/os/transaction_hash_test.py @@ -4,7 +4,7 @@ from starkware.cairo.common.hash_state import compute_hash_on_elements from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash -from starkware.starknet.services.api.gateway.transaction_hash import ( +from starkware.starknet.core.os.transaction_hash import ( TransactionHashPrefix, calculate_deploy_transaction_hash, calculate_transaction_hash_common, diff --git a/src/starkware/starknet/core/os/transactions.cairo b/src/starkware/starknet/core/os/transactions.cairo index 7816d7c0..479c9cca 100644 --- a/src/starkware/starknet/core/os/transactions.cairo +++ b/src/starkware/starknet/core/os/transactions.cairo @@ -1,6 +1,6 @@ from starkware.cairo.builtin_selection.select_builtins import select_builtins from starkware.cairo.builtin_selection.validate_builtins import validate_builtin, validate_builtins -from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, HashBuiltin, SignatureBuiltin +from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.dict import dict_new, dict_read, dict_update, dict_write from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.find_element import find_element, search_sorted @@ -12,15 +12,19 @@ from starkware.starknet.common.syscalls import ( CALL_CONTRACT_SELECTOR, DELEGATE_CALL_SELECTOR, DELEGATE_L1_HANDLER_SELECTOR, EMIT_EVENT_SELECTOR, GET_BLOCK_NUMBER_SELECTOR, GET_BLOCK_TIMESTAMP_SELECTOR, GET_CALLER_ADDRESS_SELECTOR, GET_CONTRACT_ADDRESS_SELECTOR, GET_SEQUENCER_ADDRESS_SELECTOR, - GET_TX_SIGNATURE_SELECTOR, SEND_MESSAGE_TO_L1_SELECTOR, STORAGE_READ_SELECTOR, - STORAGE_WRITE_SELECTOR, CallContract, CallContractResponse, EmitEvent, GetBlockNumber, - GetBlockNumberResponse, GetBlockTimestamp, GetBlockTimestampResponse, GetCallerAddress, - GetCallerAddressResponse, GetContractAddress, GetContractAddressResponse, GetSequencerAddress, - GetSequencerAddressResponse, GetTxSignature, SendMessageToL1SysCall, StorageRead, StorageWrite) + GET_TX_INFO_SELECTOR, GET_TX_SIGNATURE_SELECTOR, SEND_MESSAGE_TO_L1_SELECTOR, + STORAGE_READ_SELECTOR, STORAGE_WRITE_SELECTOR, CallContract, CallContractResponse, EmitEvent, + GetBlockNumber, GetBlockNumberResponse, GetBlockTimestamp, GetBlockTimestampResponse, + GetCallerAddress, GetCallerAddressResponse, GetContractAddress, GetContractAddressResponse, + GetSequencerAddress, GetSequencerAddressResponse, GetTxInfo, GetTxInfoResponse, GetTxSignature, + GetTxSignatureResponse, SendMessageToL1SysCall, StorageRead, StorageWrite, TxInfo) +from starkware.starknet.core.os.builtins import ( + BuiltinEncodings, BuiltinParams, BuiltinPointers, get_builtin_params) from starkware.starknet.core.os.contracts import ( ContractDefinition, ContractDefinitionFact, ContractEntryPoint, load_contract_definition_facts) from starkware.starknet.core.os.output import ( - BlockInfo, DeploymentInfoHeader, MessageToL1Header, MessageToL2Header, OsCarriedOutputs) + BlockInfo, DeploymentInfoHeader, MessageToL1Header, MessageToL2Header, OsCarriedOutputs, + os_carried_outputs_new) from starkware.starknet.core.os.state import StateEntry const UNINITIALIZED_CONTRACT_HASH = 0 @@ -32,9 +36,9 @@ const ORIGIN_ADDRESS = 0 # Used to implement an empty constructor. const NOP_ENTRY_POINT_OFFSET = -1 -const TX_TYPE_EXTERNAL = 0 -const TX_TYPE_L1_HANDLER = 1 -const TX_TYPE_CONSTRUCTOR = 2 +const ENTRY_POINT_TYPE_EXTERNAL = 0 +const ENTRY_POINT_TYPE_L1_HANDLER = 1 +const ENTRY_POINT_TYPE_CONSTRUCTOR = 2 # get_selector_from_name('constructor'). const CONSTRUCTOR_SELECTOR = ( @@ -42,11 +46,11 @@ const CONSTRUCTOR_SELECTOR = ( const DEFAULT_ENTRY_POINT_SELECTOR = 0 -# An internal representation of an Invoke transaction to execute. -struct Transaction: - member tx_type : felt +# Represents the execution context during the execution of contract code. +struct ExecutionContext: + member entry_point_type : felt member caller_address : felt - # The address of the contract executing this transaction. + # The execution is done in the context of the contract at 'contract_address'. # This address controls the storage being used, messages sent to L1, calling contracts, etc. member contract_address : felt # The address that holds the code to execute. @@ -55,6 +59,8 @@ struct Transaction: member selector : felt member calldata_size : felt member calldata : felt* + # Information about the transaction that triggered the execution. + member original_tx_info : TxInfo* end # A dictionary from address to StateEntry. @@ -63,35 +69,8 @@ struct StateChanges: member changes_end : DictAccess* end -struct BuiltinPointers: - member pedersen : HashBuiltin* - member range_check : felt - member ecdsa : felt - member bitwise : felt -end - -# A struct containing the ASCII encoding of each builtin. -struct BuiltinEncodings: - member pedersen : felt - member range_check : felt - member ecdsa : felt - member bitwise : felt -end - -# A struct containing the instance size of each builtin. -struct BuiltinInstanceSizes: - member pedersen : felt - member range_check : felt - member ecdsa : felt - member bitwise : felt -end - -struct BuiltinParams: - member builtin_encodings : BuiltinEncodings* - member builtin_instance_sizes : BuiltinInstanceSizes* -end - -struct ExecuteTransactionContext: +# Context that remains fixed throughout the block. +struct BlockContext: member builtin_params : BuiltinParams* member n_contract_definition_facts : felt member contract_definition_facts : ContractDefinitionFact* @@ -99,28 +78,6 @@ struct ExecuteTransactionContext: member block_info : BlockInfo* end -func get_builtin_params() -> (builtin_params : BuiltinParams*): - alloc_locals - let (local __fp__, _) = get_fp_and_pc() - - local builtin_encodings : BuiltinEncodings = BuiltinEncodings( - pedersen='pedersen', - range_check='range_check', - ecdsa='ecdsa', - bitwise='bitwise') - - local builtin_instance_sizes : BuiltinInstanceSizes = BuiltinInstanceSizes( - pedersen=HashBuiltin.SIZE, - range_check=1, - ecdsa=SignatureBuiltin.SIZE, - bitwise=BitwiseBuiltin.SIZE) - - local builtin_params : BuiltinParams = BuiltinParams( - builtin_encodings=&builtin_encodings, - builtin_instance_sizes=&builtin_instance_sizes) - return (builtin_params=&builtin_params) -end - # Executes the transactions in the hint variable os_input.transactions. # # Returns: @@ -135,7 +92,7 @@ end # the returned range_check_ptr is smaller then reserved_range_checks_end. func execute_transactions{ pedersen_ptr : HashBuiltin*, range_check_ptr, ecdsa_ptr, bitwise_ptr, - outputs : OsCarriedOutputs}(block_info : BlockInfo*) -> ( + outputs : OsCarriedOutputs*}(block_info : BlockInfo*) -> ( reserved_range_checks_end, state_changes : StateChanges): alloc_locals local n_txs @@ -156,17 +113,15 @@ func execute_transactions{ let (n_contract_definition_facts, contract_definition_facts) = load_contract_definition_facts() let (local __fp__, _) = get_fp_and_pc() - tempvar temp_range_check - %{ ids.temp_range_check = segments.add_temp_segment() %} local local_builtin_ptrs : BuiltinPointers = BuiltinPointers( pedersen=pedersen_ptr, - range_check=temp_range_check, + range_check=nondet %{ segments.add_temp_segment() %}, ecdsa=ecdsa_ptr, bitwise=bitwise_ptr) let (builtin_params) = get_builtin_params() - local execute_tx_context : ExecuteTransactionContext = ExecuteTransactionContext( + local block_context : BlockContext = BlockContext( builtin_params=builtin_params, n_contract_definition_facts=n_contract_definition_facts, contract_definition_facts=contract_definition_facts, @@ -186,7 +141,7 @@ func execute_transactions{ let global_state_changes_start = global_state_changes execute_transactions_inner{ builtin_ptrs=builtin_ptrs, global_state_changes=global_state_changes}( - execute_tx_context=&execute_tx_context, n_txs=n_txs) + block_context=&block_context, n_txs=n_txs) %{ vm_exit_scope() %} let reserved_range_checks_end = range_check_ptr @@ -206,7 +161,7 @@ end # Inner function for execute_transactions. # Arguments: -# execute_tx_context - a read-only context used for transaction execution. +# block_context - a read-only context used for transaction execution. # n_txs - the number of transactions to execute. # # Implicit arguments: @@ -217,7 +172,7 @@ end # They are accounted for in builtin_ptrs. func execute_transactions_inner{ range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, - outputs : OsCarriedOutputs}(execute_tx_context : ExecuteTransactionContext*, n_txs): + outputs : OsCarriedOutputs*}(block_context : BlockContext*, n_txs): if n_txs == 0: return () end @@ -233,15 +188,15 @@ func execute_transactions_inner{ jmp deploy_transaction if [ap] != 0; ap++ # Handle invoke_transaction. - execute_externally_called_invoke_transaction(execute_tx_context=execute_tx_context) + execute_externally_called_invoke_transaction(block_context=block_context) - return execute_transactions_inner(execute_tx_context=execute_tx_context, n_txs=n_txs - 1) + return execute_transactions_inner(block_context=block_context, n_txs=n_txs - 1) deploy_transaction: # Handle deploy_transaction. - execute_deploy_transaction(execute_tx_context=execute_tx_context) + execute_deploy_transaction(block_context=block_context) - return execute_transactions_inner(execute_tx_context=execute_tx_context, n_txs=n_txs - 1) + return execute_transactions_inner(block_context=block_context, n_txs=n_txs - 1) end # Executes an externally called invoke transaction. @@ -250,81 +205,101 @@ end # If the transaction is an L1 handler, it is appended to the list of consumed L1->L2 messages. # # Arguments: -# execute_tx_context - a read-only context used for transaction execution. +# block_context - a global context that is fixed throughout the block. func execute_externally_called_invoke_transaction{ range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, - outputs : OsCarriedOutputs}(execute_tx_context : ExecuteTransactionContext*): + outputs : OsCarriedOutputs*}(block_context : BlockContext*): alloc_locals - local tx : Transaction* + + # Loads the execution context based on the current transaction. + local execution_context : ExecutionContext* %{ from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction from starkware.starknet.services.api.contract_definition import EntryPointType if tx.entry_point_type is EntryPointType.L1_HANDLER: - tx_type = ids.TX_TYPE_L1_HANDLER + entry_point_type = ids.ENTRY_POINT_TYPE_L1_HANDLER assert tx.nonce is not None, "L1 handlers must include a nonce." elif tx.entry_point_type is EntryPointType.EXTERNAL: - tx_type = ids.TX_TYPE_EXTERNAL + entry_point_type = ids.ENTRY_POINT_TYPE_EXTERNAL else: raise NotImplementedError(f'Unexpected EntryPointType: {tx.entry_point_type}.') assert isinstance(tx, InternalInvokeFunction), \ f'Expected a transaction of type InternalInvokeFunction, got {tx}.' - ids.tx = segments.gen_arg( + + + original_tx_info = segments.add() + ids.execution_context = segments.gen_arg( arg=[ - tx_type, + entry_point_type, ids.ORIGIN_ADDRESS, tx.contract_address, tx.code_address, tx.entry_point_selector, len(tx.calldata), tx.calldata, + original_tx_info, ] ) %} + assert [execution_context.original_tx_info] = TxInfo( + version=0, + account_contract_address=execution_context.contract_address, + max_fee=0, + signature_len=nondet %{ len(tx.signature) %}, + signature=cast(nondet %{ segments.gen_arg(arg=tx.signature) %}, felt*), + ) + # External calls originate from ORIGIN_ADDRESS. - assert tx.caller_address = ORIGIN_ADDRESS + assert execution_context.caller_address = ORIGIN_ADDRESS - if tx.tx_type == TX_TYPE_L1_HANDLER: + if execution_context.entry_point_type == ENTRY_POINT_TYPE_L1_HANDLER: # Consume L1-to-L2 message. - consume_l1_to_l2_message(tx=tx, nonce=nondet %{ tx.nonce %}) + consume_l1_to_l2_message(execution_context=execution_context, nonce=nondet %{ tx.nonce %}) else: - # If tx.tx_type is not TX_TYPE_L1_HANDLER, it must be TX_TYPE_EXTERNAL. - assert tx.tx_type = TX_TYPE_EXTERNAL + # If execution_context.entry_point_type is not ENTRY_POINT_TYPE_L1_HANDLER, + # it must be ENTRY_POINT_TYPE_EXTERNAL. + assert execution_context.entry_point_type = ENTRY_POINT_TYPE_EXTERNAL tempvar outputs = outputs end # In external calls and l1 handlers, the code_address must match the contract_address. - assert tx.code_address = tx.contract_address + assert execution_context.code_address = execution_context.contract_address + + %{ syscall_handler.start_tx(tx_info_ptr=ids.execution_context.original_tx_info.address_) %} + execute_entry_point(block_context=block_context, execution_context=execution_context) - execute_invoke_transaction(execute_tx_context=execute_tx_context, tx=tx) + %{ syscall_handler.end_tx() %} return () end # Executes a syscall that calls another contract, or invokes a delegate call. func execute_contract_call{ range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, - outputs : OsCarriedOutputs}( - execute_tx_context : ExecuteTransactionContext*, contract_address : felt, - caller_address : felt, tx_type : felt, syscall_ptr : CallContract*): + outputs : OsCarriedOutputs*}( + block_context : BlockContext*, contract_address : felt, caller_address : felt, + entry_point_type : felt, original_tx_info : TxInfo*, syscall_ptr : CallContract*): alloc_locals let call_req = syscall_ptr.request - local tx : Transaction* - %{ ids.tx = segments.add() %} - assert [tx] = Transaction( - tx_type=tx_type, + local execution_context : ExecutionContext* + %{ ids.execution_context = segments.add() %} + assert [execution_context] = ExecutionContext( + entry_point_type=entry_point_type, caller_address=caller_address, contract_address=contract_address, code_address=call_req.contract_address, selector=call_req.function_selector, calldata_size=call_req.calldata_size, - calldata=call_req.calldata) + calldata=call_req.calldata, + original_tx_info=original_tx_info, + ) - let (retdata_size, retdata) = execute_invoke_transaction( - execute_tx_context=execute_tx_context, tx=tx) + let (retdata_size, retdata) = execute_entry_point( + block_context=block_context, execution_context=execution_context) let call_resp = syscall_ptr.response %{ @@ -410,199 +385,217 @@ func execute_storage_write{global_state_changes : DictAccess*}( return () end -# Executes a system call. +# Executes the system calls in syscall_ptr. # # Arguments: -# execute_tx_context - a read-only context used for transaction execution. -# calling_tx - The transaction for which we are executing the system calls. -# syscall_ptr a pointer to the syscall segment associated with the 'calling_tx'. +# block_context - a read-only context used for transaction execution. +# execution_context - The execution context in which the system calls need to be executed. +# syscall_ptr - a pointer to the syscall segment that needs to be executed. +# syscall_size - The size of the system call segment to be executed. func execute_syscalls{ range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, - outputs : OsCarriedOutputs}( - execute_tx_context : ExecuteTransactionContext*, calling_tx : Transaction*, syscall_size, + outputs : OsCarriedOutputs*}( + block_context : BlockContext*, execution_context : ExecutionContext*, syscall_size, syscall_ptr : felt*): if syscall_size == 0: return () end + if [syscall_ptr] == STORAGE_READ_SELECTOR: + execute_storage_read( + contract_address=execution_context.contract_address, + syscall_ptr=cast(syscall_ptr, StorageRead*)) + return execute_syscalls( + block_context=block_context, + execution_context=execution_context, + syscall_size=syscall_size - StorageRead.SIZE, + syscall_ptr=syscall_ptr + StorageRead.SIZE) + end + + if [syscall_ptr] == STORAGE_WRITE_SELECTOR: + execute_storage_write( + contract_address=execution_context.contract_address, + syscall_ptr=cast(syscall_ptr, StorageWrite*)) + return execute_syscalls( + block_context=block_context, + execution_context=execution_context, + syscall_size=syscall_size - StorageWrite.SIZE, + syscall_ptr=syscall_ptr + StorageWrite.SIZE) + end + + if [syscall_ptr] == EMIT_EVENT_SELECTOR: + # Skip as long as the block hash is not calculated by the OS. + return execute_syscalls( + block_context=block_context, + execution_context=execution_context, + syscall_size=syscall_size - EmitEvent.SIZE, + syscall_ptr=syscall_ptr + EmitEvent.SIZE) + end + if [syscall_ptr] == CALL_CONTRACT_SELECTOR: let call_contract_syscall = cast(syscall_ptr, CallContract*) execute_contract_call( - execute_tx_context=execute_tx_context, + block_context=block_context, contract_address=call_contract_syscall.request.contract_address, - caller_address=calling_tx.contract_address, - tx_type=TX_TYPE_EXTERNAL, + caller_address=execution_context.contract_address, + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + original_tx_info=execution_context.original_tx_info, syscall_ptr=call_contract_syscall) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - CallContract.SIZE, syscall_ptr=syscall_ptr + CallContract.SIZE) end if [syscall_ptr] == DELEGATE_CALL_SELECTOR: execute_contract_call( - execute_tx_context=execute_tx_context, - contract_address=calling_tx.contract_address, - caller_address=calling_tx.caller_address, - tx_type=TX_TYPE_EXTERNAL, + block_context=block_context, + contract_address=execution_context.contract_address, + caller_address=execution_context.caller_address, + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + original_tx_info=execution_context.original_tx_info, syscall_ptr=cast(syscall_ptr, CallContract*)) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - CallContract.SIZE, syscall_ptr=syscall_ptr + CallContract.SIZE) end if [syscall_ptr] == DELEGATE_L1_HANDLER_SELECTOR: execute_contract_call( - execute_tx_context=execute_tx_context, - contract_address=calling_tx.contract_address, - caller_address=calling_tx.caller_address, - tx_type=TX_TYPE_L1_HANDLER, + block_context=block_context, + contract_address=execution_context.contract_address, + caller_address=execution_context.caller_address, + entry_point_type=ENTRY_POINT_TYPE_L1_HANDLER, + original_tx_info=execution_context.original_tx_info, syscall_ptr=cast(syscall_ptr, CallContract*)) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - CallContract.SIZE, syscall_ptr=syscall_ptr + CallContract.SIZE) end + if [syscall_ptr] == GET_TX_INFO_SELECTOR: + assert cast(syscall_ptr, GetTxInfo*).response = GetTxInfoResponse( + tx_info=execution_context.original_tx_info) + return execute_syscalls( + block_context=block_context, + execution_context=execution_context, + syscall_size=syscall_size - GetTxInfo.SIZE, + syscall_ptr=syscall_ptr + GetTxInfo.SIZE) + end + if [syscall_ptr] == GET_CALLER_ADDRESS_SELECTOR: assert [cast(syscall_ptr, GetCallerAddress*)].response = GetCallerAddressResponse( - caller_address=calling_tx.caller_address) + caller_address=execution_context.caller_address) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - GetCallerAddress.SIZE, syscall_ptr=syscall_ptr + GetCallerAddress.SIZE) end if [syscall_ptr] == GET_SEQUENCER_ADDRESS_SELECTOR: assert [cast(syscall_ptr, GetSequencerAddress*)].response = GetSequencerAddressResponse( - sequencer_address=execute_tx_context.sequencer_address) + sequencer_address=block_context.sequencer_address) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - GetSequencerAddress.SIZE, syscall_ptr=syscall_ptr + GetSequencerAddress.SIZE) end if [syscall_ptr] == GET_CONTRACT_ADDRESS_SELECTOR: assert [cast(syscall_ptr, GetContractAddress*)].response = GetContractAddressResponse( - contract_address=calling_tx.contract_address) + contract_address=execution_context.contract_address) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - GetContractAddress.SIZE, syscall_ptr=syscall_ptr + GetContractAddress.SIZE) end if [syscall_ptr] == GET_BLOCK_TIMESTAMP_SELECTOR: assert [cast(syscall_ptr, GetBlockTimestamp*)].response = GetBlockTimestampResponse( - block_timestamp=execute_tx_context.block_info.block_timestamp) + block_timestamp=block_context.block_info.block_timestamp) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - GetBlockTimestamp.SIZE, syscall_ptr=syscall_ptr + GetBlockTimestamp.SIZE) end if [syscall_ptr] == GET_BLOCK_NUMBER_SELECTOR: assert [cast(syscall_ptr, GetBlockNumber*)].response = GetBlockNumberResponse( - block_number=execute_tx_context.block_info.block_number) + block_number=block_context.block_info.block_number) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - GetBlockNumber.SIZE, syscall_ptr=syscall_ptr + GetBlockNumber.SIZE) end if [syscall_ptr] == GET_TX_SIGNATURE_SELECTOR: - # Note that we don't enforce anything on the response. + tempvar original_tx_info : TxInfo* = execution_context.original_tx_info + assert [cast(syscall_ptr, GetTxSignature*)].response = GetTxSignatureResponse( + signature_len=original_tx_info.signature_len, + signature=original_tx_info.signature + ) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - GetTxSignature.SIZE, syscall_ptr=syscall_ptr + GetTxSignature.SIZE) end - if [syscall_ptr] == STORAGE_READ_SELECTOR: - execute_storage_read( - contract_address=calling_tx.contract_address, - syscall_ptr=cast(syscall_ptr, StorageRead*)) - return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, - syscall_size=syscall_size - StorageRead.SIZE, - syscall_ptr=syscall_ptr + StorageRead.SIZE) - end - - if [syscall_ptr] == STORAGE_WRITE_SELECTOR: - execute_storage_write( - contract_address=calling_tx.contract_address, - syscall_ptr=cast(syscall_ptr, StorageWrite*)) - return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, - syscall_size=syscall_size - StorageWrite.SIZE, - syscall_ptr=syscall_ptr + StorageWrite.SIZE) - end - - if [syscall_ptr] == EMIT_EVENT_SELECTOR: - # Skip as long as the block hash is not calculated by the OS. - return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, - syscall_size=syscall_size - EmitEvent.SIZE, - syscall_ptr=syscall_ptr + EmitEvent.SIZE) - end - # Here the system call must be 'SendMessageToL1'. assert [syscall_ptr] = SEND_MESSAGE_TO_L1_SELECTOR let syscall = [cast(syscall_ptr, SendMessageToL1SysCall*)] assert [outputs.messages_to_l1] = MessageToL1Header( - from_address=calling_tx.contract_address, + from_address=execution_context.contract_address, to_address=syscall.to_address, payload_size=syscall.payload_size) memcpy( dst=outputs.messages_to_l1 + MessageToL1Header.SIZE, src=syscall.payload_ptr, len=syscall.payload_size) - let outputs = OsCarriedOutputs( + let (outputs) = os_carried_outputs_new( messages_to_l1=outputs.messages_to_l1 + MessageToL1Header.SIZE + outputs.messages_to_l1.payload_size, messages_to_l2=outputs.messages_to_l2, deployment_info=outputs.deployment_info) return execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=calling_tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_size - SendMessageToL1SysCall.SIZE, syscall_ptr=syscall_ptr + SendMessageToL1SysCall.SIZE) end # Adds 'tx' with the given 'nonce' to 'outputs.messages_to_l2'. -func consume_l1_to_l2_message{outputs : OsCarriedOutputs}(tx : Transaction*, nonce : felt): - assert_not_zero(tx.calldata_size) - # The raw payload is the calldata without the from_address argument (which is the first). - tempvar raw_payload : felt* = tx.calldata + 1 - tempvar raw_payload_size = tx.calldata_size - 1 +func consume_l1_to_l2_message{outputs : OsCarriedOutputs*}( + execution_context : ExecutionContext*, nonce : felt): + assert_not_zero(execution_context.calldata_size) + # The payload is the calldata without the from_address argument (which is the first). + let payload : felt* = execution_context.calldata + 1 + tempvar payload_size = execution_context.calldata_size - 1 # Write the given transaction to the output. assert [outputs.messages_to_l2] = MessageToL2Header( - from_address=[tx.calldata], - to_address=tx.contract_address, + from_address=[execution_context.calldata], + to_address=execution_context.contract_address, nonce=nonce, - selector=tx.selector, - payload_size=raw_payload_size) + selector=execution_context.selector, + payload_size=payload_size) - # The payload consists of the selector and the raw payload. let message_payload = cast(outputs.messages_to_l2 + MessageToL2Header.SIZE, felt*) - memcpy(dst=message_payload, src=raw_payload, len=raw_payload_size) + memcpy(dst=message_payload, src=payload, len=payload_size) - let outputs = OsCarriedOutputs( + let (outputs) = os_carried_outputs_new( messages_to_l1=outputs.messages_to_l1, messages_to_l2=outputs.messages_to_l2 + MessageToL2Header.SIZE + outputs.messages_to_l2.payload_size, @@ -610,25 +603,26 @@ func consume_l1_to_l2_message{outputs : OsCarriedOutputs}(tx : Transaction*, non return () end -# Returns the entry point's offset in the program based on the contract_definition and the -# transaction. +# Returns the entry point's offset in the program based on 'contract_definition' and +# 'execution_context'. func get_entry_point_offset{range_check_ptr}( - contract_definition : ContractDefinition*, tx : Transaction*) -> ( + contract_definition : ContractDefinition*, execution_context : ExecutionContext*) -> ( entry_point_offset : felt): alloc_locals # Get the entry points corresponding to the transaction's type. local entry_points : ContractEntryPoint* local n_entry_points : felt - if tx.tx_type == TX_TYPE_L1_HANDLER: + tempvar entry_point_type = execution_context.entry_point_type + if entry_point_type == ENTRY_POINT_TYPE_L1_HANDLER: entry_points = contract_definition.l1_handlers n_entry_points = contract_definition.n_l1_handlers else: - if tx.tx_type == TX_TYPE_EXTERNAL: + if entry_point_type == ENTRY_POINT_TYPE_EXTERNAL: entry_points = contract_definition.external_functions n_entry_points = contract_definition.n_external_functions else: - assert tx.tx_type = TX_TYPE_CONSTRUCTOR + assert entry_point_type = ENTRY_POINT_TYPE_CONSTRUCTOR entry_points = contract_definition.constructors n_entry_points = contract_definition.n_constructors @@ -644,7 +638,7 @@ func get_entry_point_offset{range_check_ptr}( array_ptr=cast(entry_points, felt*), elm_size=ContractEntryPoint.SIZE, n_elms=n_entry_points, - key=tx.selector) + key=execution_context.selector) if success != 0: return (entry_point_offset=entry_point_desc.offset) end @@ -655,38 +649,40 @@ func get_entry_point_offset{range_check_ptr}( return (entry_point_offset=entry_points[0].offset) end -# Executes an invoke transaction and returns its return value. +# Executes an entry point in a contract. +# The contract entry point is selected based on execution_context.entry_point_type +# and execution_context.selector. # # Arguments: -# execute_tx_context - a read-only context used for transaction execution. -# tx - The transaction to execute. -func execute_invoke_transaction{ +# block_context - a global context that is fixed throughout the block. +# execution_context - The context for the current execution. +func execute_entry_point{ range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, - outputs : OsCarriedOutputs}( - execute_tx_context : ExecuteTransactionContext*, tx : Transaction*) -> ( + outputs : OsCarriedOutputs*}( + block_context : BlockContext*, execution_context : ExecutionContext*) -> ( retdata_size, retdata : felt*): alloc_locals let (local state_entry : StateEntry*) = dict_read{dict_ptr=global_state_changes}( - key=tx.code_address) + key=execution_context.code_address) local global_state_changes : DictAccess* = global_state_changes # The key must be at offset 0. static_assert ContractDefinitionFact.hash == 0 let (contract_definition_fact : ContractDefinitionFact*) = find_element( - array_ptr=execute_tx_context.contract_definition_facts, + array_ptr=block_context.contract_definition_facts, elm_size=ContractDefinitionFact.SIZE, - n_elms=execute_tx_context.n_contract_definition_facts, + n_elms=block_context.n_contract_definition_facts, key=state_entry.contract_hash) local contract_definition : ContractDefinition* = contract_definition_fact.contract_definition let (entry_point_offset) = get_entry_point_offset( - contract_definition=contract_definition, tx=tx) + contract_definition=contract_definition, execution_context=execution_context) %{ syscall_handler.enter_call() %} if entry_point_offset == NOP_ENTRY_POINT_OFFSET: # Assert that there is no call data in the case of NOP entry point. - assert tx.calldata_size = 0 + assert execution_context.calldata_size = 0 %{ syscall_handler.exit_call() %} return (retdata_size=0, retdata=cast(0, felt*)) end @@ -704,7 +700,7 @@ func execute_invoke_transaction{ assert [os_context] = cast(syscall_ptr, felt) let n_builtins = BuiltinEncodings.SIZE - local builtin_params : BuiltinParams* = execute_tx_context.builtin_params + local builtin_params : BuiltinParams* = block_context.builtin_params select_builtins( n_builtins=n_builtins, all_encodings=builtin_params.builtin_encodings, @@ -714,13 +710,13 @@ func execute_invoke_transaction{ selected_ptrs=os_context + 1) # Use tempvar to pass arguments to contract_entry_point(). - tempvar selector = tx.selector + tempvar selector = execution_context.selector tempvar context = os_context - tempvar calldata_size = tx.calldata_size - tempvar calldata = tx.calldata + tempvar calldata_size = execution_context.calldata_size + tempvar calldata = execution_context.calldata %{ vm_enter_scope({ - '__storage' : storage_by_address[ids.tx.contract_address], + '__storage' : storage_by_address[ids.execution_context.contract_address], 'syscall_handler' : syscall_handler, }) %} @@ -771,8 +767,8 @@ func execute_invoke_transaction{ let builtin_ptrs = return_builtin_ptrs execute_syscalls( - execute_tx_context=execute_tx_context, - calling_tx=tx, + block_context=block_context, + execution_context=execution_context, syscall_size=syscall_end - syscall_ptr, syscall_ptr=syscall_ptr) @@ -782,13 +778,15 @@ end func execute_deploy_transaction{ range_check_ptr, builtin_ptrs : BuiltinPointers*, global_state_changes : DictAccess*, - outputs : OsCarriedOutputs}(execute_tx_context : ExecuteTransactionContext*): + outputs : OsCarriedOutputs*}(block_context : BlockContext*): alloc_locals local contract_address local state_entry : StateEntry* local new_state_entry : StateEntry* %{ + from starkware.python.utils import from_bytes + ids.contract_address = tx.contract_address # Fetch a state_entry in this hint and validate it in the update at the end @@ -798,9 +796,7 @@ func execute_deploy_transaction{ ids.new_state_entry = segments.add() - from starkware.starknet.core.os.contract_hash import compute_contract_hash - - ids.new_state_entry.contract_hash = compute_contract_hash(tx.contract_definition) + ids.new_state_entry.contract_hash = from_bytes(tx.contract_hash) %} # Assert that we don't deploy to ORIGIN_ADDRESS. @@ -826,23 +822,36 @@ func execute_deploy_transaction{ assert [outputs.deployment_info] = DeploymentInfoHeader( contract_address=contract_address, contract_hash=new_state_entry.contract_hash, calldata_size=calldata_size) - let outputs = OsCarriedOutputs( + let (outputs) = os_carried_outputs_new( messages_to_l1=outputs.messages_to_l1, messages_to_l2=outputs.messages_to_l2, deployment_info=cast(calldata + calldata_size, DeploymentInfoHeader*)) - local tx : Transaction* - %{ ids.tx = segments.add() %} - assert [tx] = Transaction( - tx_type=TX_TYPE_CONSTRUCTOR, + # Invoke the contract constructor. + local execution_context : ExecutionContext* = cast( + nondet %{ segments.add() %}, ExecutionContext*) + assert [execution_context] = ExecutionContext( + entry_point_type=ENTRY_POINT_TYPE_CONSTRUCTOR, caller_address=ORIGIN_ADDRESS, contract_address=contract_address, code_address=contract_address, selector=CONSTRUCTOR_SELECTOR, calldata_size=calldata_size, - calldata=calldata) + calldata=calldata, + original_tx_info=cast(nondet %{ segments.add() %}, TxInfo*), + ) + + assert [execution_context.original_tx_info] = TxInfo( + version=0, + account_contract_address=ORIGIN_ADDRESS, + max_fee=0, + signature_len=0, + signature=cast(0, felt*), + ) - execute_invoke_transaction(execute_tx_context=execute_tx_context, tx=tx) + %{ syscall_handler.start_tx(tx_info_ptr=ids.execution_context.original_tx_info.address_) %} + execute_entry_point(block_context=block_context, execution_context=execution_context) + %{ syscall_handler.end_tx() %} return () end diff --git a/src/starkware/starknet/definitions/constants.py b/src/starkware/starknet/definitions/constants.py index dd17692d..da97b5a5 100644 --- a/src/starkware/starknet/definitions/constants.py +++ b/src/starkware/starknet/definitions/constants.py @@ -25,6 +25,8 @@ ENTRY_POINT_SELECTOR_LOWER_BOUND = 0 ENTRY_POINT_SELECTOR_UPPER_BOUND = FIELD_SIZE EVENT_COMMITMENT_TREE_HEIGHT = 64 +FEE_LOWER_BOUND = 0 +FEE_UPPER_BOUND = 2 ** 256 # Fee is a uint-256. # Default hash to fill the parent_hash field of the first block in the sequence. GENESIS_PARENT_BLOCK_HASH = 0 MAX_MESSAGE_TO_L1_LENGTH = 100 @@ -35,3 +37,17 @@ TRANSACTION_HASH_UPPER_BOUND = FIELD_SIZE ADDRESS_LOWER_BOUND = 0 ADDRESS_UPPER_BOUND = 2 ** ADDRESS_BITS + +# OS-related constants. +L1_TO_L2_MSG_HEADER_SIZE = 5 +L2_TO_L1_MSG_HEADER_SIZE = 3 + +# StarkNet solidity contract-related constants. +N_DEFAULT_TOPICS = 1 # Events have one default topic. +# Excluding the default topic. +LOG_MSG_TO_L1_N_TOPICS = 2 +CONSUMED_MSG_TO_L2_N_TOPICS = 3 +# The headers include the payload size, so we need to add +1 since arrays are encoded with two +# additional parameters (offset and length) in solidity. +LOG_MSG_TO_L1_ENCODED_DATA_SIZE = (L2_TO_L1_MSG_HEADER_SIZE + 1) - LOG_MSG_TO_L1_N_TOPICS +CONSUMED_MSG_TO_L2_ENCODED_DATA_SIZE = (L1_TO_L2_MSG_HEADER_SIZE + 1) - CONSUMED_MSG_TO_L2_N_TOPICS diff --git a/src/starkware/starknet/definitions/error_codes.py b/src/starkware/starknet/definitions/error_codes.py index c4befdf8..56d53e70 100644 --- a/src/starkware/starknet/definitions/error_codes.py +++ b/src/starkware/starknet/definitions/error_codes.py @@ -30,6 +30,7 @@ class StarknetErrorCode(ErrorCode): OUT_OF_RANGE_CONTRACT_STORAGE_KEY = auto() OUT_OF_RANGE_ENTRY_POINT_OFFSET = auto() OUT_OF_RANGE_ENTRY_POINT_SELECTOR = auto() + OUT_OF_RANGE_FEE = auto() OUT_OF_RANGE_SEQUENCER_ADDRESS = auto() OUT_OF_RANGE_TRANSACTION_HASH = auto() OUT_OF_RANGE_TRANSACTION_ID = auto() diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index a1cec8c9..785554ea 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -29,10 +29,10 @@ block_number_metadata = sequential_id_metadata(field_name="Block number", allow_previous_id=True) default_optional_block_number_metadata = sequential_id_metadata( - field_name="Block number", required=False, allow_none=True, load_default=None + field_name="Block number", required=False, load_default=None ) default_optional_transaction_index_metadata = sequential_id_metadata( - field_name="Transaction index", required=False, allow_none=True, load_default=None + field_name="Transaction index", required=False, load_default=None ) felt_list_metadata = dict( @@ -94,6 +94,10 @@ def validate_contract_hash(contract_hash: bytes): marshmallow_field=BytesAsHex(required=True, validate=validate_contract_hash), ) +non_required_contract_hash_metadata = dict( + marshmallow_field=BytesAsHex(required=False, validate=validate_contract_hash), +) + contract_storage_commitment_tree_height_metadata = dict( marshmallow_field=StrictRequiredInteger( validate=validate_positive("contract_storage_commitment_tree_height") diff --git a/src/starkware/starknet/definitions/general_config.py b/src/starkware/starknet/definitions/general_config.py index a26648f1..d9fe8663 100644 --- a/src/starkware/starknet/definitions/general_config.py +++ b/src/starkware/starknet/definitions/general_config.py @@ -1,12 +1,14 @@ from dataclasses import field from enum import Enum +from typing import Dict +import marshmallow.fields as mfields import marshmallow_dataclass from starkware.python.utils import from_bytes from starkware.starknet.definitions import constants, fields from starkware.starkware_utils.config_base import Config -from starkware.starkware_utils.field_validators import validate_non_negative +from starkware.starkware_utils.field_validators import validate_dict, validate_non_negative from starkware.starkware_utils.marshmallow_dataclass_fields import StrictRequiredInteger DOCKER_GENERAL_CONFIG_PATH = "/general_config.yml" @@ -26,6 +28,26 @@ class StarknetChainId(Enum): DEFAULT_MAX_STEPS = 10 ** 6 DEFAULT_CHAIN_ID = StarknetChainId.TESTNET +class CairoResource(Enum): + N_STEPS = "n_steps" + GAS_WEIGHT = "gas_weight" + PEDERSEN_BUILTIN = "pedersen_builtin" + RANGE_CHECK_BUILTIN = "range_check_builtin" + ECDSA_BUILTIN = "ecdsa_builtin" + BITWISE_BUILTIN = "bitwise_builtin" + OUTPUT_BUILTIN = "output_builtin" + EC_OP_BUILTIN = "ec_op_builtin" + + +DEFAULT_CAIRO_USAGE_RESOURCE_FEE_WEIGHTS = { + CairoResource.N_STEPS.value: 0.0, + CairoResource.GAS_WEIGHT.value: 0.0, + CairoResource.PEDERSEN_BUILTIN.value: 0.0, + CairoResource.RANGE_CHECK_BUILTIN.value: 0.0, + CairoResource.ECDSA_BUILTIN.value: 0.0, + CairoResource.BITWISE_BUILTIN.value: 0.0, +} + # Configuration schema definition. @@ -74,3 +96,20 @@ class StarknetGeneralConfig(Config): ), default=constants.EVENT_COMMITMENT_TREE_HEIGHT, ) + + cairo_usage_resource_fee_weights: Dict[str, float] = field( + metadata=dict( + marshmallow_field=mfields.Dict( + keys=mfields.String, + values=mfields.Float, + validate=validate_dict( + "Cairo usage resource fee weights", value_validator=validate_non_negative + ), + ), + description=( + "A mapping from a Cairo usage resource to its coefficient in this transaction " + "fee calculation." + ), + ), + default_factory=lambda: DEFAULT_CAIRO_USAGE_RESOURCE_FEE_WEIGHTS.copy(), + ) diff --git a/src/starkware/starknet/definitions/general_config.yml b/src/starkware/starknet/definitions/general_config.yml index 3b813ecb..d1a12221 100644 --- a/src/starkware/starknet/definitions/general_config.yml +++ b/src/starkware/starknet/definitions/general_config.yml @@ -2,6 +2,13 @@ sequencer_address: '0x0' contract_storage_commitment_tree_height: 251 global_state_commitment_tree_height: 251 invoke_tx_max_n_steps: 1000000 +cairo_usage_resource_fee_weights: + bitwise_builtin: 0.0 + ecdsa_builtin: 0.0 + gas_weight: 0.0 + n_steps: 0.0 + pedersen_builtin: 0.0 + range_check_builtin: 0.0 tx_commitment_tree_height: 64 event_commitment_tree_height: 64 diff --git a/src/starkware/starknet/security/latest_whitelist_test.py b/src/starkware/starknet/security/latest_whitelist_test.py index a1ff670d..2f7da33e 100644 --- a/src/starkware/starknet/security/latest_whitelist_test.py +++ b/src/starkware/starknet/security/latest_whitelist_test.py @@ -19,7 +19,7 @@ def run(fix: bool): filename = get_source_dir_path("src/starkware/starknet/security/whitelists/latest.json") whitelist = HintsWhitelist.from_program(program) if fix: - data = HintsWhitelist.Schema().dumps(whitelist, indent=4, sort_keys=True) + data = whitelist.dumps(indent=4, sort_keys=True) with open(filename, "w") as fp: fp.write(data) fp.write("\n") diff --git a/src/starkware/starknet/security/secure_hints.py b/src/starkware/starknet/security/secure_hints.py index 051463ae..9a873d07 100644 --- a/src/starkware/starknet/security/secure_hints.py +++ b/src/starkware/starknet/security/secure_hints.py @@ -96,7 +96,7 @@ def empty(cls): @classmethod def from_file(cls, filename: str) -> "HintsWhitelist": with open(filename, "r") as fp: - return cls.Schema().loads(fp.read()) + return cls.loads(data=fp.read()) @classmethod def from_dir(cls, dirname: str) -> "HintsWhitelist": diff --git a/src/starkware/starknet/security/secure_hints_test.py b/src/starkware/starknet/security/secure_hints_test.py index f9954e61..02ab8a94 100644 --- a/src/starkware/starknet/security/secure_hints_test.py +++ b/src/starkware/starknet/security/secure_hints_test.py @@ -73,8 +73,7 @@ def test_secure_hints_cases(): def test_secure_hints_serialization(): template_program = compile_cairo(ALLOWED_CODE, DEFAULT_PRIME) whitelist = HintsWhitelist.from_program(template_program) - data = HintsWhitelist.Schema().dumps(whitelist) - whitelist = HintsWhitelist.Schema().loads(data) + whitelist = HintsWhitelist.loads(data=whitelist.dumps()) for good_code in GOOD_CODES: program = compile_cairo(good_code, DEFAULT_PRIME) whitelist.verify_program_hint_secure(program) diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo index e3b1a230..4c5753a9 100644 --- a/src/starkware/starknet/security/starknet_common.cairo +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -24,4 +24,4 @@ from starkware.starknet.common.storage import normalize_address from starkware.starknet.common.syscalls import ( call_contract, delegate_call, delegate_l1_handler, emit_event, get_block_number, get_block_timestamp, get_caller_address, get_contract_address, get_sequencer_address, - get_tx_signature, storage_read, storage_write) + get_tx_info, get_tx_signature, storage_read, storage_write) diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json index 19dd2dff..f3c5ee3d 100644 --- a/src/starkware/starknet/security/whitelists/latest.json +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -539,6 +539,12 @@ "syscall_handler.get_sequencer_address(segments=segments, syscall_ptr=ids.syscall_ptr)" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "syscall_handler.get_tx_info(segments=segments, syscall_ptr=ids.syscall_ptr)" + ] + }, { "allowed_expressions": [], "hint_lines": [ diff --git a/src/starkware/starknet/services/api/CMakeLists.txt b/src/starkware/starknet/services/api/CMakeLists.txt index 56d7f4ec..0ed15bb2 100644 --- a/src/starkware/starknet/services/api/CMakeLists.txt +++ b/src/starkware/starknet/services/api/CMakeLists.txt @@ -14,6 +14,7 @@ python_lib(starknet_messages_lib starknet_definitions_lib starknet_internal_transaction_lib starkware_dataclasses_utils_lib + starkware_error_handling_lib ) python_lib(starknet_contract_definition_lib diff --git a/src/starkware/starknet/services/api/contract_definition.py b/src/starkware/starknet/services/api/contract_definition.py index e70f4d25..a4be70ad 100644 --- a/src/starkware/starknet/services/api/contract_definition.py +++ b/src/starkware/starknet/services/api/contract_definition.py @@ -94,3 +94,11 @@ def remove_debug_info(self) -> "ContractDefinition": """ altered_program = dataclasses.replace(self.program, debug_info=None) return dataclasses.replace(self, program=altered_program) + + @property + def n_entry_points(self) -> int: + """ + Returns the number of entry points (note that functions with multiple decorators are + counted more than once). + """ + return sum(len(eps) for eps in self.entry_points_by_type.values()) diff --git a/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt b/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt index 732a2fc3..a88aa987 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt +++ b/src/starkware/starknet/services/api/feeder_gateway/CMakeLists.txt @@ -7,6 +7,7 @@ python_lib(starknet_feeder_gateway_client_lib LIBS everest_feeder_gateway_client_lib starknet_definitions_lib + starknet_feeder_gateway_response_objects_lib starknet_transaction_lib starkware_dataclasses_utils_lib pip_typing_extensions @@ -27,3 +28,26 @@ python_lib(starknet_block_hash_lib starkware_storage_utils_lib starkware_utils_lib ) + +python_lib(starknet_feeder_gateway_response_objects_lib + PREFIX starkware/starknet/services/api/feeder_gateway + + FILES + response_objects.py + + LIBS + cairo_vm_lib + everest_definitions_lib + everest_feeder_gateway_response_objects_lib + everest_transaction_execution_objects_lib + starknet_contract_definition_lib + starknet_definitions_lib + starknet_internal_transaction_lib + starknet_transaction_execution_objects_lib + starkware_dataclasses_utils_lib + pip_marshmallow + pip_marshmallow_dataclass + pip_marshmallow_enum + pip_marshmallow_oneofschema + pip_typing_extensions +) diff --git a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py index 20309229..1c46effe 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py +++ b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py @@ -5,6 +5,11 @@ from services.everest.api.feeder_gateway.feeder_gateway_client import EverestFeederGatewayClient from starkware.starknet.definitions import fields +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + StarknetBlock, + TransactionInfo, + TransactionReceipt, +) from starkware.starknet.services.api.gateway.transaction import InvokeFunction from starkware.starkware_utils.validated_fields import RangeValidatedField @@ -42,7 +47,7 @@ async def get_block( self, block_hash: Optional[CastableToHash] = None, block_number: Optional[BlockIdentifier] = None, - ) -> JsonObject: + ) -> StarknetBlock: formatted_block_identifier = get_formatted_block_identifier( block_hash=block_hash, block_number=block_number ) @@ -50,6 +55,20 @@ async def get_block( send_method="GET", uri=f"/get_block?{formatted_block_identifier}", ) + return StarknetBlock.loads(raw_response) + + async def get_state_update( + self, + block_hash: Optional[CastableToHash] = None, + block_number: Optional[BlockIdentifier] = None, + ) -> JsonObject: + formatted_block_identifier = get_formatted_block_identifier( + block_hash=block_hash, block_number=block_number + ) + raw_response = await self._send_request( + send_method="GET", + uri=f"/get_state_update?{formatted_block_identifier}", + ) return json.loads(raw_response) async def get_code( @@ -67,6 +86,26 @@ async def get_code( ) return json.loads(raw_response) + async def get_full_contract( + self, + contract_address: int, + block_hash: Optional[CastableToHash] = None, + block_number: Optional[BlockIdentifier] = None, + ) -> JsonObject: + """ + Returns the contract definition deployed under the given address. + A plain JSON is returned, rather than the Python object, to save loading time. + """ + formatted_block_identifier = get_formatted_block_identifier( + block_hash=block_hash, block_number=block_number + ) + uri = ( + f"/get_full_contract?contractAddress={hex(contract_address)}&" + f"{formatted_block_identifier}" + ) + raw_response = await self._send_request(send_method="GET", uri=uri) + return json.loads(raw_response) + async def get_storage_at( self, contract_address: int, @@ -95,20 +134,20 @@ async def get_transaction_status( async def get_transaction( self, tx_hash: Optional[CastableToHash], tx_id: Optional[int] = None - ) -> JsonObject: + ) -> TransactionInfo: raw_response = await self._send_request( send_method="GET", uri=f"/get_transaction?{tx_identifier(tx_hash=tx_hash, tx_id=tx_id)}" ) - return json.loads(raw_response) + return TransactionInfo.loads(raw_response) async def get_transaction_receipt( self, tx_hash: Optional[CastableToHash], tx_id: Optional[int] = None - ) -> JsonObject: + ) -> TransactionReceipt: raw_response = await self._send_request( send_method="GET", uri=f"/get_transaction_receipt?{tx_identifier(tx_hash=tx_hash, tx_id=tx_id)}", ) - return json.loads(raw_response) + return TransactionReceipt.loads(raw_response) async def get_block_hash_by_id(self, block_id: int) -> str: raw_response = await self._send_request( diff --git a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py new file mode 100644 index 00000000..f22684e0 --- /dev/null +++ b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py @@ -0,0 +1,527 @@ +import dataclasses +from dataclasses import field +from enum import Enum, auto +from typing import ClassVar, Dict, Iterable, List, Optional, Tuple, Type, Union + +import marshmallow +import marshmallow.exceptions +import marshmallow.fields as mfields +import marshmallow.utils +import marshmallow_dataclass +from marshmallow_oneofschema import OneOfSchema +from typing_extensions import Literal + +from services.everest.api.feeder_gateway.response_objects import BaseResponseObject +from services.everest.business_logic.transaction_execution_objects import TransactionFailureReason +from services.everest.definitions import fields as everest_fields +from starkware.cairo.lang.vm.cairo_pie import ExecutionResources +from starkware.starknet.business_logic.internal_transaction import ( + InternalDeploy, + InternalInvokeFunction, + InternalTransaction, +) +from starkware.starknet.business_logic.transaction_execution_objects import Event +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.transaction_type import TransactionType +from starkware.starknet.services.api.contract_definition import EntryPointType +from starkware.starkware_utils.marshmallow_dataclass_fields import VariadicLengthTupleField +from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass + +BlockIdentifier = Union[int, Literal["pending"]] +OptionalBlockIdentifier = Optional[BlockIdentifier] + + +class BlockStatus(Enum): + # A pending block; i.e., a block that is yet to be closed. + PENDING = 0 + # An aborted block (failed in the L2 pipeline). + ABORTED = auto() + # A reverted block (rejected on L1). + REVERTED = auto() + # A block that was created on L2, in contrast to PENDING, which is not yet closed. + ACCEPTED_ON_L2 = auto() + # A block accepted on L1. + ACCEPTED_ON_L1 = auto() + + +class TransactionStatus(Enum): + # The transaction has not been received yet (i.e., not written to storage). + NOT_RECEIVED = 0 + # The transaction was received by the sequencer. + RECEIVED = auto() + # The transaction passed the validation and entered the pending block. + PENDING = auto() + # The transaction failed validation and thus was skipped (applies both to a pending and an + # actual created block). + REJECTED = auto() + # The transaction passed the validation and entered an actual created block. + ACCEPTED_ON_L2 = auto() + # The transaction was accepted on-chain. + ACCEPTED_ON_L1 = auto() + + @property + def has_receipt(self) -> bool: + """ + Returns whether a transaction with that status has a receipt (i.e., has been executed + successfully). + """ + return self in ( + TransactionStatus.PENDING, + TransactionStatus.ACCEPTED_ON_L2, + TransactionStatus.ACCEPTED_ON_L1, + ) + + @classmethod + def from_block_status(cls, block_status: BlockStatus) -> "TransactionStatus": + """ + Returns a transaction status according to the status of a block containing it. + """ + if block_status in ( + BlockStatus.PENDING, + BlockStatus.ACCEPTED_ON_L2, + BlockStatus.ACCEPTED_ON_L1, + ): + # The statuses above are identical for a block and a transaction. + return TransactionStatus[block_status.name] + elif block_status in (BlockStatus.REVERTED, BlockStatus.ABORTED): + # The transaction passed Batcher validations, but the block containing it failed on + # L1 or L2. Hence, it is yet again waiting to be inserted to a new block. + return TransactionStatus.RECEIVED + + raise NotImplementedError(f"Handling block status {block_status.name} is not implemented.") + + def __ge__(self, other: object) -> bool: + if not isinstance(other, TransactionStatus): + return NotImplemented + + self_not_comparable, other_not_comparable = ( + status not in tx_status_order_relation.keys() for status in (self, other) + ) + if self_not_comparable or other_not_comparable: + raise NotImplementedError( + f"Comparison is not supported between status {self.name} and {other.name}." + ) + + return tx_status_order_relation[self] >= tx_status_order_relation[other] + + def __lt__(self, other: object) -> bool: + return not self >= other + + +# Dictionary that represents the TransactionStatus valid flows. +# [NOT_RECEIVED] -> [RECEIVED] -> [PENDING] -> [ACCEPTED_ON_L2] -> [ACCEPTED_ON_L1]. +# REJECTED is excluded from the relation since the status of a REJECTED transaction will not +# become ACCEPTED_ON_L2. +tx_status_order_relation: Dict[TransactionStatus, int] = { + TransactionStatus.NOT_RECEIVED: 0, + TransactionStatus.RECEIVED: 1, + TransactionStatus.PENDING: 2, + TransactionStatus.ACCEPTED_ON_L2: 3, + TransactionStatus.ACCEPTED_ON_L1: 4, +} + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionInBlockInfo(BaseResponseObject): + """ + Represents the information regarding a StarkNet transaction that appears in a block. + """ + + # The status of a transaction, see TransactionStatus. + status: Optional[TransactionStatus] + # The reason for the transaction failure, if applicable. + transaction_failure_reason: Optional[TransactionFailureReason] + # The unique identifier of the block on the active chain containing the transaction. + block_hash: Optional[int] = field(metadata=fields.optional_block_hash_metadata) + # The sequence number of the block corresponding to block_hash, which is the number of blocks + # prior to it in the active chain. + block_number: Optional[int] = field(metadata=fields.default_optional_block_number_metadata) + # The index of the transaction within the block corresponding to block_hash. + transaction_index: Optional[int] = field( + metadata=fields.default_optional_transaction_index_metadata + ) + + def __post_init__(self): + super().__post_init__() + + # Validate NOT_RECEIVED/nonexistent status matches missing execution fields. + execution_fields = ( + self.block_hash, + self.block_number, + self.transaction_index, + self.transaction_failure_reason, + ) + if self.status is None or self.status in ( + TransactionStatus.NOT_RECEIVED, + TransactionStatus.RECEIVED, + ): + assert all(field is None for field in execution_fields), ( + "Transaction execution fields (block hash, block number, index in block, etc.) " + "must not appear in a transaction that is not yet in a block, " + "or when status is None. " + f"Status: {self.status}. Execution fields: block_hash: {self.block_hash}, " + f"block_number: {self.block_number}, transaction_index: {self.transaction_index}, " + f"transaction_failure_reason: {self.transaction_failure_reason}." + ) + + return + + # Validate REJECTED status matches existing failure reason field. + tx_rejected = self.status is TransactionStatus.REJECTED + has_failure_info = self.transaction_failure_reason is not None + assert ( + tx_rejected == has_failure_info + ), "A rejected transaction must contain failure information, and vice versa." + + # Validate PENDING status matches missing created block fields. + if self.status is TransactionStatus.PENDING: + assert ( + self.block_hash is None and self.block_number is None + ), "Block hash and block number must not appear in a pending transaction." + + return + + # Validate ACCEPTED_ON_L1/2 status matches existing missing created block fields. + minimal_remaining_status = TransactionStatus.ACCEPTED_ON_L2 + assert tx_rejected or self.status >= minimal_remaining_status, ( + f"Unexpected transaction status: {self.status}; expected status to be at least " + f"{minimal_remaining_status.name}." + ) + if not tx_rejected: + assert all( + field is not None + for field in (self.block_hash, self.block_number, self.transaction_index) + ), ( + "Block hash, block number and transaction index in block must appear in an " + "accepted transaction." + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionSpecificInfo(ValidatedMarshmallowDataclass): + tx_type: ClassVar[TransactionType] + + @classmethod + def from_internal(cls, internal_tx: InternalTransaction) -> "TransactionSpecificInfo": + if isinstance(internal_tx, InternalDeploy): + return DeploySpecificInfo.from_internal_deploy(internal_tx=internal_tx) + elif isinstance(internal_tx, InternalInvokeFunction): + return InvokeSpecificInfo.from_internal_invoke(internal_tx=internal_tx) + else: + raise NotImplementedError(f"No response object for {internal_tx}.") + + +@marshmallow_dataclass.dataclass(frozen=True) +class DeploySpecificInfo(TransactionSpecificInfo): + contract_address: int = field(metadata=fields.contract_address_metadata) + contract_address_salt: int = field(metadata=fields.contract_address_salt_metadata) + constructor_calldata: List[int] = field(metadata=fields.call_data_metadata) + transaction_hash: int = field(metadata=fields.transaction_hash_metadata) + tx_type: ClassVar[TransactionType] = TransactionType.DEPLOY + + @classmethod + def from_internal_deploy(cls, internal_tx: InternalDeploy) -> "DeploySpecificInfo": + return cls( + contract_address=internal_tx.contract_address, + contract_address_salt=internal_tx.contract_address_salt, + constructor_calldata=internal_tx.constructor_calldata, + transaction_hash=internal_tx.hash_value, + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class InvokeSpecificInfo(TransactionSpecificInfo): + contract_address: int = field(metadata=fields.contract_address_metadata) + entry_point_selector: int = field(metadata=fields.entry_point_selector_metadata) + entry_point_type: EntryPointType + calldata: List[int] = field(metadata=fields.call_data_metadata) + signature: List[int] = field(metadata=fields.signature_metadata) + transaction_hash: int = field(metadata=fields.transaction_hash_metadata) + tx_type: ClassVar[TransactionType] = TransactionType.INVOKE_FUNCTION + + @classmethod + def from_internal_invoke(cls, internal_tx: InternalInvokeFunction) -> "InvokeSpecificInfo": + return cls( + contract_address=internal_tx.contract_address, + entry_point_selector=internal_tx.entry_point_selector, + entry_point_type=internal_tx.entry_point_type, + calldata=internal_tx.calldata, + signature=internal_tx.signature, + transaction_hash=internal_tx.hash_value, + ) + + +class TransactionSpecificInfoSchema(OneOfSchema): + type_schemas: Dict[str, Type[marshmallow.Schema]] = { + TransactionType.DEPLOY.name: DeploySpecificInfo.Schema, + TransactionType.INVOKE_FUNCTION.name: InvokeSpecificInfo.Schema, + } + + def get_obj_type(self, obj: TransactionSpecificInfo) -> str: + return obj.tx_type.name + + +TransactionSpecificInfo.Schema = TransactionSpecificInfoSchema + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionInfo(TransactionInBlockInfo): + """ + Represents the information regarding a StarkNet transaction. + """ + + transaction: Optional[TransactionSpecificInfo] + + @classmethod + def create( + cls, + status: Optional[TransactionStatus], + transaction: Optional[InternalTransaction] = None, + transaction_failure_reason: Optional[TransactionFailureReason] = None, + block_hash: Optional[int] = None, + block_number: Optional[int] = None, + transaction_index: Optional[int] = None, + ) -> "TransactionInfo": + return cls( + transaction=None + if transaction is None + else TransactionSpecificInfo.from_internal(internal_tx=transaction), + status=status, + transaction_failure_reason=transaction_failure_reason, + block_hash=block_hash, + block_number=block_number, + transaction_index=transaction_index, + ) + + def __post_init__(self): + super().__post_init__() + + if self.status is not None and self.status is not TransactionStatus.NOT_RECEIVED: + assert ( + self.transaction is not None + ), "A received transaction must be included in TransactionInfo object." + + +@dataclasses.dataclass(frozen=True) +class L1ToL2Message(BaseResponseObject): + """ + Represents a StarkNet L1-to-L2 message. + """ + + from_address: str = field(metadata=everest_fields.EthAddressField.metadata("from_address")) + to_address: int = field(metadata=fields.contract_address_metadata) + selector: int = field(metadata=fields.entry_point_selector_metadata) + payload: List[int] = field(metadata=fields.felt_list_metadata) + nonce: Optional[int] = field(metadata=fields.optional_nonce_metadata) + + +@dataclasses.dataclass(frozen=True) +class L2ToL1Message(BaseResponseObject): + """ + Represents a StarkNet L2-to-L1 message. + """ + + from_address: int = field(metadata=fields.contract_address_metadata) + to_address: str = field(metadata=everest_fields.EthAddressField.metadata("to_address")) + payload: List[int] = field(metadata=fields.felt_list_metadata) + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionExecution(BaseResponseObject): + """ + Represents a receipt of an executed transaction. + """ + + # The index of the transaction within the block. + transaction_index: Optional[int] = field( + metadata=fields.default_optional_transaction_index_metadata + ) + # A unique identifier of the transaction. + transaction_hash: Optional[int] = field(metadata=fields.optional_transaction_hash_metadata) + # L1-to-L2 messages. + l1_to_l2_consumed_message: Optional[L1ToL2Message] + # L2-to-L1 messages. + l2_to_l1_messages: List[L2ToL1Message] + # Events emitted during the execution of the transaction. + events: List[Event] + # The resources needed by the transaction. + execution_resources: Optional[ExecutionResources] + + +@marshmallow_dataclass.dataclass(frozen=True) +class TransactionReceipt(TransactionExecution, TransactionInBlockInfo): + """ + Represents a receipt of a StarkNet transaction; + i.e., the information regarding its execution and the block it appears in. + """ + + def __post_init__(self): + super().__post_init__() + + if self.status is TransactionStatus.REJECTED and self.has_execution_info: + raise AssertionError("A rejected transaction cannot have execution info.") + + assert self.transaction_hash is not None, "A receipt must include a transaction_hash." + + @property + def has_execution_info(self) -> bool: + """ + Returns whether the transaction has execution info. + """ + return ( + self.l1_to_l2_consumed_message is not None + or self.execution_resources is not None + or len(self.l2_to_l1_messages) > 0 + or len(self.events) > 0 + ) + + @classmethod + def from_tx_info( + cls, + transaction_hash: int, + tx_info: TransactionInBlockInfo, + l1_to_l2_consumed_message: Optional[L1ToL2Message] = None, + l2_to_l1_messages: Optional[List[L2ToL1Message]] = None, + events: Optional[List[Event]] = None, + execution_resources: Optional[ExecutionResources] = None, + ) -> "TransactionReceipt": + return cls( + l1_to_l2_consumed_message=l1_to_l2_consumed_message, + l2_to_l1_messages=[] if l2_to_l1_messages is None else l2_to_l1_messages, + events=[] if events is None else events, + execution_resources=execution_resources, + transaction_hash=transaction_hash, + status=tx_info.status, + transaction_failure_reason=tx_info.transaction_failure_reason, + block_hash=tx_info.block_hash, + block_number=tx_info.block_number, + transaction_index=tx_info.transaction_index, + ) + + +@marshmallow_dataclass.dataclass(frozen=True) +class StorageEntry(BaseResponseObject): + """ + Represents a value stored in a single contract storage entry. + """ + + key: int = field(metadata=everest_fields.felt_metadata(name_in_error_message="key")) + value: int = field(metadata=everest_fields.felt_metadata(name_in_error_message="value")) + + +@marshmallow_dataclass.dataclass(frozen=True) +class DeployedContract(BaseResponseObject): + """ + Represents a newly deployed contract in a block state update. + """ + + address: int = field(metadata=fields.contract_address_metadata) + contract_hash: bytes = field(metadata=fields.contract_hash_metadata) + + +@marshmallow_dataclass.dataclass(frozen=True) +class StateDiff(BaseResponseObject): + """ + Represents the difference in the StarkNet state induced by applying a block's transactions. + """ + + storage_diffs: Dict[int, List[StorageEntry]] = field( + metadata=dict( + marshmallow_field=mfields.Dict( + keys=fields.ContractAddressField.get_marshmallow_field( + required=True, + load_default=marshmallow.utils.missing, + ), + values=mfields.List(mfields.Nested(StorageEntry.Schema)), + ) + ) + ) + + deployed_contracts: List[DeployedContract] + + +@marshmallow_dataclass.dataclass(frozen=True) +class BlockStateUpdate(BaseResponseObject): + """ + Represents a response block state update. + """ + + block_hash: Optional[int] = field(metadata=fields.optional_block_hash_metadata) + new_root: bytes = field(metadata=fields.state_root_metadata) + old_root: bytes = field(metadata=fields.state_root_metadata) + state_diff: StateDiff + + +@marshmallow_dataclass.dataclass(frozen=True) +class StarknetBlock(BaseResponseObject): + """ + Represents a response StarkNet block. + """ + + block_hash: Optional[int] = field(metadata=fields.optional_block_hash_metadata) + parent_block_hash: int = field(metadata=fields.block_hash_metadata) + block_number: Optional[int] = field(metadata=fields.default_optional_block_number_metadata) + state_root: Optional[bytes] = field(metadata=fields.optional_state_root_metadata) + status: Optional[BlockStatus] + transactions: Tuple[TransactionSpecificInfo, ...] = field( + metadata=dict( + marshmallow_field=VariadicLengthTupleField( + mfields.Nested(TransactionSpecificInfo.Schema) + ) + ) + ) + timestamp: int = field(metadata=fields.timestamp_metadata) + transaction_receipts: Tuple[TransactionExecution, ...] = field( + metadata=dict( + marshmallow_field=VariadicLengthTupleField(mfields.Nested(TransactionExecution.Schema)) + ) + ) + + @classmethod + def create( + cls, + block_hash: Optional[int], + parent_block_hash: int, + block_number: Optional[int], + state_root: Optional[bytes], + transactions: Iterable[InternalTransaction], + timestamp: int, + transaction_receipts: Tuple[TransactionExecution, ...], + status: Optional[BlockStatus], + ) -> "StarknetBlock": + return cls( + block_hash=block_hash, + parent_block_hash=parent_block_hash, + block_number=block_number, + state_root=state_root, + transactions=tuple( + TransactionSpecificInfo.from_internal(internal_tx=tx) for tx in transactions + ), + timestamp=timestamp, + transaction_receipts=transaction_receipts, + status=status, + ) + + def __post_init__(self): + super().__post_init__() + + tx_status_error_message = ( + "Transactions' status in block must match the status of the block." + ) + if self.status is None: + assert all( + tx_receipt.status is None for tx_receipt in self.transaction_receipts + ), tx_status_error_message + + return + + # Validate PENDING status matches missing created block fields. + created_block_fields = (self.block_hash, self.block_number, self.state_root) + if self.status is BlockStatus.PENDING: + assert all( + field is None for field in created_block_fields + ), "Block hash, block number, state_root must not appear in a pending block." + else: + assert all( + field is not None for field in created_block_fields + ), "Block hash, block number, state_root must appear in a created block." diff --git a/src/starkware/starknet/services/api/gateway/CMakeLists.txt b/src/starkware/starknet/services/api/gateway/CMakeLists.txt index 2e3fd124..f50635c6 100644 --- a/src/starkware/starknet/services/api/gateway/CMakeLists.txt +++ b/src/starkware/starknet/services/api/gateway/CMakeLists.txt @@ -1,17 +1,3 @@ -python_lib(starknet_transaction_hash_lib - PREFIX starkware/starknet/services/api/gateway - - FILES - transaction_hash.py - - LIBS - cairo_common_lib - cairo_vm_crypto_lib - starknet_contract_definition_lib - starknet_definitions_lib - starkware_python_utils_lib -) - python_lib(starknet_transaction_lib PREFIX starkware/starknet/services/api/gateway diff --git a/src/starkware/starknet/services/api/gateway/contract_address.py b/src/starkware/starknet/services/api/gateway/contract_address.py index a6f0627f..630e187b 100644 --- a/src/starkware/starknet/services/api/gateway/contract_address.py +++ b/src/starkware/starknet/services/api/gateway/contract_address.py @@ -27,6 +27,26 @@ def calculate_contract_address( contract_hash = compute_contract_hash( contract_definition=contract_definition, hash_func=hash_function ) + return calculate_contract_address_from_hash( + salt=salt, + contract_hash=contract_hash, + constructor_calldata=constructor_calldata, + caller_address=caller_address, + hash_function=hash_function, + ) + + +def calculate_contract_address_from_hash( + salt: int, + contract_hash: int, + constructor_calldata: Sequence[int], + caller_address: int, + hash_function: Callable[[int, int], int] = pedersen_hash, +) -> int: + """ + Same as calculate_contract_address(), except that it gets contract_hash instead of + contract_definition. + """ constructor_calldata_hash = compute_hash_on_elements( data=constructor_calldata, hash_func=hash_function ) diff --git a/src/starkware/starknet/services/api/gateway/transaction.py b/src/starkware/starknet/services/api/gateway/transaction.py index c3baf287..af409de3 100644 --- a/src/starkware/starknet/services/api/gateway/transaction.py +++ b/src/starkware/starknet/services/api/gateway/transaction.py @@ -11,17 +11,17 @@ from marshmallow_oneofschema import OneOfSchema from services.everest.api.gateway.transaction import EverestTransaction +from starkware.starknet.core.os.transaction_hash import ( + TransactionHashPrefix, + calculate_deploy_transaction_hash, + calculate_transaction_hash_common, +) from starkware.starknet.definitions import fields from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.services.api.gateway.contract_address import calculate_contract_address -from starkware.starknet.services.api.gateway.transaction_hash import ( - TransactionHashPrefix, - calculate_deploy_transaction_hash, - calculate_transaction_hash_common, -) from starkware.starkware_utils.error_handling import wrap_with_stark_exception diff --git a/src/starkware/starknet/services/api/messages.py b/src/starkware/starknet/services/api/messages.py index dc48ae06..1215a631 100644 --- a/src/starkware/starknet/services/api/messages.py +++ b/src/starkware/starknet/services/api/messages.py @@ -1,31 +1,36 @@ import dataclasses +from abc import ABC, abstractmethod from dataclasses import field -from enum import Enum, auto from typing import List from services.everest.definitions import fields as everest_fields from starkware.cairo.bootloader.compute_fact import keccak_ints from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction from starkware.starknet.definitions import fields -from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.services.api.contract_definition import EntryPointType -from starkware.starkware_utils.error_handling import stark_assert from starkware.starkware_utils.validated_dataclass import ValidatedDataclass -class MessageType(Enum): - L1_TO_L2 = 0 - L2_TO_L1 = auto() +class StarknetMessage(ABC, ValidatedDataclass): + """ + Abstract base class for StarkNet Messages. + """ + + @abstractmethod + def encode(self) -> List[int]: + """ + Encodes the message as it would appear in the output of the StarkNet OS. + """ @dataclasses.dataclass(frozen=True) -class StarknetMessageToL1(ValidatedDataclass): +class StarknetMessageToL1(StarknetMessage): """ A StarkNet Message from L2 to L1. """ from_address: int = field( - metadata=everest_fields.felt_metadata(name_in_error_message="from_address") + metadata=fields.ContractAddressField.metadata(field_name="from_address") ) to_address: int = field( metadata=everest_fields.EthAddressIntField.metadata(field_name="to_address") @@ -40,7 +45,7 @@ def get_hash(self) -> str: @dataclasses.dataclass(frozen=True) -class StarknetMessageToL2(ValidatedDataclass): +class StarknetMessageToL2(StarknetMessage): """ A StarkNet Message from L1 to L2. """ @@ -48,9 +53,7 @@ class StarknetMessageToL2(ValidatedDataclass): from_address: int = field( metadata=everest_fields.EthAddressIntField.metadata(field_name="from_address") ) - to_address: int = field( - metadata=everest_fields.felt_metadata(name_in_error_message="to_address") - ) + to_address: int = field(metadata=fields.ContractAddressField.metadata(field_name="to_address")) l1_handler_selector: int payload: List[int] = field(metadata=fields.felt_list_metadata) nonce: int = field(metadata=everest_fields.felt_metadata(name_in_error_message="nonce")) @@ -74,12 +77,6 @@ def get_message_hash_from_tx(tx: InternalInvokeFunction) -> str: tx.entry_point_type is EntryPointType.L1_HANDLER ), f"Transaction must be of type L1_HANDLER. Got: {tx.entry_point_type.name}." - stark_assert( - tx.nonce is not None, - code=StarknetErrorCode.UNEXPECTED_FAILURE, - message="L1 handlers must include a nonce.", - ) - assert tx.nonce is not None, "L1 handlers must include a nonce." return StarknetMessageToL2( diff --git a/src/starkware/starknet/testing/CMakeLists.txt b/src/starkware/starknet/testing/CMakeLists.txt index d473e74e..080a9ea9 100644 --- a/src/starkware/starknet/testing/CMakeLists.txt +++ b/src/starkware/starknet/testing/CMakeLists.txt @@ -54,8 +54,10 @@ python_lib(starknet_testing_lib starknet_internal_transaction_lib starknet_messages_lib starknet_mock_messaging_contracts_lib + starknet_transaction_execution_objects_lib starknet_transaction_lib starkware_dataclasses_utils_lib + starkware_eth_test_utils_lib starkware_python_utils_lib starkware_storage_lib pip_typeguard diff --git a/src/starkware/starknet/testing/conftest.py b/src/starkware/starknet/testing/conftest.py index 184b0a81..2d13cc8e 100644 --- a/src/starkware/starknet/testing/conftest.py +++ b/src/starkware/starknet/testing/conftest.py @@ -1,9 +1,11 @@ +from typing import Iterator + import pytest from starkware.eth.eth_test_utils import EthTestUtils @pytest.fixture(scope="session") -def eth_test_utils() -> EthTestUtils: +def eth_test_utils() -> Iterator[EthTestUtils]: with EthTestUtils.context_manager() as val: yield val diff --git a/src/starkware/starknet/testing/starknet.py b/src/starkware/starknet/testing/starknet.py index a5c642a3..e71f92a6 100644 --- a/src/starkware/starknet/testing/starknet.py +++ b/src/starkware/starknet/testing/starknet.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +from starkware.starknet.business_logic.transaction_execution_objects import TransactionExecutionInfo from starkware.starknet.compiler.compile import compile_starknet_files from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.services.api.contract_definition import ContractDefinition, EntryPointType @@ -85,7 +86,7 @@ async def send_message_to_l2( selector: Union[int, str], payload: List[int], nonce: Optional[int] = None, - ): + ) -> TransactionExecutionInfo: """ Mocks the L1 contract function sendMessageToL2. @@ -96,7 +97,7 @@ async def send_message_to_l2( nonce = self.l1_to_l2_nonce self.l1_to_l2_nonce += 1 - await self.state.invoke_raw( + return await self.state.invoke_raw( contract_address=to_address, selector=selector, calldata=[from_address, *payload], diff --git a/src/starkware/starknet/testing/state.py b/src/starkware/starknet/testing/state.py index 9846c321..404acc1b 100644 --- a/src/starkware/starknet/testing/state.py +++ b/src/starkware/starknet/testing/state.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Dict, List, Optional, Tuple, Union from starkware.cairo.lang.vm.crypto import pedersen_hash_func from starkware.starknet.business_logic.internal_transaction import ( @@ -15,7 +15,6 @@ from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.services.api.contract_definition import ContractDefinition, EntryPointType -from starkware.starknet.services.api.gateway.transaction import Deploy from starkware.starknet.services.api.messages import StarknetMessageToL1 from starkware.storage.dict_storage import DictStorage from starkware.storage.storage import FactFetchingContext @@ -65,7 +64,7 @@ async def empty(cls, general_config: Optional[StarknetGeneralConfig] = None) -> general_config = StarknetGeneralConfig() ffc = FactFetchingContext(storage=DictStorage(), hash_func=pedersen_hash_func) - state = await CarriedState.create_empty_for_test( + state = await CarriedState.empty_for_testing( shared_state=None, ffc=ffc, general_config=general_config ) @@ -92,16 +91,12 @@ async def deploy( contract_address_salt = int(contract_address_salt, 16) assert isinstance(contract_address_salt, int) - external_tx = Deploy( - contract_address_salt=contract_address_salt, + tx = await InternalDeploy.create_for_testing( + ffc=self.state.ffc, contract_definition=contract_definition, + contract_address_salt=contract_address_salt, constructor_calldata=constructor_calldata, - ) - tx = cast( - InternalDeploy, - InternalDeploy.from_external( - external_tx=external_tx, general_config=self.general_config - ), + chain_id=self.general_config.chain_id.value, ) with self.state.copy_and_apply() as state_copy: diff --git a/src/starkware/starknet/testing/test.cairo b/src/starkware/starknet/testing/test.cairo index 86d593d3..cede49ee 100644 --- a/src/starkware/starknet/testing/test.cairo +++ b/src/starkware/starknet/testing/test.cairo @@ -5,7 +5,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.registers import get_fp_and_pc from starkware.starknet.common.messages import send_message_to_l1 from starkware.starknet.common.syscalls import ( - get_caller_address, get_tx_signature, storage_read, storage_write) + get_caller_address, get_tx_info, storage_read, storage_write) @contract_interface namespace MyContract: @@ -46,8 +46,8 @@ end @external func get_signature{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*}() -> ( res_len : felt, res : felt*): - let (sig_len, sig) = get_tx_signature() - return (res_len=sig_len, res=sig) + let (tx_info) = get_tx_info() + return (res_len=tx_info.signature_len, res=tx_info.signature) end @external diff --git a/src/starkware/starknet/wallets/CMakeLists.txt b/src/starkware/starknet/wallets/CMakeLists.txt index 0c545519..084ddd1b 100644 --- a/src/starkware/starknet/wallets/CMakeLists.txt +++ b/src/starkware/starknet/wallets/CMakeLists.txt @@ -18,7 +18,12 @@ python_lib(starknet_standard_wallets_lib open_zeppelin.py LIBS + cairo_common_lib open_zeppelin_contracts_lib + starknet_abi_lib + starknet_definitions_lib + starknet_transaction_lib starknet_wallets_lib starkware_crypto_lib + starkware_error_handling_lib ) diff --git a/src/starkware/starkware_utils/CMakeLists_common.txt b/src/starkware/starkware_utils/CMakeLists_common.txt index 5b3edbd6..55a7648f 100644 --- a/src/starkware/starkware_utils/CMakeLists_common.txt +++ b/src/starkware/starkware_utils/CMakeLists_common.txt @@ -33,6 +33,7 @@ python_lib(starkware_dataclasses_utils_lib FILES field_validators.py marshmallow_dataclass_fields.py + serializable_dataclass.py validated_dataclass.py validated_fields.py diff --git a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py index 4f131465..6f6d8dcd 100644 --- a/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py +++ b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py @@ -1,14 +1,17 @@ import asyncio from abc import ABC, abstractmethod -from typing import Collection, Dict, Optional, Tuple, Type, TypeVar +from typing import AsyncIterator, Collection, Dict, List, Optional, Tuple, Type, TypeVar from starkware.python.utils import from_bytes from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict, TFact from starkware.starkware_utils.commitment_tree.inner_node_fact import InnerNodeFact +from starkware.starkware_utils.commitment_tree.merkle_tree.traverse_tree import traverse_tree from starkware.storage.storage import FactFetchingContext TInnerNodeFact = TypeVar("TInnerNodeFact", bound=InnerNodeFact) TBinaryFactTreeNode = TypeVar("TBinaryFactTreeNode", bound="BinaryFactTreeNode") +_BinaryFactTreeNodePair = Tuple[TBinaryFactTreeNode, TBinaryFactTreeNode] +_BinaryFactTreeDiff = Tuple[int, _BinaryFactTreeNodePair] class BinaryFactTreeNode(ABC): @@ -50,22 +53,6 @@ def _leaf_hash(self) -> bytes: def create_leaf(cls: Type[TBinaryFactTreeNode], hash_value: bytes) -> TBinaryFactTreeNode: pass - @classmethod - @abstractmethod - async def combine( - cls: Type[TBinaryFactTreeNode], - ffc: FactFetchingContext, - left: "BinaryFactTreeNode", - right: "BinaryFactTreeNode", - facts: Optional[BinaryFactDict] = None, - ) -> TBinaryFactTreeNode: - """ - Gets two BinaryFactTreeNode objects left and right representing children nodes, and builds - their parent node. Returns a new BinaryFactTreeNode. - - If facts argument is not None, this dictionary is filled with facts read from the DB. - """ - @abstractmethod async def get_children( self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] = None @@ -113,7 +100,7 @@ def unify_leaves( return {} if self.is_leaf: - assert set(indices) == {0}, f"Merkle tree indices out of range: {indices}." + assert set(indices) == {0}, f"Commitment tree indices out of range: {indices}." leaf = await fact_cls.get_or_fail(storage=ffc.storage, suffix=self.leaf_hash) return {0: leaf} @@ -145,6 +132,60 @@ def unify_leaves( return unify_leaves(left_leaves=left_leaves, right_leaves=right_leaves) + async def get_diff_between_trees( + self, + other: TBinaryFactTreeNode, + ffc: FactFetchingContext, + fact_cls: Type[TFact], + facts: Optional[BinaryFactDict] = None, + ) -> List[Tuple[int, TFact, TFact]]: + """ + Returns a list of (key, old_fact, new_fact) that are different + between this tree and another. + + The height of the two trees must be equel. + + If the 'facts' argument is not None, this dictionary is filled with facts read from the DB. + """ + assert self.get_height_in_tree() == other.get_height_in_tree(), ( + f"Tree heights must be equal. Got: {other.get_height_in_tree()} for 'other'; " + f"expected: {self.get_height_in_tree()}." + ) + result: List[Tuple[int, TFact, TFact]] = [] + + async def get_children_callback( + node: _BinaryFactTreeDiff, + ) -> AsyncIterator[_BinaryFactTreeDiff]: + path, (previous, current) = node + if previous.is_leaf: + result.append( + ( + path, + await fact_cls.get_or_fail(suffix=previous.leaf_hash, storage=ffc.storage), + await fact_cls.get_or_fail(suffix=current.leaf_hash, storage=ffc.storage), + ) + ) + return + + previous_left, previous_right = await previous.get_children(ffc=ffc, facts=facts) + current_left, current_right = await current.get_children(ffc=ffc, facts=facts) + + if previous_left != current_left: + # Shift left for the left child. + yield (path << 1, (previous_left, current_left)) + + if previous_right != current_right: + # Shift left and turn on the LSB bit for the right child. + yield ((path << 1) + 1, (previous_right, current_right)) + + await traverse_tree( + get_children_callback=get_children_callback, + root=(0, (self, other)), + n_workers=ffc.n_workers, + ) + + return result + async def read_node_fact( ffc: FactFetchingContext, diff --git a/src/starkware/starkware_utils/commitment_tree/calculation.py b/src/starkware/starkware_utils/commitment_tree/calculation.py index a9aeb35d..a1e90f90 100644 --- a/src/starkware/starkware_utils/commitment_tree/calculation.py +++ b/src/starkware/starkware_utils/commitment_tree/calculation.py @@ -10,13 +10,13 @@ TBinaryFactTreeNode, TInnerNodeFact, ) -from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.storage.storage import FactFetchingContext, HashFunctionType T = TypeVar("T") TCalculationNode = TypeVar("TCalculationNode", bound="CalculationNode") NodeFactDict = Dict[bytes, TInnerNodeFact] + class Calculation(Generic[T], ABC): """ A calculation that can produce a result of type T. The calculation is dependent on the results @@ -140,6 +140,33 @@ async def full_calculate_with_executor( return result +# NOTE: We avoid using ValidatedDataclass here for performance. +@dataclasses.dataclass(frozen=True) +class ConstantCalculation(Calculation[T]): + """ + A calculation that contains a value and simply produces it. It doesn't depend on any other + calculations. + """ + + value: T + + def calculate( + self, + dependency_results: list, + hash_func: HashFunctionType, + fact_nodes: NodeFactDict, + ) -> T: + assert len(dependency_results) == 0, "ConstantCalculation has no dependencies." + return self.value + + def get_dependency_calculations(self) -> List[Calculation[T]]: + return [] + + +# A calculation that produces a hash result. +HashCalculation = Calculation[bytes] + + class CalculationNode(Calculation[TBinaryFactTreeNode], ABC): """ A calculation that produces a BinaryFactTreeNode. The calculation can be created from either a @@ -168,31 +195,3 @@ def create(cls: Type[TCalculationNode], node: TBinaryFactTreeNode) -> TCalculati Creates a Calculation object from a node. It will produce the node and will have no dependencies. """ - - -class HashCalculation(Calculation[bytes]): - """ - A calculation that produces a hash result. - """ - - -@dataclasses.dataclass(frozen=True) -class ConstantCalculation(HashCalculation, ValidatedDataclass): - """ - A calculation that contains a hash and simply produces it. It doesn't depend on any other - calculations. It constitutes a leaf calculation so that other calculations can depend on it. - """ - - hash_value: bytes - - def calculate( - self, - dependency_results: list, - hash_func: HashFunctionType, - fact_nodes: NodeFactDict, - ) -> bytes: - assert len(dependency_results) == 0, "ConstantCalculation has no dependencies." - return self.hash_value - - def get_dependency_calculations(self) -> List[Calculation[bytes]]: - return [] diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py index c09bbebb..50c87866 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py @@ -4,11 +4,11 @@ from starkware.python.utils import blockify, from_bytes, to_bytes from starkware.starkware_utils.commitment_tree.inner_node_fact import InnerNodeFact -from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.storage.storage import HASH_BYTES, HashFunctionType -class PatriciaNodeFact(InnerNodeFact, ValidatedDataclass): +# NOTE: We avoid using ValidatedDataclass here for performance. +class PatriciaNodeFact(InnerNodeFact): """ Base abstract class of Patricia-Merkle tree nodes. """ @@ -68,6 +68,7 @@ def verify_path_value(path: int, length: int): ), f"Edge path must be at most of length {length}; got: {bin(path)}." +# NOTE: We avoid using ValidatedDataclass here for performance. @dataclasses.dataclass(frozen=True) class BinaryNodeFact(PatriciaNodeFact): """ @@ -81,8 +82,6 @@ class BinaryNodeFact(PatriciaNodeFact): PREIMAGE_LENGTH: ClassVar[int] = 2 * HASH_BYTES def __post_init__(self): - super().__post_init__() - legal_binary_node = ( self.left_node != EmptyNodeFact.EMPTY_NODE_HASH and self.right_node != EmptyNodeFact.EMPTY_NODE_HASH @@ -110,6 +109,7 @@ def to_tuple(self) -> Tuple[int, ...]: return from_bytes(self.left_node), from_bytes(self.right_node) +# NOTE: We avoid using ValidatedDataclass here for performance. @dataclasses.dataclass(frozen=True) class EdgeNodeFact(PatriciaNodeFact): """ @@ -132,8 +132,6 @@ class EdgeNodeFact(PatriciaNodeFact): PREIMAGE_LENGTH: ClassVar[int] = 2 * HASH_BYTES + 1 def __post_init__(self): - super().__post_init__() - assert ( self.edge_length > 0 ), f"The length of an edge node must be positive; got: {self.edge_length}." diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py index 6c0829cd..1409959d 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_calculation_node.py @@ -20,7 +20,6 @@ from starkware.starkware_utils.commitment_tree.patricia_tree.virtual_patricia_node import ( VirtualPatriciaNode, ) -from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.storage.storage import FactFetchingContext, HashFunctionType @@ -67,8 +66,9 @@ def calculate( return fact_hash +# NOTE: We avoid using ValidatedDataclass here for performance. @dataclasses.dataclass(frozen=True) -class VirtualCalculationNode(CalculationNode[VirtualPatriciaNode], ValidatedDataclass): +class VirtualCalculationNode(CalculationNode[VirtualPatriciaNode]): """ Represents a virtual calculation node. It consists a bottom calculation and a virtual edge that can be empty (and then the calculation just represents the bottom calculation). @@ -91,16 +91,15 @@ def __post_init__(self): Note that many of the functions in this class rely on the invariants checked in this function, and on the fact they are made at initialization time (the object is immutable). """ - super().__post_init__() - verify_path_value(path=self.path, length=self.length) @classmethod def create(cls, node: VirtualPatriciaNode): if node.is_empty: return cls.empty_node(height=node.height) + return cls( - bottom_calculation=ConstantCalculation(hash_value=node.bottom_node), + bottom_calculation=ConstantCalculation(value=node.bottom_node), path=node.path, length=node.length, height=node.height, @@ -117,7 +116,11 @@ def empty_node(cls, height: int) -> "VirtualCalculationNode": @property def is_empty(self) -> bool: - return self.bottom_calculation == ConstantCalculation(EmptyNodeFact.EMPTY_NODE_HASH) + # NOTE: we compare directly the values (instead of comparing objects) for performance. + if isinstance(self.bottom_calculation, ConstantCalculation): + return self.bottom_calculation.value == EmptyNodeFact.EMPTY_NODE_HASH + + return False @property def is_virtual_edge(self) -> bool: @@ -248,13 +251,13 @@ async def _decommit( bottom_fact = await read_node_fact( ffc=ffc, inner_node_fact_cls=PatriciaNodeFact, # type: ignore - fact_hash=self.bottom_calculation.hash_value, + fact_hash=self.bottom_calculation.value, facts=facts, ) if isinstance(bottom_fact, EdgeNodeFact): # Moving the edge of the fact into the virtual edge. return VirtualCalculationNode( - bottom_calculation=ConstantCalculation(hash_value=bottom_fact.bottom_node), + bottom_calculation=ConstantCalculation(value=bottom_fact.bottom_node), path=bottom_fact.edge_path, length=bottom_fact.edge_length, height=self.height, diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py index f5f75816..d226dc36 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py @@ -1,4 +1,3 @@ -import asyncio import dataclasses from typing import Optional, Tuple @@ -15,12 +14,12 @@ PatriciaNodeFact, verify_path_value, ) -from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.storage.storage import FactFetchingContext +# NOTE: We avoid using ValidatedDataclass here for performance. @dataclasses.dataclass(frozen=True) -class VirtualPatriciaNode(BinaryFactTreeNode, ValidatedDataclass): +class VirtualPatriciaNode(BinaryFactTreeNode): """ Represents a virtual Patricia node. Virtual node instances are used to build and traverse through a Patricia tree. @@ -43,8 +42,6 @@ def __post_init__(self): Note that many of the functions in this class rely on the invariants checked in this function, and on the fact they are made at initialization time (the object is immutable). """ - super().__post_init__() - verify_path_value(path=self.path, length=self.length) @classmethod @@ -101,66 +98,6 @@ async def commit(self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] ) return await write_node_fact(ffc=ffc, inner_node_fact=edge_node_fact, facts=facts) - async def decommit( - self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] - ) -> "VirtualPatriciaNode": - """ - Returns the canonical representation of the information embedded in self. - Returns (bottom, path, length) for an edge node of form (hash, 0, 0), which is the - canonical form. - """ - if self.is_leaf or self.is_empty or self.is_virtual_edge: - # Node is already decommitted (of canonical form); no work to be done. - return self - - # Need to read fact from storage to understand if (hash, 0, 0) represents a binary node, - # or a committed edge node. - # Note that a fact that was written in a previous combine while building this tree will - # appear in cache (in case the FFC's storage is cached). - root_node_fact = await self.read_bottom_node_fact(ffc=ffc, facts=facts) - - if isinstance(root_node_fact, BinaryNodeFact): - return self - if isinstance(root_node_fact, EdgeNodeFact): - return VirtualPatriciaNode( - bottom_node=root_node_fact.bottom_node, - path=root_node_fact.edge_path, - length=root_node_fact.edge_length, - height=self.height, - ) - - raise NotImplementedError(f"Unexpected node fact type: {type(root_node_fact).__name__}.") - - @classmethod - async def combine( - cls, - ffc: FactFetchingContext, - left: "BinaryFactTreeNode", - right: "BinaryFactTreeNode", - facts: Optional[BinaryFactDict] = None, - ) -> "VirtualPatriciaNode": - """ - Gets two VirtualPatriciaNode objects left and right representing children nodes, and builds - their parent node. Returns a new VirtualPatriciaNode. - - If facts argument is not None, this dictionary is filled with facts read from the DB. - """ - # Downcast arguments. - assert isinstance(left, VirtualPatriciaNode) and isinstance(right, VirtualPatriciaNode) - - assert ( - right.height == left.height - ), f"Only trees of same height can be combined; got: {right.height} and {left.height}." - - parent_height = right.height + 1 - if left.is_empty and right.is_empty: - return VirtualPatriciaNode.empty_node(height=parent_height) - - if not left.is_empty and not right.is_empty: - return await cls._combine_to_binary_node(ffc=ffc, left=left, right=right, facts=facts) - - return await cls._combine_to_virtual_edge_node(ffc=ffc, left=left, right=right, facts=facts) - async def get_children( self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] = None ) -> Tuple["VirtualPatriciaNode", "VirtualPatriciaNode"]: @@ -200,65 +137,6 @@ async def get_children( self.from_hash(hash_value=fact.right_node, height=children_height), ) - # Internal utils. - - @classmethod - async def _combine_to_binary_node( - cls, - ffc: FactFetchingContext, - left: "VirtualPatriciaNode", - right: "VirtualPatriciaNode", - facts: Optional[BinaryFactDict], - ) -> "VirtualPatriciaNode": - """ - Combines two non-empty nodes to form a binary node. - Writes the constructed node fact to the DB, as well as (up to) two other facts for the - children if they were not previously committed. - """ - left_node_hash, right_node_hash = await asyncio.gather( - *(node.commit(ffc=ffc, facts=facts) for node in (left, right)) - ) - parent_node_fact = BinaryNodeFact(left_node=left_node_hash, right_node=right_node_hash) - parent_fact_hash = await write_node_fact( - ffc=ffc, inner_node_fact=parent_node_fact, facts=facts - ) - - return VirtualPatriciaNode( - bottom_node=parent_fact_hash, path=0, length=0, height=right.height + 1 - ) - - @classmethod - async def _combine_to_virtual_edge_node( - cls, - ffc: FactFetchingContext, - left: "VirtualPatriciaNode", - right: "VirtualPatriciaNode", - facts: Optional[BinaryFactDict], - ) -> "VirtualPatriciaNode": - """ - Combines an empty node and a non-empty node to form a virtual edge node. - If the non-empty node is not known to be of canonical form, reads its fact from the DB - in order to make it such (or make sure it is). - """ - assert ( - left.is_empty != right.is_empty - ), "_combine_to_virtual_edge_node() must be called on one empty and one non-empty nodes." - - non_empty_child = right if left.is_empty else left - non_empty_child = await non_empty_child.decommit(ffc=ffc, facts=facts) - - parent_path = non_empty_child.path - if left.is_empty: - # Turn on the MSB bit if the non-empty child is on the right. - parent_path += 1 << non_empty_child.length - - return VirtualPatriciaNode( - bottom_node=non_empty_child.bottom_node, - path=parent_path, - length=non_empty_child.length + 1, - height=non_empty_child.height + 1, - ) - def _get_virtual_edge_node_children( self, ) -> Tuple["VirtualPatriciaNode", "VirtualPatriciaNode"]: diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py index 02fcd7f0..03bb708a 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py @@ -10,6 +10,7 @@ BinaryNodeFact, EdgeNodeFact, ) +from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.starkware_utils.commitment_tree.patricia_tree.virtual_calculation_node import ( VirtualCalculationNode, ) @@ -188,3 +189,45 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): sorted_by_index_leaf_values = [updated_leaves[leaf_id].value for leaf_id in leaves_range] expected_root_hash = await tree.commit(ffc=ffc, facts=None) # Root is an edge node. verify_root(leaves=sorted_by_index_leaf_values, expected_root_hash=expected_root_hash) + + +@pytest.mark.asyncio +async def test_binary_fact_tree_node_create_diff(ffc: FactFetchingContext): + # All tree values ​​are zero. + empty_tree = await PatriciaTree.empty_tree(ffc=ffc, height=251, leaf_fact=LeafFact(value=0)) + virtual_empty_tree_node = VirtualPatriciaNode.from_hash( + hash_value=empty_tree.root, height=empty_tree.height + ) + + # All tree values ​​are zero except for the fifth leaf, which has a value of 8. + one_change_tree = await empty_tree.update(ffc=ffc, modifications=[(5, LeafFact(value=8))]) + virtual_one_change_node = VirtualPatriciaNode.from_hash( + hash_value=one_change_tree.root, height=empty_tree.height + ) + + # All tree values ​​are zero except for the fifth leaf, which has a value of 8. + # and the 58th leaf, which is 81. + two_change_tree = await one_change_tree.update( + ffc=ffc, modifications=[(58, LeafFact(value=81))] + ) + virtual_two_change_node = VirtualPatriciaNode.from_hash( + hash_value=two_change_tree.root, height=empty_tree.height + ) + + # The difference between the tree whose values are all zero and the tree that has + # all values zero except two values is exactly the 2 values. + diff_result = await virtual_empty_tree_node.get_diff_between_trees( + other=virtual_two_change_node, ffc=ffc, fact_cls=LeafFact + ) + assert diff_result == [ + (5, LeafFact(value=0), LeafFact(value=8)), + (58, LeafFact(value=0), LeafFact(value=81)), + ] + + # The difference between the tree whose values are zero except for the fifth leaf + # and the tree whose values are all zero except for the fifth leaf (there they are equal) + # and for the 58th leaf is exactly the 58th leaf. + diff_result = await virtual_one_change_node.get_diff_between_trees( + other=virtual_two_change_node, ffc=ffc, fact_cls=LeafFact + ) + assert diff_result == [(58, LeafFact(value=0), LeafFact(value=81))] diff --git a/src/starkware/starkware_utils/commitment_tree/update_tree.py b/src/starkware/starkware_utils/commitment_tree/update_tree.py index 9c779a99..81a9e0ab 100644 --- a/src/starkware/starkware_utils/commitment_tree/update_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/update_tree.py @@ -1,7 +1,7 @@ from typing import Any, AsyncIterator, Collection, Dict, NamedTuple, Optional, Tuple, Type, Union from starkware.python.utils import from_bytes, gather_in_chunks -from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict, TFact +from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict from starkware.starkware_utils.commitment_tree.binary_fact_tree_node import ( BinaryFactTreeNode, TBinaryFactTreeNode, @@ -11,6 +11,8 @@ from starkware.starkware_utils.executor import executor_ctx_var from starkware.storage.storage import Fact, FactFetchingContext +# Should be Tuple["UpdateTree", "UpdateTree"], but recursive types are not supported in mypy: +# https://github.com/python/mypy/issues/731. UpdateTree = Optional[Union[Tuple[Any, Any], Fact]] NodeType = NamedTuple( "NodeType", [("index", int), ("tree", BinaryFactTreeNode), ("update", UpdateTree)] @@ -40,7 +42,7 @@ async def update_tree( """ # A map from node index to the updated subtree, in which inner nodes are not hashed yet, but # rather consist of their children. - # This map is populated when we traverse a node and know it's value in the updated tree. + # This map is populated when we traverse a node and know its value in the updated tree. # This happens when either of these happens: # 1. The node has no updates => value remains the same. # 2. Node is a leaf update, and we just updated the leaf value. @@ -67,6 +69,23 @@ async def update_necessary(node_index: int): del updated_nodes[2 * node_index] del updated_nodes[2 * node_index + 1] + async def update_if_possible(node_index: int, binary_fact_tree_node: BinaryFactTreeNode): + updated_nodes[node_index] = calculation_node_cls.create( + node=binary_fact_tree_node, + ) + await update_necessary(node_index=node_index) + + async def set_fact( + new_fact: UpdateTree, node_index: int, binary_fact_tree_node: BinaryFactTreeNode + ): + assert isinstance(new_fact, Fact) + + leaf_hash = await new_fact.set_fact(ffc=ffc) + updated_nodes[node_index] = calculation_node_cls.create( + node=binary_fact_tree_node.create_leaf(hash_value=leaf_hash) + ) + await update_necessary(node_index=node_index) + async def traverse_node(node: NodeType) -> AsyncIterator[NodeType]: """ Callback function for traverse_tree(). @@ -77,23 +96,19 @@ async def traverse_node(node: NodeType) -> AsyncIterator[NodeType]: node_index, binary_fact_tree_node, update_subtree = node if update_subtree is None: - # No update to subtree. - updated_nodes[node_index] = calculation_node_cls.create( - node=binary_fact_tree_node, + # No updates to subtree. + await update_if_possible( + node_index=node_index, binary_fact_tree_node=binary_fact_tree_node ) - await update_necessary(node_index=node_index) return if binary_fact_tree_node.is_leaf: # Leaf update. - new_fact = update_subtree - assert isinstance(new_fact, Fact) - - leaf_hash = await new_fact.set_fact(ffc=ffc) - updated_nodes[node_index] = calculation_node_cls.create( - node=binary_fact_tree_node.create_leaf(hash_value=leaf_hash) + await set_fact( + new_fact=update_subtree, + node_index=node_index, + binary_fact_tree_node=binary_fact_tree_node, ) - await update_necessary(node_index=node_index) return # Inner node with updates. @@ -131,14 +146,14 @@ async def build_updated_calculation() -> CalculationNode: return root_node -def build_update_tree(height: int, modifications: Collection[Tuple[int, TFact]]) -> UpdateTree: +def build_update_tree(height: int, modifications: Collection[Tuple[int, Fact]]) -> UpdateTree: """ Constructs a tree from leaf updates. This is not a full binary tree. It is just the subtree induced by the modification leaves. Returns a tree. A tree is either: - * None - * a pair of trees - * A leaf + * None (if no modifications exist in its subtree). + * A leaf (if a single modification is given at height 0; i.e., a leaf). + * A pair of trees. """ # Bottom layer. This will prefer the last modification to an index. if len(modifications) == 0: diff --git a/src/starkware/starkware_utils/serializable.py b/src/starkware/starkware_utils/serializable.py index b72df70e..d2161473 100644 --- a/src/starkware/starkware_utils/serializable.py +++ b/src/starkware/starkware_utils/serializable.py @@ -1,7 +1,7 @@ import inspect from abc import ABC, abstractmethod from json import JSONDecoder, JSONEncoder -from typing import ClassVar, Dict, Type, TypeVar +from typing import ClassVar, Dict, Optional, Type, TypeVar from starkware.python.utils import camel_to_snake_case @@ -87,7 +87,7 @@ def __init_subclass__(cls, **kwargs): StringSerializable._classes[cls._serialize_name] = cls @abstractmethod - def dumps(self) -> str: + def dumps(self, indent: Optional[int] = None, sort_keys: bool = False) -> str: pass @classmethod diff --git a/src/starkware/starkware_utils/serializable_dataclass.py b/src/starkware/starkware_utils/serializable_dataclass.py new file mode 100644 index 00000000..bfdf85d6 --- /dev/null +++ b/src/starkware/starkware_utils/serializable_dataclass.py @@ -0,0 +1,32 @@ +from typing import ClassVar, Optional, Type, TypeVar + +import marshmallow + +from starkware.starkware_utils.serializable import StringSerializable + +TSerializableDataclass = TypeVar("TSerializableDataclass", bound="SerializableMarshmallowDataclass") + + +class SerializableMarshmallowDataclass(StringSerializable): + """ + Base class to classes decorated with marshmallow_dataclass.dataclass, implementing the + Serializable interface. + """ + + Schema: ClassVar[Type[marshmallow.Schema]] + + def dump(self) -> dict: + return self.Schema().dump(obj=self) + + @classmethod + def load(cls: Type[TSerializableDataclass], data: dict) -> TSerializableDataclass: + return cls.Schema().load(data=data) + + def dumps(self, indent: Optional[int] = None, sort_keys: Optional[bool] = None) -> str: + sort_keys = False if sort_keys is None else sort_keys + # An indent level of 0 will only insert newlines; 'None' is the most compact representation. + return self.Schema().dumps(obj=self, indent=indent, sort_keys=sort_keys) + + @classmethod + def loads(cls: Type[TSerializableDataclass], data: str) -> TSerializableDataclass: + return cls.Schema().loads(json_data=data) diff --git a/src/starkware/starkware_utils/validated_dataclass.py b/src/starkware/starkware_utils/validated_dataclass.py index 1a225e8a..7b4101c6 100644 --- a/src/starkware/starkware_utils/validated_dataclass.py +++ b/src/starkware/starkware_utils/validated_dataclass.py @@ -1,44 +1,20 @@ import dataclasses import inspect import random -from typing import Any, ClassVar, Dict, Optional, Sequence, Tuple, Type, TypeVar +from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar import marshmallow import marshmallow.fields as mfields import marshmallow_dataclass import typeguard -from starkware.starkware_utils.serializable import StringSerializable +from starkware.starkware_utils.serializable_dataclass import SerializableMarshmallowDataclass from starkware.starkware_utils.validated_fields import Field TValidatedDataclass = TypeVar("TValidatedDataclass", bound="ValidatedDataclass") -TSerializableDataclass = TypeVar("TSerializableDataclass", bound="SerializableMarshmallowDataclass") T = TypeVar("T") -class SerializableMarshmallowDataclass(StringSerializable): - """ - Base class to classes decorated with marshmallow_dataclass.dataclass, implementing the - Serializable interface. - """ - - Schema: ClassVar[Type[marshmallow.Schema]] - - def dump(self) -> dict: - return self.Schema().dump(obj=self) - - @classmethod - def load(cls: Type[TSerializableDataclass], data: dict) -> TSerializableDataclass: - return cls.Schema().load(data=data) - - def dumps(self) -> str: - return self.Schema().dumps(obj=self) - - @classmethod - def loads(cls: Type[TSerializableDataclass], data: str) -> TSerializableDataclass: - return cls.Schema().loads(json_data=data) - - class ValidatedDataclass: """ A class containing a type- and value-level validation. diff --git a/src/starkware/starkware_utils/validated_fields.py b/src/starkware/starkware_utils/validated_fields.py index 7d152617..9742f7b7 100644 --- a/src/starkware/starkware_utils/validated_fields.py +++ b/src/starkware/starkware_utils/validated_fields.py @@ -1,4 +1,3 @@ -import copy import dataclasses import random from abc import ABC, abstractmethod @@ -91,20 +90,27 @@ def format_invalid_value_error_message(self, value: T, name: Optional[str] = Non # Serialization. @abstractmethod - def get_marshmallow_field(self) -> mfields.Field: + def get_marshmallow_field(self, required: bool, load_default: Any) -> mfields.Field: """ Returns a marshmallow field that serializes and deserializes values of this field. """ # Metadata. - def metadata(self, field_name: Optional[str] = None): + def metadata( + self, + field_name: Optional[str] = None, + required: bool = True, + load_default: Any = marshmallow.utils.missing, + ) -> Dict[str, Any]: """ Creates the metadata associated with this field. If provided, then use the given field_name for messages, and otherwise (if it is None) use the default name. """ return dict( - marshmallow_field=self.get_marshmallow_field(), + marshmallow_field=self.get_marshmallow_field( + required=required, load_default=load_default + ), validated_field=self, name_in_messages=self.name if field_name is None else field_name, ) @@ -129,10 +135,6 @@ def __init__(self, field: Field[T], none_probability: float): super().__init__(name=field.name, error_code=field.error_code) self.field = field self.none_probability = max(0, min(1, none_probability)) - self.mfield: mfields.Field = copy.copy(self.field.get_marshmallow_field()) # type: ignore - self.mfield.allow_none = True - self.mfield.load_default = None # type: ignore[attr-defined] - self.mfield.required = False def format(self, value: Optional[T]) -> str: if value is None: @@ -160,8 +162,9 @@ def format_invalid_value_error_message( return f"{name} is valid (None)." return self.field.format_invalid_value_error_message(value=value, name=name) - def get_marshmallow_field(self) -> mfields.Field: - return self.mfield + def get_marshmallow_field(self, required: bool, load_default: Any) -> mfields.Field: + # Field is created with allow_none=True if load_default is None. + return self.field.get_marshmallow_field(required=False, load_default=None) # Mypy has a problem with dataclasses that contain unimplemented abstract methods. @@ -191,13 +194,13 @@ def _format_value(self, value: int) -> str: return str(value) return self.formatter(value) - def get_marshmallow_field(self) -> mfields.Field: + def get_marshmallow_field(self, required: bool, load_default: Any) -> mfields.Field: if self.formatter == hex: - return IntAsHex(required=True) + return IntAsHex(required=required, load_default=load_default) if self.formatter == str: - return IntAsStr(required=True) + return IntAsStr(required=required, load_default=load_default) if self.formatter is None: - return mfields.Integer(required=True) + return mfields.Integer(required=required, load_default=load_default) raise NotImplementedError( f"{self.name}: The given formatter {self.formatter.__name__} " "does not have a suitable metadata." @@ -281,8 +284,8 @@ def format_invalid_value_error_message(self, value: bytes, name: Optional[str] = return f"{name} {value_repr} length is not {self.length} bytes, instead it is {len(value)}" # Serialization. - def get_marshmallow_field(self) -> mfields.Field: - return BytesAsHex(required=True) + def get_marshmallow_field(self, required: bool, load_default: Any) -> mfields.Field: + return BytesAsHex(required=required, load_default=load_default) def format(self, value: bytes) -> str: return value.hex() @@ -351,18 +354,16 @@ def sequential_id_metadata( field_name: str, required: bool = True, allow_previous_id: bool = False, - allow_none: bool = False, load_default: Any = marshmallow.utils.missing, ) -> Dict[str, Any]: + load_default_value = load_default() if callable(load_default) else load_default validator = validate_in_range( - field_name=field_name, min_value=-1 if allow_previous_id else 0, allow_none=allow_none + field_name=field_name, + min_value=-1 if allow_previous_id else 0, + allow_none=load_default_value is None, ) return dict( marshmallow_field=mfields.Integer( - strict=True, - required=required, - allow_none=allow_none, - validate=validator, - load_default=load_default, + strict=True, required=required, validate=validator, load_default=load_default ) ) diff --git a/src/starkware/storage/imm_storage.py b/src/starkware/storage/imm_storage.py index 06af8c20..a162718d 100644 --- a/src/starkware/storage/imm_storage.py +++ b/src/starkware/storage/imm_storage.py @@ -40,7 +40,11 @@ async def del_value(self, key: bytes): self.write_tasks.append(asyncio.create_task(self.storage.del_value(key))) async def wait_for_all(self): - logger.debug("Performing remaining writing tasks to storage...") + n_write_tasks = len(self.write_tasks) + if n_write_tasks == 0: + return + + logger.debug(f"Performing {n_write_tasks} writing tasks to storage...") logging_chunk_size = 2 ** 8 for n_handled_tasks, task in enumerate(self.write_tasks): diff --git a/src/starkware/storage/storage.py b/src/starkware/storage/storage.py index 07baafa5..138695aa 100644 --- a/src/starkware/storage/storage.py +++ b/src/starkware/storage/storage.py @@ -138,7 +138,7 @@ async def get(cls: Type[TDBObject], storage: Storage, suffix: bytes) -> Optional if result is None: return None - return await asyncio.get_event_loop().run_in_executor(None, cls.deserialize, result) + return cls.deserialize(result) @classmethod async def get_or_fail(cls: Type[TDBObject], storage: Storage, suffix: bytes) -> TDBObject: @@ -150,7 +150,7 @@ async def get_or_fail(cls: Type[TDBObject], storage: Storage, suffix: bytes) -> result = await storage.get_value(key=db_key) assert result is not None, f"Key {db_key!r} does not appear in storage." - return await asyncio.get_event_loop().run_in_executor(None, cls.deserialize, result) + return cls.deserialize(result) async def set(self, storage: Storage, suffix: bytes): serialized = await asyncio.get_event_loop().run_in_executor(None, self.serialize)