diff --git a/.gitignore b/.gitignore index 0d8ac0aa..f719f659 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ /build/ __pycache__/ -cairo-starkware-*.zip +cairo-lang-*.zip diff --git a/Dockerfile b/Dockerfile index 97ce9ed0..ca593aa2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,3 +20,5 @@ WORKDIR /app/src/starkware/cairo/lang/ide/vscode-cairo RUN npm install -g vsce RUN npm install RUN vsce package + +WORKDIR /app/ diff --git a/README.md b/README.md index f1608773..33a027f1 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.0.1) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.0.2) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -55,7 +55,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-starkware-0.0.1.zip . +> docker cp ${container_id}:/app/cairo-lang-0.0.2.zip . > docker rm -v ${container_id} ``` diff --git a/build.sh b/build.sh index 0a83f0db..228a7fe8 100755 --- a/build.sh +++ b/build.sh @@ -11,5 +11,6 @@ VENV_SITE_DIR=build/Release/src/starkware/cairo/lang/cairo_lang_venv-site cp src/starkware/cairo/lang/setup.py ${VENV_SITE_DIR} cp src/starkware/cairo/lang/MANIFEST.in ${VENV_SITE_DIR} cp scripts/requirements-gen.txt ${VENV_SITE_DIR}/requirements.txt +cp README.md ${VENV_SITE_DIR} ( cd ${VENV_SITE_DIR}; python3 setup.py sdist --format=zip ) -cp ${VENV_SITE_DIR}/dist/cairo-starkware-0.0.1.zip . +cp ${VENV_SITE_DIR}/dist/cairo-lang-$(cat src/starkware/cairo/lang/VERSION).zip . diff --git a/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo b/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo index f52ed8b2..525ce929 100644 --- a/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo +++ b/src/starkware/cairo/apps/starkex2_0/execute_batch.cairo @@ -3,13 +3,13 @@ # * The license can be found in: licenses/CairoProgramLicense.txt * # *********************************************************************** -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess from starkware.cairo.apps.starkex2_0.dex_context import DexContext -from starkware.cairo.apps.starkex2_0.execute_false_full_withdrawal import execute_false_full_withdrawal -from starkware.cairo.apps.starkex2_0.execute_modification import ModificationOutput -from starkware.cairo.apps.starkex2_0.execute_modification import execute_modification +from starkware.cairo.apps.starkex2_0.execute_false_full_withdrawal import ( + execute_false_full_withdrawal) +from starkware.cairo.apps.starkex2_0.execute_modification import ( + ModificationOutput, execute_modification) from starkware.cairo.apps.starkex2_0.execute_settlement import execute_settlement from starkware.cairo.apps.starkex2_0.execute_transfer import execute_transfer diff --git a/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo b/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo index 1f787e06..69189fdc 100644 --- a/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo +++ b/src/starkware/cairo/apps/starkex2_0/execute_limit_order.cairo @@ -3,14 +3,11 @@ # * The license can be found in: licenses/CairoProgramLicense.txt * # *********************************************************************** -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess from starkware.cairo.apps.starkex2_0.common.merkle_update import merkle_update -from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import EXPIRATION_TIMESTAMP_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import NONCE_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import PackedOrderMsg +from starkware.cairo.apps.starkex2_0.dex_constants import ( + BALANCE_BOUND, EXPIRATION_TIMESTAMP_BOUND, NONCE_BOUND, PackedOrderMsg) from starkware.cairo.apps.starkex2_0.dex_context import DexContext from starkware.cairo.apps.starkex2_0.vault_update import vault_update_diff from starkware.cairo.apps.starkex2_0.verify_order_signature import verify_order_signature diff --git a/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo b/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo index a42c471d..21cc3f61 100644 --- a/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo +++ b/src/starkware/cairo/apps/starkex2_0/execute_settlement.cairo @@ -3,8 +3,7 @@ # * The license can be found in: licenses/CairoProgramLicense.txt * # *********************************************************************** -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND from starkware.cairo.apps.starkex2_0.dex_context import DexContext diff --git a/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo b/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo index ff42e4d7..0944484a 100644 --- a/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo +++ b/src/starkware/cairo/apps/starkex2_0/execute_transfer.cairo @@ -3,14 +3,11 @@ # * The license can be found in: licenses/CairoProgramLicense.txt * # *********************************************************************** -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess from starkware.cairo.apps.starkex2_0.common.merkle_update import merkle_update -from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import EXPIRATION_TIMESTAMP_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import NONCE_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import PackedOrderMsg +from starkware.cairo.apps.starkex2_0.dex_constants import ( + BALANCE_BOUND, EXPIRATION_TIMESTAMP_BOUND, NONCE_BOUND, PackedOrderMsg) from starkware.cairo.apps.starkex2_0.dex_context import DexContext from starkware.cairo.apps.starkex2_0.vault_update import vault_update_diff from starkware.cairo.apps.starkex2_0.verify_order_signature import verify_order_signature diff --git a/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo b/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo index 0a7f38e9..4cb7af73 100644 --- a/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo +++ b/src/starkware/cairo/apps/starkex2_0/hash_vault_ptr_dict.cairo @@ -5,8 +5,7 @@ from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess -from starkware.cairo.apps.starkex2_0.vault_update import compute_vault_hash -from starkware.cairo.apps.starkex2_0.vault_update import VaultState +from starkware.cairo.apps.starkex2_0.vault_update import VaultState, compute_vault_hash # Gets a single pointer to a vault state and outputs the hash of that vault. func hash_vault_state_ptr(hash_ptr : HashBuiltin*, vault_state_ptr : VaultState*) -> ( diff --git a/src/starkware/cairo/apps/starkex2_0/main.cairo b/src/starkware/cairo/apps/starkex2_0/main.cairo index 8d35f077..a3e5c868 100644 --- a/src/starkware/cairo/apps/starkex2_0/main.cairo +++ b/src/starkware/cairo/apps/starkex2_0/main.cairo @@ -5,10 +5,8 @@ %builtins output pedersen range_check ecdsa -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin -from starkware.cairo.apps.starkex2_0.common.dict import DictAccess -from starkware.cairo.apps.starkex2_0.common.dict import squash_dict +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.dict import DictAccess, squash_dict from starkware.cairo.apps.starkex2_0.common.merkle_multi_update import merkle_multi_update from starkware.cairo.apps.starkex2_0.dex_context import make_dex_context from starkware.cairo.apps.starkex2_0.execute_batch import execute_batch diff --git a/src/starkware/cairo/apps/starkex2_0/vault_update.cairo b/src/starkware/cairo/apps/starkex2_0/vault_update.cairo index 63332f23..d4d2f0db 100644 --- a/src/starkware/cairo/apps/starkex2_0/vault_update.cairo +++ b/src/starkware/cairo/apps/starkex2_0/vault_update.cairo @@ -6,8 +6,7 @@ from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin from starkware.cairo.apps.starkex2_0.common.dict import DictAccess from starkware.cairo.apps.starkex2_0.common.merkle_update import merkle_update -from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND -from starkware.cairo.apps.starkex2_0.dex_constants import ZERO_VAULT_HASH +from starkware.cairo.apps.starkex2_0.dex_constants import BALANCE_BOUND, ZERO_VAULT_HASH struct VaultState: member stark_key = 0 diff --git a/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo b/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo index b1fc8b12..d7b14298 100644 --- a/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo +++ b/src/starkware/cairo/apps/starkex2_0/verify_order_id.cairo @@ -3,7 +3,8 @@ # * The license can be found in: licenses/CairoProgramLicense.txt * # *********************************************************************** -from starkware.cairo.apps.starkex2_0.dex_constants import HASH_MESSAGE_BOUND as DEX_HASH_MESSAGE_BOUND +from starkware.cairo.apps.starkex2_0.dex_constants import ( + HASH_MESSAGE_BOUND as DEX_HASH_MESSAGE_BOUND) from starkware.cairo.apps.starkex2_0.dex_constants import ORDER_ID_BOUND as DEX_ORDER_ID_BOUND from starkware.cairo.apps.starkex2_0.dex_constants import RANGE_CHECK_BOUND as DEX_RANGE_CHECK_BOUND diff --git a/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo b/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo index 522a1cd5..f8b54123 100644 --- a/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo +++ b/src/starkware/cairo/apps/starkex2_0/verify_order_signature.cairo @@ -3,8 +3,7 @@ # * The license can be found in: licenses/CairoProgramLicense.txt * # *********************************************************************** -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin -from starkware.cairo.apps.starkex2_0.common.cairo_builtins import SignatureBuiltin +from starkware.cairo.apps.starkex2_0.common.cairo_builtins import HashBuiltin, SignatureBuiltin from starkware.cairo.apps.starkex2_0.dex_constants import PackedOrderMsg from starkware.cairo.apps.starkex2_0.verify_order_id import verify_order_id diff --git a/src/starkware/cairo/bootloader/CMakeLists.txt b/src/starkware/cairo/bootloader/CMakeLists.txt index 2ab3a0e9..84656249 100644 --- a/src/starkware/cairo/bootloader/CMakeLists.txt +++ b/src/starkware/cairo/bootloader/CMakeLists.txt @@ -7,6 +7,7 @@ python_lib(cairo_hash_program_lib LIBS cairo_common_lib cairo_compile_lib + cairo_version_lib cairo_vm_crypto_lib ) diff --git a/src/starkware/cairo/bootloader/hash_program.py b/src/starkware/cairo/bootloader/hash_program.py index 357aedf1..24e1bfcb 100644 --- a/src/starkware/cairo/bootloader/hash_program.py +++ b/src/starkware/cairo/bootloader/hash_program.py @@ -3,6 +3,7 @@ from starkware.cairo.common.hash_chain import compute_hash_chain from starkware.cairo.lang.compiler.program import Program, ProgramBase +from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager @@ -21,6 +22,7 @@ def compute_program_hash_chain(program: ProgramBase, bootloader_version=0): def main(): parser = argparse.ArgumentParser( description='A tool to compute the hash of a cairo program') + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') parser.add_argument( '--program', type=argparse.FileType('r'), required=True, help='The name of the program json file.') diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 20bb485e..5829b9bb 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -4,22 +4,28 @@ python_lib(cairo_common_lib alloc.cairo cairo_builtins.cairo dict.cairo + dict_access.cairo dict.py find_element.cairo - hash.cairo hash_chain.cairo hash_chain.py hash_state.cairo - math.cairo + hash.cairo math_utils.py + math.cairo memcpy.cairo merkle_multi_update.cairo merkle_update.cairo registers.cairo serialize.cairo signature.cairo + small_merkle_tree.cairo + small_merkle_tree.py + squash_dict.cairo ${CAIRO_COMMON_LIB_ADDITIONAL_FILES} LIBS + cairo_vm_crypto_lib + starkware_merkle_tree_lib ${CAIRO_COMMON_LIB_ADDITIONAL_LIBS} ) diff --git a/src/starkware/cairo/common/dict.cairo b/src/starkware/cairo/common/dict.cairo index 78349bf8..4804a4e8 100644 --- a/src/starkware/cairo/common/dict.cairo +++ b/src/starkware/cairo/common/dict.cairo @@ -1,248 +1,29 @@ -struct DictAccess: - member key = 0 - member prev_value = 1 - member new_value = 2 - const SIZE = 3 -end - -# Inner tail-recursive function for squash_dict. -# -# Arguments: -# range_check_ptr - range check builtin pointer. -# dict_accesses - a pointer to the beginning of an array of DictAccess instances. -# dict_accesses_end_minus1 - a pointer to the end of said array, minus 1. -# min_key - minimum allowed key. Used to enforce monotonicity of keys. -# remaining_accesses - remaining number of accesses that need to be accounted for. Starts with -# the total number of entries in dict_accesses array, and slowly decreases until it reaches 0. -# squashed_dict - a pointer to an output array, which will be filled with -# DictAccess instances sorted by key with the first and last value for each key. -# -# Hints: -# keys - a descending list of the keys for which we have accesses. Destroyed in the process. -# access_indices - A map from key to a descending list of indices in the dict_accesses array that -# access this key. Destroyed in the process. -# -# Returns: -# range_check_ptr - updated range check builtin pointer. -# squashed_dict - end pointer to squashed_dict. -func squash_dict_inner( - range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end_minus1 : felt*, min_key, - remaining_accesses, squashed_dict : DictAccess*) -> ( - range_check_ptr, squashed_dict : DictAccess*): - # Exit recursion when done. - if remaining_accesses == 0: - %{ assert len(keys) == 0 %} - return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) - end - - # Locals. - struct Locals: - member key = 0 - member should_skip_loop = 1 - member first_value = 2 - const SIZE = 3 - end - let locals = cast(fp, Locals*) - let key = locals.key - let dict_diff : DictAccess* = squashed_dict - ap += Locals.SIZE - - # Guess key and check that key >= min_key. - %{ ids.locals.key = key = keys.pop() %} - [ap] = key - min_key - [ap] = [range_check_ptr]; ap++ - - # Loop to verify chronological accesses to the key. - # These values are not needed from previous iteration. - struct LoopTemps: - member index_delta_minus1 = 0 - member index_delta = 1 - member ptr_delta = 2 - member should_continue = 3 - const SIZE = 4 - end - # These values are needed from previous iteration. - struct LoopLocals: - member value = 0 - member access_ptr : DictAccess* = 1 - member range_check_ptr = 2 - const SIZE = 3 - end - - # Prepare first iteration. - %{ - current_access_indices = sorted(access_indices[key])[::-1] - current_access_index = current_access_indices.pop() - memory[ids.range_check_ptr + 1] = current_access_index - %} - # Check that first access_index >= 0. - tempvar current_access_index = [range_check_ptr + 1] - tempvar ptr_delta = current_access_index * DictAccess.SIZE - - let first_loop_locals = cast(ap, LoopLocals*) - first_loop_locals.access_ptr = dict_accesses + ptr_delta; ap++ - let first_access : DictAccess* = first_loop_locals.access_ptr - first_loop_locals.value = first_access.new_value; ap++ - first_loop_locals.range_check_ptr = range_check_ptr + 2; ap++ - - # Verify first key. - key = first_access.key - - # Write key and first value to dict_diff. - key = dict_diff.key - # Use a local variable, instead of a tempvar, to avoid increasing ap. - locals.first_value = first_access.prev_value - locals.first_value = dict_diff.prev_value - - # Skip loop non-deterministically if necessary. - %{ memory[fp + ids.Locals.should_skip_loop] = 0 if current_access_indices else 1 %} - jmp skip_loop if [fp + Locals.should_skip_loop] != 0 - - loop: - let prev_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) - let loop_temps = cast(ap, LoopTemps*) - let loop_locals = cast(ap + LoopTemps.SIZE, LoopLocals*) - - # Check access_index. - %{ - new_access_index = current_access_indices.pop() - ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1 - current_access_index = new_access_index - %} - # Check that new access_index > prev access_index. - loop_temps.index_delta_minus1 = [prev_loop_locals.range_check_ptr]; ap++ - loop_temps.index_delta = loop_temps.index_delta_minus1 + 1; ap++ - loop_temps.ptr_delta = loop_temps.index_delta * DictAccess.SIZE; ap++ - loop_locals.access_ptr = prev_loop_locals.access_ptr + loop_temps.ptr_delta; ap++ - - # Check valid transition. - let access : DictAccess* = loop_locals.access_ptr - prev_loop_locals.value = access.prev_value - loop_locals.value = access.new_value; ap++ - - # Verify key. - key = access.key - - # Next range_check_ptr. - loop_locals.range_check_ptr = prev_loop_locals.range_check_ptr + 1; ap++ - - %{ ids.loop_temps.should_continue = 1 if current_access_indices else 0 %} - jmp loop if loop_temps.should_continue != 0; ap++ - - skip_loop: - let last_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) - - # Check if address is out of bounds. - %{ assert len(current_access_indices) == 0 %} - [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) - [ap] = [last_loop_locals.range_check_ptr]; ap++ - tempvar range_check_diff = last_loop_locals.range_check_ptr - range_check_ptr - tempvar n_used_accesses = range_check_diff - 1 - %{ assert ids.n_used_accesses == len(access_indices[key]) %} - - # Write last value to dict_diff. - last_loop_locals.value = dict_diff.new_value - - # Call squashed_dict_inner recursively. - squash_dict_inner( - range_check_ptr=last_loop_locals.range_check_ptr + 1, - dict_accesses=dict_accesses, - dict_accesses_end_minus1=dict_accesses_end_minus1, - min_key=key + 1, - remaining_accesses=remaining_accesses - n_used_accesses, - squashed_dict=squashed_dict + DictAccess.SIZE) - return (...) -end +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.squash_dict import squash_dict -# Verifies that dict_accesses lists valid chronological accesses (and updates) -# to a mutable dictionary and outputs a squashed dict with one DictAccess instance per key -# (value before and value after) which summarizes all the changes to that key. -# -# All keys are assumed to be in the range of the range check builtin (usually 2**128). -# -# Example: -# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} -# Output: {(key1, 0, 5), (key2, 4, 2)} -# -# Arguments: -# range_check_ptr - range check builtin pointer. -# dict_accesses - a pointer to the beginning of an array of DictAccess instances. The format of each -# entry is a triplet (key, prev_value, new_value). -# dict_accesses_end - a pointer to the end of said array. -# squashed_dict - a pointer to an output array, which will be filled with -# DictAccess instances sorted by key with the first and last value for each key. -# -# Returns: -# range_check_ptr - updated range check builtin pointer. -# squashed_dict - end pointer to squashed_dict. -func squash_dict( - range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end : DictAccess*, - squashed_dict : DictAccess*) -> (range_check_ptr, squashed_dict : DictAccess*): - let ptr_diff = [fp] - %{ vm_enter_scope() %} - ptr_diff = dict_accesses_end - dict_accesses; ap++ - - if ptr_diff == 0: - # Access array is empty, nothing to check. - %{ vm_exit_scope() %} - return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) - end - - tempvar n_accesses = ptr_diff / DictAccess.SIZE +# Creates a new dict. +func dict_new() -> (res : DictAccess*): %{ - assert ids.ptr_diff % ids.DictAccess.SIZE == 0, \ - 'Accesses array size must be divisible by DictAccess.SIZE' - # A map from key to the list of indices accessing it. - access_indices = {} - for i in range(ids.n_accesses): - key = memory[ids.dict_accesses.address_ + ids.DictAccess.SIZE * i] - access_indices.setdefault(key, []).append(i) - # Descending list of keys. - keys = sorted(access_indices.keys())[::-1] - %} - - # Call inner. - squash_dict_inner( - range_check_ptr=range_check_ptr, - dict_accesses=dict_accesses, - dict_accesses_end_minus1=dict_accesses_end - 1, - min_key=0, - remaining_accesses=n_accesses, - squashed_dict=squashed_dict) - %{ vm_exit_scope() %} - return (...) -end + if '__dict_manager' not in globals(): + from starkware.cairo.common.dict import DictManager + __dict_manager = DictManager() -# Initializes the dict manager. Should be called exactly once at the beginning of a program that -# uses dicts. -func initialize_dict_manager() -> (): - %{ - from starkware.cairo.common.dict import DictManager - assert 'dict_manager' not in globals(), \ - 'initialize_dict_manager() must be called exactly once.' - dict_manager = DictManager() + memory[ap] = __dict_manager.new_dict(segments, initial_dict) + del initial_dict %} - return () -end - -# Creates a new dict. -# Note that a dict_manager must be passed in the hints. -# Allocate one using initialize_dict_manager(). -func dict_new() -> (res): - %{ memory[ap] = dict_manager.new_dict(segments, initial_dict) %} ap += 1 return (...) end # Updates a value in a dict. prev_value must be specified. A standalone read with no write should be # performed by writing the same value. -# It is possible to get prev_value from dict_manager using the hint: -# %{ ids.val = dict_manager.get_dict(ids.dict_ptr)[ids.key] %} +# It is possible to get prev_value from __dict_manager using the hint: +# %{ ids.val = __dict_manager.get_dict(ids.dict_ptr)[ids.key] %} func dict_update(dict_ptr : DictAccess*, key : felt, prev_value : felt, new_value : felt) -> ( dict_ptr : DictAccess*): %{ # Verify dict pointer and prev value. - dict_tracker = dict_manager.get_tracker(ids.dict_ptr) + dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) current_value = dict_tracker.data[ids.key] assert current_value == ids.prev_value, \ f'Wrong previous value in dict. Got {ids.prev_value}, expected {current_value}.' @@ -256,3 +37,48 @@ func dict_update(dict_ptr : DictAccess*, key : felt, prev_value : felt, new_valu dict_ptr.new_value = new_value return (dict_ptr=dict_ptr + DictAccess.SIZE) end + +# Returns a new dictionary with one DictAccess instance per key +# (value before and value after) which summarizes all the changes to that key. +# +# All keys are assumed to be in the range of the range check builtin (usually 2**128). +# +# Example: +# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} +# Output: {(key1, 0, 5), (key2, 4, 2)} +# +# This is a wrapper of squash_dict for dictionaries created by dict_new(). +func dict_squash( + range_check_ptr, dict_accesses_start : DictAccess*, dict_accesses_end : DictAccess*) -> ( + range_check_ptr, squashed_dict_start : DictAccess*, squashed_dict_end : DictAccess*): + alloc_locals + + %{ + # Prepare arguments for dict_new. In particular, the same dictionary values should be copied + # to the new (squashed) dictionary. + vm_enter_scope({ + # Make __dict_manager accessible. + '__dict_manager': __dict_manager, + # Create a copy of the dict, in case it changes in the future. + 'initial_dict': dict(__dict_manager.get_dict(ids.dict_accesses_end)), + }) + %} + let (local squashed_dict_start : DictAccess*) = dict_new() + %{ vm_exit_scope() %} + + let (range_check_ptr, squashed_dict_end) = squash_dict( + range_check_ptr=range_check_ptr, + dict_accesses=dict_accesses_start, + dict_accesses_end=dict_accesses_end, + squashed_dict=squashed_dict_start) + + %{ + # Update the DictTracker's current_ptr to point to the end of the squashed dict. + __dict_manager.get_tracker(ids.squashed_dict_start).current_ptr = \ + ids.squashed_dict_end.address_ + %} + return ( + range_check_ptr=range_check_ptr, + squashed_dict_start=squashed_dict_start, + squashed_dict_end=squashed_dict_end) +end diff --git a/src/starkware/cairo/common/dict.py b/src/starkware/cairo/common/dict.py index 280a85b8..fcd07b8b 100644 --- a/src/starkware/cairo/common/dict.py +++ b/src/starkware/cairo/common/dict.py @@ -46,7 +46,9 @@ def get_tracker(self, dict_ptr): """ if isinstance(dict_ptr, VmConstsReference): dict_ptr = dict_ptr.address_ - dict_tracker = self.trackers[dict_ptr.segment_index] + dict_tracker = self.trackers.get(dict_ptr.segment_index) + if dict_tracker is None: + raise ValueError(f'Dictionary pointer {dict_ptr} was not created using dict_new().') assert dict_tracker.current_ptr == dict_ptr, 'Wrong dict pointer supplied. ' \ f'Got {dict_ptr}, expected {dict_tracker.current_ptr}.' return dict_tracker diff --git a/src/starkware/cairo/common/dict_access.cairo b/src/starkware/cairo/common/dict_access.cairo new file mode 100644 index 00000000..9711790f --- /dev/null +++ b/src/starkware/cairo/common/dict_access.cairo @@ -0,0 +1,6 @@ +struct DictAccess: + member key = 0 + member prev_value = 1 + member new_value = 2 + const SIZE = 3 +end diff --git a/src/starkware/cairo/common/find_element.cairo b/src/starkware/cairo/common/find_element.cairo index b61dbf13..16feee75 100644 --- a/src/starkware/cairo/common/find_element.cairo +++ b/src/starkware/cairo/common/find_element.cairo @@ -1,5 +1,4 @@ -from starkware.cairo.common.math import assert_nn_le -from starkware.cairo.common.math import assert_le +from starkware.cairo.common.math import assert_le, assert_nn_le # Finds an element in the array whose first field is key and returns a pointer # to this element. diff --git a/src/starkware/cairo/common/hash_chain.cairo b/src/starkware/cairo/common/hash_chain.cairo index 73e2ba99..879df112 100644 --- a/src/starkware/cairo/common/hash_chain.cairo +++ b/src/starkware/cairo/common/hash_chain.cairo @@ -5,7 +5,8 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin # For example, for the 3-element sequence [x, y, z] the hash is: # h(3, h(x, h(y, z))) # If data_length = 0, the function does not return (takes more than field prime steps). -func hash_chain(pedersen_ptr, data_ptr : felt*) -> (pedersen_ptr, hash): +func hash_chain(pedersen_ptr : HashBuiltin*, data_ptr : felt*) -> ( + pedersen_ptr : HashBuiltin*, hash : felt): struct LoopLocals: member data_ptr : felt* = 0 member pedersen_ptr : HashBuiltin* = 1 diff --git a/src/starkware/cairo/common/math.cairo b/src/starkware/cairo/common/math.cairo index 32fe7cb9..53ae886a 100644 --- a/src/starkware/cairo/common/math.cairo +++ b/src/starkware/cairo/common/math.cairo @@ -229,7 +229,7 @@ end # 0 < div <= PRIME / (rc_bound) # bound <= rc_bound / 2. # Prover assumption: -bound <= value / div < bound. - +# # The values of div and bound are restricted to make sure there is no overflow. # q * div + r < (q + 1) * div <= rc_bound / 2 * (PRIME / rc_bound) # q * div + r >= q * div >= -rc_bound / 2 * (PRIME / rc_bound) diff --git a/src/starkware/cairo/common/merkle_multi_update.cairo b/src/starkware/cairo/common/merkle_multi_update.cairo index 4154a741..c1639713 100644 --- a/src/starkware/cairo/common/merkle_multi_update.cairo +++ b/src/starkware/cairo/common/merkle_multi_update.cairo @@ -1,5 +1,5 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.dict import DictAccess +from starkware.cairo.common.dict_access import DictAccess # Helper function for merkle_multi_update(). func merkle_multi_update_inner( @@ -129,9 +129,9 @@ end # Arguments: # hash_ptr - hash builtin pointer. # update_ptr - a list of DictAccess instances sorted by key (e.g., the result of squash_dict). -# height - height of merkle tree. -# prev_root - root value before the multi update. -# new_root - root value after the multi update. +# height - the height of the merkle tree. +# prev_root - the value of the root before the update. +# new_root - the value of the root after the update. # # Hint arguments: # preimage - a dictionary from the hash value of a merkle node to the pair of children values. @@ -158,7 +158,7 @@ func merkle_multi_update( end %{ - from starkware.starkware_utils.merkle_tree.merkle_tree import build_update_tree + from starkware.python.merkle_tree import build_update_tree # Build modifications list. modifications = [] diff --git a/src/starkware/cairo/common/merkle_update.cairo b/src/starkware/cairo/common/merkle_update.cairo index 732a1e15..e04befe6 100644 --- a/src/starkware/cairo/common/merkle_update.cairo +++ b/src/starkware/cairo/common/merkle_update.cairo @@ -6,7 +6,8 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin # In particular, given a secret authentication path (of the siblings of the nodes in the path from # the root to the leaf), this function computes the roots twice - once with prev_leaf and once with # new_leaf, where the verifier is guaranteed that the same authentication path is used. -func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, new_root, hash_ptr): +func merkle_update(hash_ptr : HashBuiltin*, height, prev_leaf, new_leaf, index) -> ( + prev_root, new_root, hash_ptr : HashBuiltin*): if height == 0: # Assert that index is 0. index = 0 @@ -18,6 +19,9 @@ func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, return (prev_root=prev_leaf, new_root=new_leaf, hash_ptr=hash_ptr) end + let prev_node_hash = hash_ptr + let new_node_hash = hash_ptr + HashBuiltin.SIZE + %{ memory[ap] = ids.index % 2 %} jmp update_right if [ap] != 0; ap++ @@ -25,22 +29,22 @@ func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, %{ # Hash hints. sibling = auth_path.pop() - memory[ids.hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = sibling - memory[ids.hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.y] = sibling + ids.prev_node_hash.y = sibling + ids.new_node_hash.y = sibling %} - prev_leaf = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.x] - new_leaf = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.x] + prev_leaf = prev_node_hash.x + new_leaf = new_node_hash.x # Make sure the same authentication path is used. let right_sibling = ap - [right_sibling] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.y] - [right_sibling] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.y]; ap++ + [right_sibling] = prev_node_hash.y + [right_sibling] = new_node_hash.y; ap++ # Call merkle_update recursively. [ap] = hash_ptr + 2 * HashBuiltin.SIZE; ap++ # hash_ptr. [ap] = height - 1; ap++ # height. - [ap] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # prev_leaf. - [ap] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # new_leaf. + [ap] = prev_node_hash.result; ap++ # prev_leaf. + [ap] = new_node_hash.result; ap++ # new_leaf. let update_left_index = ap %{ memory[ap] = ids.index // 2 %} @@ -52,16 +56,16 @@ func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, %{ # Hash hints. sibling = auth_path.pop() - memory[ids.hash_ptr + 0 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = sibling - memory[ids.hash_ptr + 1 * ids.HashBuiltin.SIZE + ids.HashBuiltin.x] = sibling + ids.prev_node_hash.x = sibling + ids.new_node_hash.x = sibling %} - prev_leaf = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.y] - new_leaf = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.y] + prev_leaf = prev_node_hash.y + new_leaf = new_node_hash.y # Make sure the same authentication path is used. let left_sibling = ap - [left_sibling] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.x] - [left_sibling] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.x]; ap++ + [left_sibling] = prev_node_hash.x + [left_sibling] = new_node_hash.x; ap++ # Compute index - 1. tempvar index_minus_one = index - 1 @@ -69,8 +73,8 @@ func merkle_update(hash_ptr, height, prev_leaf, new_leaf, index) -> (prev_root, # Call merkle_update recursively. [ap] = hash_ptr + 2 * HashBuiltin.SIZE; ap++ # hash_ptr. [ap] = height - 1; ap++ # height. - [ap] = [hash_ptr + 0 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # prev_leaf. - [ap] = [hash_ptr + 1 * HashBuiltin.SIZE + HashBuiltin.result]; ap++ # new_leaf. + [ap] = prev_node_hash.result; ap++ # prev_leaf. + [ap] = new_node_hash.result; ap++ # new_leaf. let update_right_index = ap %{ memory[ap] = ids.index // 2 %} diff --git a/src/starkware/cairo/common/small_merkle_tree.cairo b/src/starkware/cairo/common/small_merkle_tree.cairo new file mode 100644 index 00000000..2e47c082 --- /dev/null +++ b/src/starkware/cairo/common/small_merkle_tree.cairo @@ -0,0 +1,102 @@ +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.cairo.common.dict import DictAccess +from starkware.cairo.common.merkle_multi_update import merkle_multi_update + +# Performs an efficient update of multiple leaves in a Merkle tree, based on the given squashed +# dict, assuming the merkle tree is small enough to be loaded to the memory. +# +# This function computes the Merkle authentication paths internally and +# does not require any hint arguments, therefore it's usually easier to use. +# The input dict must be created using the higher-level dict functions (see dict.cairo), which add +# information about all the non-default leaves in the hints (not just the leaves that were changed). +# +# Usage example: +# %{ initial_dict = {1: 2, 3: 4, 5: 6} %} +# let (dict_ptr_start) = dict_new() +# let dict_ptr = dict_ptr_start +# let (dict_ptr) = dict_update(dict_ptr=dict_ptr, key=1, prev_value=2, new_value=20) +# let (range_check_ptr, squashed_dict_start, squashed_dict_end) = dict_squash( +# range_check_ptr=range_check_ptr, +# dict_accesses_start=dict_ptr_start, +# dict_accesses_end=dict_ptr) +# const HEIGHT = 3 +# let (hash_ptr, prev_root, new_root) = small_merkle_tree( +# hash_ptr, squashed_dict_start, squashed_dict_end, HEIGHT) +# +# In this example prev_root is the Merkle root of [0, 2, 0, 4, 0, 6, 0, 0], and new_root +# is the Merkle root of [0, 20, 0, 4, 0, 6, 0, 0]. +# Note that from the point of view of the verifier, all it knows is that leaf 1 changed from 2 to +# 20 -- it doesn't know anything about the other leaves (except that they haven't changed). +# +# Arguments: +# hash_ptr - hash builtin pointer. +# squashed_dict, squashed_dict_end - a list of DictAccess instances sorted by key +# (e.g., the result of dict_squash). +# height - the height of the merkle tree. +# +# Returns: +# hash_ptr - updated hash builtin pointer. +# prev_root - the value of the root before the update. +# new_root - the value of the root after the update. +# +# Assumptions: The keys in the squashed_dict are unique and sorted. +# +# Prover assumptions: +# * squashed_dict was created using the higher-level API dict_squash() (rather than squash_dict()). +# * This function can be used for (relatively) small Merkle trees whose leaves can be loaded +# to the memory. +func small_merkle_tree( + hash_ptr : HashBuiltin*, squashed_dict_start : DictAccess*, + squashed_dict_end : DictAccess*, height : felt) -> ( + hash_ptr : HashBuiltin*, prev_root : felt, new_root : felt): + %{ vm_enter_scope({'__dict_manager': __dict_manager}) %} + alloc_locals + # Allocate memory cells for the roots. + local prev_root + local new_root + %{ + # Compute the roots and the preimage dictionary. + from starkware.cairo.common.small_merkle_tree import get_preimage_dictionary + from starkware.python.math_utils import safe_div + + new_dict = __dict_manager.get_dict(ids.squashed_dict_end.address_) + + DICT_ACCESS_SIZE = ids.DictAccess.SIZE + squashed_dict_start = ids.squashed_dict_start.address_ + squashed_dict_size = ids.squashed_dict_end.address_ - squashed_dict_start + assert squashed_dict_size >= 0 and squashed_dict_size % DICT_ACCESS_SIZE == 0, \ + f'squashed_dict size must be non-negative and divisible by DictAccess.SIZE. ' \ + f'Found: {squashed_dict_size}.' + squashed_dict_length = safe_div(squashed_dict_size, DICT_ACCESS_SIZE) + + # Compute the modifications backwards: from the new values to the previous values. + modifications = [] + for i in range(squashed_dict_length): + key = memory[squashed_dict_start + i * DICT_ACCESS_SIZE + ids.DictAccess.key] + prev_value = memory[ + squashed_dict_start + i * DICT_ACCESS_SIZE + ids.DictAccess.prev_value] + new_value = memory[ + squashed_dict_start + i * DICT_ACCESS_SIZE + ids.DictAccess.new_value] + assert new_dict[key] == new_value, \ + f'Inconsistent dictionary values. Expected new value: {new_dict[key]}, ' \ + f'found: {new_value}' + modifications.append((key, prev_value)) + + ids.new_root, ids.prev_root, preimage = get_preimage_dictionary( + initial_leaves=new_dict.items(), + modifications=modifications, + tree_height=ids.height, + default_leaf=0) + %} + + # Call merkle_multi_update() to verify the two roots. + let (hash_ptr) = merkle_multi_update( + hash_ptr=hash_ptr, + update_ptr=squashed_dict_start, + n_updates=(squashed_dict_end - squashed_dict_start) / DictAccess.SIZE, + height=height, + prev_root=prev_root, + new_root=new_root) + %{ vm_exit_scope() %} + return (hash_ptr=hash_ptr, prev_root=prev_root, new_root=new_root) +end diff --git a/src/starkware/cairo/common/small_merkle_tree.py b/src/starkware/cairo/common/small_merkle_tree.py new file mode 100644 index 00000000..183976c6 --- /dev/null +++ b/src/starkware/cairo/common/small_merkle_tree.py @@ -0,0 +1,59 @@ +from typing import Collection, Dict, Tuple + +from starkware.cairo.lang.vm.crypto import pedersen_hash + + +class MerkleTree: + def __init__(self, tree_height: int, default_leaf: int): + self.tree_height = tree_height + self.default_leaf = default_leaf + # A map from node indices to their values. + self.node_values: Dict[int, int] = {} + # A map from node hash to its two children. + self.preimage: Dict[int, Tuple[int, int]] = {} + + def compute_merkle_root(self, modifications: Collection[Tuple[int, int]]): + """ + Applies the given modifications (a list of (leaf index, value)) to the tree and returns + the Merkle root. + """ + default_node = self.default_leaf + indices = set() + leaves_offset = 2 ** self.tree_height + for index, value in modifications: + node_index = leaves_offset + index + self.node_values[node_index] = value + indices.add(node_index // 2) + for _ in range(self.tree_height): + new_indices = set() + while len(indices) > 0: + index = indices.pop() + left = self.node_values.get(2 * index, default_node) + right = self.node_values.get(2 * index + 1, default_node) + self.node_values[index] = node_hash = pedersen_hash(left, right) + self.preimage[node_hash] = (left, right) + new_indices.add(index // 2) + default_node = pedersen_hash(default_node, default_node) + indices = new_indices + assert indices == {0} + return self.node_values[1] + + +def get_preimage_dictionary( + initial_leaves: Collection[Tuple[int, int]], modifications: Collection[Tuple[int, int]], + tree_height: int, default_leaf: int) -> Tuple[int, int, Dict[int, Tuple[int, int]]]: + """ + Given a set of initial leaves and a set of modifications + (both are maps from leaf index to value, where all the leaves in `modifications` appear + in `initial_leaves`). + Constructs two merkle trees, before and after the modifications. + Returns (root_before, root_after, preimage) where preimage is a dictionary from a node to + its two children. + """ + + merkle_tree = MerkleTree(tree_height=tree_height, default_leaf=default_leaf) + + root_before = merkle_tree.compute_merkle_root(modifications=initial_leaves) + root_after = merkle_tree.compute_merkle_root(modifications=modifications) + + return root_before, root_after, merkle_tree.preimage diff --git a/src/starkware/cairo/common/small_merkle_tree_test.py b/src/starkware/cairo/common/small_merkle_tree_test.py new file mode 100644 index 00000000..38384731 --- /dev/null +++ b/src/starkware/cairo/common/small_merkle_tree_test.py @@ -0,0 +1,39 @@ +import os + +from starkware.cairo.common.dict import DictManager +from starkware.cairo.common.test_utils import CairoFunctionRunner +from starkware.cairo.lang.builtins.hash.hash_builtin_runner import CELLS_PER_HASH +from starkware.cairo.lang.compiler.cairo_compile import compile_cairo_files +from starkware.native_crypto.native_crypto import pedersen_hash + +CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'small_merkle_tree.cairo') +PRIME = 2**251 + 17 * 2**192 + 1 +MERKLE_HEIGHT = 2 + + +def test_cairo_merkle_multi_update(): + program = compile_cairo_files([CAIRO_FILE], prime=PRIME, debug_info=True) + runner = CairoFunctionRunner(program) + + dict_manager = DictManager() + squashed_dict_start = dict_manager.new_dict( + segments=runner.segments, initial_dict={1: 10, 2: 20, 3: 30}) + + # Change the value at 1 from 10 to 11 and at 3 from 30 to 31. + squashed_dict = [1, 10, 11, 3, 30, 31] + squashed_dict_end = runner.segments.write_arg(ptr=squashed_dict_start, arg=squashed_dict) + dict_tracker = dict_manager.get_tracker(squashed_dict_start) + dict_tracker.current_ptr = squashed_dict_end + dict_tracker.data[1] = 11 + dict_tracker.data[3] = 31 + + runner.run( + 'small_merkle_tree', runner.hash_builtin.base, squashed_dict_start, squashed_dict_end, + MERKLE_HEIGHT, hint_locals=dict(__dict_manager=dict_manager)) + hash_ptr, prev_root, new_root = runner.get_return_values(3) + N_MERKLE_TREES = 2 + N_HASHES_PER_TREE = 3 + assert hash_ptr == \ + runner.hash_builtin.base + N_MERKLE_TREES * N_HASHES_PER_TREE * CELLS_PER_HASH + assert prev_root == pedersen_hash(pedersen_hash(0, 10), pedersen_hash(20, 30)) + assert new_root == pedersen_hash(pedersen_hash(0, 11), pedersen_hash(20, 31)) diff --git a/src/starkware/cairo/common/squash_dict.cairo b/src/starkware/cairo/common/squash_dict.cairo new file mode 100644 index 00000000..7c81766b --- /dev/null +++ b/src/starkware/cairo/common/squash_dict.cairo @@ -0,0 +1,209 @@ +from starkware.cairo.common.dict_access import DictAccess + +# Inner tail-recursive function for squash_dict. +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. +# dict_accesses_end_minus1 - a pointer to the end of said array, minus 1. +# min_key - minimum allowed key. Used to enforce monotonicity of keys. +# remaining_accesses - remaining number of accesses that need to be accounted for. Starts with +# the total number of entries in dict_accesses array, and slowly decreases until it reaches 0. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Hints: +# keys - a descending list of the keys for which we have accesses. Destroyed in the process. +# access_indices - A map from key to a descending list of indices in the dict_accesses array that +# access this key. Destroyed in the process. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict_inner( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end_minus1 : felt*, min_key, + remaining_accesses, squashed_dict : DictAccess*) -> ( + range_check_ptr, squashed_dict : DictAccess*): + # Exit recursion when done. + if remaining_accesses == 0: + %{ assert len(keys) == 0 %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) + end + + # Locals. + struct Locals: + member key = 0 + member should_skip_loop = 1 + member first_value = 2 + const SIZE = 3 + end + let locals = cast(fp, Locals*) + let key = locals.key + let dict_diff : DictAccess* = squashed_dict + ap += Locals.SIZE + + # Guess key and check that key >= min_key. + %{ ids.locals.key = key = keys.pop() %} + [ap] = key - min_key + [ap] = [range_check_ptr]; ap++ + + # Loop to verify chronological accesses to the key. + # These values are not needed from previous iteration. + struct LoopTemps: + member index_delta_minus1 = 0 + member index_delta = 1 + member ptr_delta = 2 + member should_continue = 3 + const SIZE = 4 + end + # These values are needed from previous iteration. + struct LoopLocals: + member value = 0 + member access_ptr : DictAccess* = 1 + member range_check_ptr = 2 + const SIZE = 3 + end + + # Prepare first iteration. + %{ + current_access_indices = sorted(access_indices[key])[::-1] + current_access_index = current_access_indices.pop() + memory[ids.range_check_ptr + 1] = current_access_index + %} + # Check that first access_index >= 0. + tempvar current_access_index = [range_check_ptr + 1] + tempvar ptr_delta = current_access_index * DictAccess.SIZE + + let first_loop_locals = cast(ap, LoopLocals*) + first_loop_locals.access_ptr = dict_accesses + ptr_delta; ap++ + let first_access : DictAccess* = first_loop_locals.access_ptr + first_loop_locals.value = first_access.new_value; ap++ + first_loop_locals.range_check_ptr = range_check_ptr + 2; ap++ + + # Verify first key. + key = first_access.key + + # Write key and first value to dict_diff. + key = dict_diff.key + # Use a local variable, instead of a tempvar, to avoid increasing ap. + locals.first_value = first_access.prev_value + locals.first_value = dict_diff.prev_value + + # Skip loop non-deterministically if necessary. + %{ memory[fp + ids.Locals.should_skip_loop] = 0 if current_access_indices else 1 %} + jmp skip_loop if [fp + Locals.should_skip_loop] != 0 + + loop: + let prev_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + let loop_temps = cast(ap, LoopTemps*) + let loop_locals = cast(ap + LoopTemps.SIZE, LoopLocals*) + + # Check access_index. + %{ + new_access_index = current_access_indices.pop() + ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1 + current_access_index = new_access_index + %} + # Check that new access_index > prev access_index. + loop_temps.index_delta_minus1 = [prev_loop_locals.range_check_ptr]; ap++ + loop_temps.index_delta = loop_temps.index_delta_minus1 + 1; ap++ + loop_temps.ptr_delta = loop_temps.index_delta * DictAccess.SIZE; ap++ + loop_locals.access_ptr = prev_loop_locals.access_ptr + loop_temps.ptr_delta; ap++ + + # Check valid transition. + let access : DictAccess* = loop_locals.access_ptr + prev_loop_locals.value = access.prev_value + loop_locals.value = access.new_value; ap++ + + # Verify key. + key = access.key + + # Next range_check_ptr. + loop_locals.range_check_ptr = prev_loop_locals.range_check_ptr + 1; ap++ + + %{ ids.loop_temps.should_continue = 1 if current_access_indices else 0 %} + jmp loop if loop_temps.should_continue != 0; ap++ + + skip_loop: + let last_loop_locals = cast(ap - LoopLocals.SIZE, LoopLocals*) + + # Check if address is out of bounds. + %{ assert len(current_access_indices) == 0 %} + [ap] = dict_accesses_end_minus1 - cast(last_loop_locals.access_ptr, felt) + [ap] = [last_loop_locals.range_check_ptr]; ap++ + tempvar range_check_diff = last_loop_locals.range_check_ptr - range_check_ptr + tempvar n_used_accesses = range_check_diff - 1 + %{ assert ids.n_used_accesses == len(access_indices[key]) %} + + # Write last value to dict_diff. + last_loop_locals.value = dict_diff.new_value + + # Call squashed_dict_inner recursively. + squash_dict_inner( + range_check_ptr=last_loop_locals.range_check_ptr + 1, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end_minus1, + min_key=key + 1, + remaining_accesses=remaining_accesses - n_used_accesses, + squashed_dict=squashed_dict + DictAccess.SIZE) + return (...) +end + +# Verifies that dict_accesses lists valid chronological accesses (and updates) +# to a mutable dictionary and outputs a squashed dict with one DictAccess instance per key +# (value before and value after) which summarizes all the changes to that key. +# +# All keys are assumed to be in the range of the range check builtin (usually 2**128). +# +# Example: +# Input: {(key1, 0, 2), (key1, 2, 7), (key2, 4, 1), (key1, 7, 5), (key2, 1, 2)} +# Output: {(key1, 0, 5), (key2, 4, 2)} +# +# Arguments: +# range_check_ptr - range check builtin pointer. +# dict_accesses - a pointer to the beginning of an array of DictAccess instances. The format of each +# entry is a triplet (key, prev_value, new_value). +# dict_accesses_end - a pointer to the end of said array. +# squashed_dict - a pointer to an output array, which will be filled with +# DictAccess instances sorted by key with the first and last value for each key. +# +# Returns: +# range_check_ptr - updated range check builtin pointer. +# squashed_dict - end pointer to squashed_dict. +func squash_dict( + range_check_ptr, dict_accesses : DictAccess*, dict_accesses_end : DictAccess*, + squashed_dict : DictAccess*) -> (range_check_ptr, squashed_dict : DictAccess*): + let ptr_diff = [fp] + %{ vm_enter_scope() %} + ptr_diff = dict_accesses_end - dict_accesses; ap++ + + if ptr_diff == 0: + # Access array is empty, nothing to check. + %{ vm_exit_scope() %} + return (range_check_ptr=range_check_ptr, squashed_dict=squashed_dict) + end + + tempvar n_accesses = ptr_diff / DictAccess.SIZE + %{ + assert ids.ptr_diff % ids.DictAccess.SIZE == 0, \ + 'Accesses array size must be divisible by DictAccess.SIZE' + # A map from key to the list of indices accessing it. + access_indices = {} + for i in range(ids.n_accesses): + key = memory[ids.dict_accesses.address_ + ids.DictAccess.SIZE * i] + access_indices.setdefault(key, []).append(i) + # Descending list of keys. + keys = sorted(access_indices.keys())[::-1] + %} + + # Call inner. + squash_dict_inner( + range_check_ptr=range_check_ptr, + dict_accesses=dict_accesses, + dict_accesses_end_minus1=dict_accesses_end - 1, + min_key=0, + remaining_accesses=n_accesses, + squashed_dict=squashed_dict) + %{ vm_exit_scope() %} + return (...) +end diff --git a/src/starkware/cairo/lang/CMakeLists.txt b/src/starkware/cairo/lang/CMakeLists.txt index 442b4de6..ba6e7be8 100644 --- a/src/starkware/cairo/lang/CMakeLists.txt +++ b/src/starkware/cairo/lang/CMakeLists.txt @@ -4,11 +4,20 @@ add_subdirectory(scripts) add_subdirectory(tracer) add_subdirectory(vm) +python_lib(cairo_version_lib + PREFIX starkware/cairo/lang + + FILES + VERSION + version.py +) + python_venv(cairo_lang_venv PYTHON python3.7 LIBS cairo_common_lib cairo_compile_lib + cairo_hash_program_lib cairo_run_lib cairo_script_lib ${CAIRO_LANG_VENV_ADDITIONAL_LIBS} diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION new file mode 100644 index 00000000..4e379d2b --- /dev/null +++ b/src/starkware/cairo/lang/VERSION @@ -0,0 +1 @@ +0.0.2 diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt index b7ddb338..d0ebab82 100644 --- a/src/starkware/cairo/lang/compiler/CMakeLists.txt +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -10,6 +10,7 @@ python_lib(cairo_compile_lib ast/code_elements.py ast/expr.py ast/formatting_utils.py + ast/imports.py ast/instructions.py ast/module.py ast/node.py @@ -55,6 +56,7 @@ python_lib(cairo_compile_lib type_system_visitor.py LIBS + cairo_version_lib starkware_expression_string_lib starkware_python_utils_lib pip_marshmallow_dataclass diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py index b3b8d4c8..ee427fec 100644 --- a/src/starkware/cairo/lang/compiler/ast/code_elements.py +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -8,8 +8,10 @@ from starkware.cairo.lang.compiler.ast.formatting_utils import ( INDENTATION, LocationField, ParticleFormattingConfig, create_particle_sublist, particles_in_lines) +from starkware.cairo.lang.compiler.ast.imports import ImportItem from starkware.cairo.lang.compiler.ast.instructions import InstructionAst from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import NoteListField, Notes from starkware.cairo.lang.compiler.ast.rvalue import Rvalue, RvalueCall, RvalueFuncCall from starkware.cairo.lang.compiler.ast.types import TypedIdentifier from starkware.cairo.lang.compiler.error_handling import Location @@ -463,20 +465,31 @@ def get_children(self) -> Sequence[Optional[AstNode]]: @dataclasses.dataclass class CodeElementImport(CodeElement): path: ExprIdentifier - orig_identifier: ExprIdentifier - local_name: Optional[ExprIdentifier] = None + import_items: List[ImportItem] + notes: List[Notes] = NoteListField # type: ignore location: Optional[Location] = LocationField def format(self, allowed_line_length): - return f'from {self.path.format()} import {self.orig_identifier.format()}' + \ - (f' as {self.local_name.format()}' if self.local_name else '') + for note in self.notes: + note.assert_no_comments() - @property - def identifier(self): - return self.local_name if self.local_name is not None else self.orig_identifier + items = [item.format() for item in self.import_items] + prefix = f'from {self.path.format()} import ' + one_liner = prefix + ', '.join(items) + + if len(one_liner) <= allowed_line_length: + return one_liner + + particles = [f'{prefix}(', create_particle_sublist(items, ')')] + return particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, + line_indent=INDENTATION, + one_per_line=False)) def get_children(self) -> Sequence[Optional[AstNode]]: - return [self.path, self.orig_identifier, self.local_name] + return [self.path, *self.import_items] @dataclasses.dataclass @@ -509,6 +522,7 @@ def remove_redundant_empty_lines( Redundant empty lines are empty lines which are after: 1. Empty lines. 2. Labels. + or at the end of the list. """ new_code_elements = [] skip_empty_lines = True @@ -523,6 +537,10 @@ def remove_redundant_empty_lines( else: skip_empty_lines = False new_code_elements.append(code_elm) + + while len(new_code_elements) > 0 and is_empty_line(new_code_elements[-1]): + new_code_elements.pop() + return new_code_elements diff --git a/src/starkware/cairo/lang/compiler/ast/imports.py b/src/starkware/cairo/lang/compiler/ast/imports.py new file mode 100644 index 00000000..6df43631 --- /dev/null +++ b/src/starkware/cairo/lang/compiler/ast/imports.py @@ -0,0 +1,25 @@ +import dataclasses +from typing import Optional, Sequence + +from starkware.cairo.lang.compiler.ast.expr import ExprIdentifier +from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.error_handling import Location + + +@dataclasses.dataclass +class ImportItem(AstNode): + orig_identifier: ExprIdentifier + local_name: Optional[ExprIdentifier] + location: Optional[Location] = LocationField + + def format(self): + return f'{self.orig_identifier.format()}' + \ + (f' as {self.local_name.format()}' if self.local_name else '') + + @property + def identifier(self): + return self.local_name if self.local_name is not None else self.orig_identifier + + def get_children(self) -> Sequence[Optional[AstNode]]: + return [self.orig_identifier, self.local_name] diff --git a/src/starkware/cairo/lang/compiler/ast/rvalue.py b/src/starkware/cairo/lang/compiler/ast/rvalue.py index 6eea6e05..4b5c541a 100644 --- a/src/starkware/cairo/lang/compiler/ast/rvalue.py +++ b/src/starkware/cairo/lang/compiler/ast/rvalue.py @@ -8,6 +8,7 @@ particles_in_lines) from starkware.cairo.lang.compiler.ast.instructions import CallInstruction from starkware.cairo.lang.compiler.ast.node import AstNode +from starkware.cairo.lang.compiler.ast.notes import Notes from starkware.cairo.lang.compiler.error_handling import Location @@ -100,15 +101,22 @@ class RvalueFuncCall(RvalueCall): """ func_ident: ExprIdentifier exprs: List[ArgListItem] + notes: List[Notes] location: Optional[Location] = LocationField + def assert_no_comments(self): + for note in self.notes: + note.assert_no_comments() + def get_particles(self): + self.assert_no_comments() expr_codes = [x.format() for x in self.exprs] particles = [ f'{self.func_ident.format()}(', create_particle_sublist(expr_codes, ')')] return particles def format(self, allowed_line_length): + self.assert_no_comments() return particles_in_lines( particles=self.get_particles(), config=ParticleFormattingConfig( diff --git a/src/starkware/cairo/lang/compiler/ast_objects_test.py b/src/starkware/cairo/lang/compiler/ast_objects_test.py index 1c643584..0ab4db00 100644 --- a/src/starkware/cairo/lang/compiler/ast_objects_test.py +++ b/src/starkware/cairo/lang/compiler/ast_objects_test.py @@ -318,16 +318,22 @@ def test_parse_struct(): def test_parse_namespace(): before = """\ -namespace MyNS: +namespace MyNamespace: x = 5 y=3 end # Comment. + +namespace MyNamespace2: + end """ after = """\ -namespace MyNS: +namespace MyNamespace: x = 5 y = 3 end # Comment. + +namespace MyNamespace2: +end """ assert parse_file(before).format() == after @@ -354,7 +360,6 @@ def test_parse_func(): [ap] = 2; ap++ ap += 3 ret - end # Comment. call fib diff --git a/src/starkware/cairo/lang/compiler/cairo.ebnf b/src/starkware/cairo/lang/compiler/cairo.ebnf index 8ce33234..5ca2628e 100644 --- a/src/starkware/cairo/lang/compiler/cairo.ebnf +++ b/src/starkware/cairo/lang/compiler/cairo.ebnf @@ -82,7 +82,12 @@ function_call: identifier "(" arg_list ")" directive: "%builtins" identifier+ -> directive_builtins // Import statement. -_import: "from" identifier "import" identifier_def ("as" identifier_def)? +import_item: identifier_def ("as" identifier_def)? + +_import_body: import_item ("," import_item)* + | "(" notes (import_item notes "," notes)* import_item notes ","? notes ")" + +_import: "from" identifier "import" _import_body // Function/Namespace/Struct definition. _returns: "->" _NEWLINE* "(" identifier_list ")" @@ -117,7 +122,7 @@ code_element: instruction -> code_element_ | "alloc_locals" -> code_element_alloc_locals | -> code_element_empty_line commented_code_element: code_element [COMMENT] -code_block: (commented_code_element _NEWLINE)+ +code_block: (commented_code_element _NEWLINE)* cairo_file: code_block diff --git a/src/starkware/cairo/lang/compiler/cairo_compile.py b/src/starkware/cairo/lang/compiler/cairo_compile.py index 6c8ce517..42d5cea7 100644 --- a/src/starkware/cairo/lang/compiler/cairo_compile.py +++ b/src/starkware/cairo/lang/compiler/cairo_compile.py @@ -14,6 +14,7 @@ from starkware.cairo.lang.compiler.preprocessor.preprocessor import Preprocessor, preprocess_codes from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.version import __version__ DEFAULT_PRIME = 2**251 + 17 * 2**192 + 1 @@ -21,8 +22,8 @@ def main(): start_time = time.time() - parser = argparse.ArgumentParser( - description='A tool to compile Cairo code.') + parser = argparse.ArgumentParser(description='A tool to compile Cairo code.') + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') parser.add_argument('files', metavar='file', type=str, nargs='+', help='File names') parser.add_argument( '--prime', type=int, default=DEFAULT_PRIME, help='The size of the finite field.') diff --git a/src/starkware/cairo/lang/compiler/cairo_format.py b/src/starkware/cairo/lang/compiler/cairo_format.py index 7c517333..976281f5 100644 --- a/src/starkware/cairo/lang/compiler/cairo_format.py +++ b/src/starkware/cairo/lang/compiler/cairo_format.py @@ -2,11 +2,13 @@ import sys from starkware.cairo.lang.compiler.parser import parse_file +from starkware.cairo.lang.version import __version__ def main(): parser = argparse.ArgumentParser( description='A tool to automatically format Cairo code.') + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') parser.add_argument('files', metavar='file', type=str, nargs='+', help='File names') action = parser.add_mutually_exclusive_group(required=False) action.add_argument('-i', dest='inplace', action='store_true', help='Edit files inplace.') diff --git a/src/starkware/cairo/lang/compiler/error_handling.py b/src/starkware/cairo/lang/compiler/error_handling.py index d27b7eb4..ff2ff70e 100644 --- a/src/starkware/cairo/lang/compiler/error_handling.py +++ b/src/starkware/cairo/lang/compiler/error_handling.py @@ -43,6 +43,15 @@ def with_parent_location(self, new_parent_location: 'Location', message: str): return dataclasses.replace(self, parent_location=( new_self_parent_location, self_parent_location_message)) + def topmost_location(self): + """ + Returns the location of the topmost parent. + """ + location = self + while location.parent_location is not None: + location = location.parent_location[0] + return location + @post_dump def remove_none_values(self, data, many=False): return { diff --git a/src/starkware/cairo/lang/compiler/parser.py b/src/starkware/cairo/lang/compiler/parser.py index df9f5767..ed13f4c0 100644 --- a/src/starkware/cairo/lang/compiler/parser.py +++ b/src/starkware/cairo/lang/compiler/parser.py @@ -1,4 +1,5 @@ import os +from functools import lru_cache from typing import List, Optional import lark @@ -146,6 +147,7 @@ def parse_instruction(code: str) -> InstructionAst: return parse(None, code, 'instruction', InstructionAst) +@lru_cache(None) def parse_expr(code: str) -> Expression: """ Parses the given string and returns an Expression instance. diff --git a/src/starkware/cairo/lang/compiler/parser_test.py b/src/starkware/cairo/lang/compiler/parser_test.py index 0298fb66..a0b4205e 100644 --- a/src/starkware/cairo/lang/compiler/parser_test.py +++ b/src/starkware/cairo/lang/compiler/parser_test.py @@ -5,6 +5,7 @@ from starkware.cairo.lang.compiler.ast.expr import ( ExprConst, ExprDeref, ExprIdentifier, ExprOperator, ExprPyConst, ExprReg) from starkware.cairo.lang.compiler.ast.formatting_utils import FormattingError +from starkware.cairo.lang.compiler.ast.imports import ImportItem from starkware.cairo.lang.compiler.ast.instructions import ( AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) @@ -330,25 +331,50 @@ def test_import(): res = parse_code_element('from a import b') assert res == CodeElementImport( path=ExprIdentifier(name='a'), - orig_identifier=ExprIdentifier(name='b'), - local_name=None) + import_items=[ImportItem( + orig_identifier=ExprIdentifier(name='b'), + local_name=None)]) assert res.format(allowed_line_length=100) == 'from a import b' # Test module names without periods, with aliasing. res = parse_code_element('from a import b as c') assert res == CodeElementImport( path=ExprIdentifier(name='a'), - orig_identifier=ExprIdentifier(name='b'), - local_name=ExprIdentifier(name='c')) + import_items=[ImportItem( + orig_identifier=ExprIdentifier(name='b'), + local_name=ExprIdentifier(name='c'))]) assert res.format(allowed_line_length=100) == 'from a import b as c' # Test module names with periods. res = parse_code_element('from a.b12.c4 import lib345') assert res == CodeElementImport( path=ExprIdentifier(name='a.b12.c4'), - orig_identifier=ExprIdentifier(name='lib345')) + import_items=[ImportItem( + orig_identifier=ExprIdentifier(name='lib345'), + local_name=None)]) assert res.format(allowed_line_length=100) == 'from a.b12.c4 import lib345' + # Test multiple imports. + res = parse_code_element('from lib import a,b as b2, c') + + assert res == CodeElementImport( + path=ExprIdentifier(name='lib'), + import_items=[ + ImportItem( + orig_identifier=ExprIdentifier(name='a'), + local_name=None), + ImportItem( + orig_identifier=ExprIdentifier(name='b'), + local_name=ExprIdentifier(name='b2')), + ImportItem( + orig_identifier=ExprIdentifier(name='c'), + local_name=None), + ]) + assert res.format(allowed_line_length=100) == 'from lib import a, b as b2, c' + assert res.format(allowed_line_length=20) == 'from lib import (\n a, b as b2, c)' + + assert res == parse_code_element('from lib import (\n a, b as b2, c)') + # Test module with bad identifier (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c.d') diff --git a/src/starkware/cairo/lang/compiler/parser_transformer.py b/src/starkware/cairo/lang/compiler/parser_transformer.py index f3552565..9edc195f 100644 --- a/src/starkware/cairo/lang/compiler/parser_transformer.py +++ b/src/starkware/cairo/lang/compiler/parser_transformer.py @@ -16,6 +16,7 @@ from starkware.cairo.lang.compiler.ast.expr import ( ArgList, EllipsisSymbol, ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprTuple) +from starkware.cairo.lang.compiler.ast.imports import ImportItem from starkware.cairo.lang.compiler.ast.instructions import ( AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) @@ -274,10 +275,9 @@ def rvalue_call_instruction(self, value): @v_args(meta=True) def function_call(self, value, meta): func_ident, arg_list = value - for note in arg_list.notes: - note.assert_no_comments() return RvalueFuncCall( - func_ident=func_ident, exprs=arg_list.args, location=self.meta2loc(meta)) + func_ident=func_ident, exprs=arg_list.args, notes=arg_list.notes, + location=self.meta2loc(meta)) # CairoFile. @@ -442,25 +442,43 @@ def directive_builtins(self, value, meta): return BuiltinsDirective(builtins=builtins, location=self.meta2loc(meta)) @v_args(meta=True) - def code_element_import(self, value, meta): - if len(value) == 2: - # Statement of the form: - # from import . - path, identifier = value + def import_item(self, value, meta): + if len(value) == 1: + # Element of the form: . + identifier, = value local_name = None - elif len(value) == 3: - # Statement of the form: - # from import as . - path, identifier, local_name = value + elif len(value) == 2: + # Element of the form: as . + identifier, local_name = value else: raise NotImplementedError(f'Unexpected argument: value={value}') - return CodeElementImport( - path=path, + return ImportItem( orig_identifier=identifier, local_name=local_name, location=self.meta2loc(meta)) + @v_args(meta=True) + def code_element_import(self, value, meta): + path = value[0] + if isinstance(value[1], ImportItem): + # Single line. + import_items = value[1:] + notes = [] + else: + # Multiline. + assert len(value) % 3 == 2, f'Unexpected value {value}.' + import_items = value[2::3] + # Join the notes before and after the comma. + notes = [value[1]] + [value[i] + value[i + 1] for i in range(3, len(value) - 1, 3)] + + return CodeElementImport( + path=path, + import_items=import_items, + notes=notes, + location=self.meta2loc(meta), + ) + @v_args(meta=True) def code_element_alloc_locals(self, value, meta): return CodeElementAllocLocals(location=self.meta2loc(meta)) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/flow.py b/src/starkware/cairo/lang/compiler/preprocessor/flow.py index efe19a44..41ba4bf0 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/flow.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/flow.py @@ -109,8 +109,18 @@ def converge( reference = reference_manager.get_ref(ref_id) other_ref = reference_manager.get_ref(other_ref_id) try: - if simplifier.visit(reference.eval(new_ap_tracking)) == \ - simplifier.visit(other_ref.eval(new_ap_tracking)): + ref_expr = reference.eval(self.ap_tracking) + if simplifier.visit(ref_expr) == \ + simplifier.visit(other_ref.eval(other.ap_tracking)): + # Same expression. + if self.ap_tracking != new_ap_tracking: + # Create a new reference on the new ap tracking. + new_reference = Reference( + pc=reference.pc, + value=ref_expr, + ap_tracking_data=new_ap_tracking, + ) + ref_id = reference_manager.get_id(new_reference) reference_ids[name] = ref_id except FlowTrackingError: pass diff --git a/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py b/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py index bc56e94f..165f5712 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py @@ -95,8 +95,8 @@ def test_flow_tracking_labels_diverge(changes): ]) def test_flow_tracking_converge_references(refs): flow_tracking = FlowTracking() - flow_tracking.add_flow_to_label(ScopedName.from_string('a'), 7) - flow_tracking.add_flow_to_label(ScopedName.from_string('b'), 5) + flow_tracking.add_flow_to_label(ScopedName.from_string('a'), RegChangeUnknown()) + flow_tracking.add_flow_to_label(ScopedName.from_string('b'), RegChangeUnknown()) # Label a. flow_tracking.revoke() diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py index dad08cdb..bf429f85 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py @@ -181,18 +181,19 @@ def visit_CodeBlock(self, code_block: CodeBlock): self.visit(elm.code_elm) def visit_CodeElementImport(self, elm: CodeElementImport): - alias_dst = ScopedName.from_string(elm.path.name) + elm.orig_identifier.name - local_identifier = elm.identifier + for import_item in elm.import_items: + alias_dst = ScopedName.from_string(elm.path.name) + import_item.orig_identifier.name + local_identifier = import_item.identifier - # Ensure destination is a valid identifier. - if self.identifiers.get_by_full_name(alias_dst) is None: - raise PreprocessorError( - f"Scope '{elm.path.name}' does not include identifier " - f"'{elm.orig_identifier.name}'.", - location=elm.orig_identifier.location) - - # Add alias to identifiers. - self.add_identifier( - name=self.current_scope + local_identifier.name, - identifier_definition=AliasDefinition(destination=alias_dst), - location=elm.identifier.location) + # Ensure destination is a valid identifier. + if self.identifiers.get_by_full_name(alias_dst) is None: + raise PreprocessorError( + f"Scope '{elm.path.name}' does not include identifier " + f"'{import_item.orig_identifier.name}'.", + location=import_item.orig_identifier.location) + + # Add alias to identifiers. + self.add_identifier( + name=self.current_scope + local_identifier.name, + identifier_definition=AliasDefinition(destination=alias_dst), + location=import_item.identifier.location) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py index e326f977..2b35ce38 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py @@ -938,6 +938,7 @@ def visit_AssertEqInstruction(self, instruction: AssertEqInstruction): location=instruction.location) def visit_JumpInstruction(self, instruction: JumpInstruction): + self.revoke_function_ap_change() return InstructionFlows(), JumpInstruction( val=self.simplify_expr_as_felt(instruction.val), relative=instruction.relative, @@ -969,6 +970,9 @@ def visit_JumpToLabelInstruction(self, instruction: JumpToLabelInstruction): condition=self.simplify_expr_as_felt(instruction.condition), location=instruction.location) + if label_pc <= self.current_pc: + self.revoke_function_ap_change() + flow_next = None if instruction.condition is None else RegChangeKnown(0) if label_full_name is None: raise PreprocessorError( @@ -978,11 +982,20 @@ def visit_JumpToLabelInstruction(self, instruction: JumpToLabelInstruction): return InstructionFlows(next_inst=flow_next, jumps=jumps), res_instruction def visit_JnzInstruction(self, instruction: JnzInstruction): + self.revoke_function_ap_change() return InstructionFlows(next_inst=RegChangeKnown(0)), JnzInstruction( jump_offset=self.simplify_expr_as_felt(instruction.jump_offset), condition=self.simplify_expr_as_felt(instruction.condition), location=instruction.location) + def revoke_function_ap_change(self): + """ + Revokes the total_ap_change tracking of the function (which implies that calling it will + revoke the ap tracking). + """ + if self.current_scope in self.function_metadata: + self.function_metadata[self.current_scope].total_ap_change = RegChangeUnknown() + def visit_CallInstruction(self, instruction: CallInstruction): return InstructionFlows(next_inst=RegChangeUnknown()), CallInstruction( val=self.simplify_expr_as_felt(instruction.val), diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py index a9494532..6d6cb195 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py @@ -701,6 +701,33 @@ def test_func_by_value_return(): """ +@pytest.mark.parametrize('jmp_code', [ + 'jmp loop if [ap] != 0', + 'jmp rel 3', + 'jmp abs 3', + 'jmp rel [ap + 3] if [ap] != 0', +]) +def test_function_flow_revoke(jmp_code): + verify_exception(f""" +func foo(): + loop: + {jmp_code} + ret +end + +func bar(): + tempvar x = 0 + foo() + assert x = 0 + ret +end +""", """ +file:?:?: Reference 'x' was revoked. + assert x = 0 + ^ +""") + + def test_scope_label(): code = """\ x: @@ -736,13 +763,18 @@ def test_scope_label(): def test_import(): files = { '.': """ -from a import f as g +from a import f as g, h as h2 call g +call h2 """, 'a': """ func f(): jmp f end + +func h(): + jmp h +end """ } program = preprocess_codes( @@ -750,7 +782,9 @@ def test_import(): assert program.format() == """\ jmp rel 0 -call rel -2 +jmp rel 0 +call rel -4 +call rel -4 """ diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index 01a4a33b..9c1f4835 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.0.1", + "version": "0.0.2", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/scripts/CMakeLists.txt b/src/starkware/cairo/lang/scripts/CMakeLists.txt index 5bbff8ba..fb51f83d 100644 --- a/src/starkware/cairo/lang/scripts/CMakeLists.txt +++ b/src/starkware/cairo/lang/scripts/CMakeLists.txt @@ -1,7 +1,8 @@ python_lib(cairo_script_lib PREFIX starkware/cairo/lang/scripts FILES - cairo-format cairo-compile + cairo-format + cairo-hash-program cairo-run ) diff --git a/src/starkware/cairo/lang/scripts/cairo-hash-program b/src/starkware/cairo/lang/scripts/cairo-hash-program new file mode 100755 index 00000000..beafc0c3 --- /dev/null +++ b/src/starkware/cairo/lang/scripts/cairo-hash-program @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../..')) +from starkware.cairo.bootloader.hash_program import main # noqa + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/starkware/cairo/lang/setup.py b/src/starkware/cairo/lang/setup.py index 20796d7d..b64aa9c9 100644 --- a/src/starkware/cairo/lang/setup.py +++ b/src/starkware/cairo/lang/setup.py @@ -4,18 +4,23 @@ DIR = os.path.abspath(os.path.dirname(__file__)) requirements = open(os.path.join(DIR, 'requirements.txt')).read().splitlines() +version = open(os.path.join(DIR, 'starkware/cairo/lang/VERSION')).read().strip() +long_description = open('README.md', 'r', encoding='utf-8').read() setuptools.setup( - name='cairo-starkware', - version='0.0.1', + name='cairo-lang', + version=version, author='Starkware', author_email='info@starkware.co', description='Compiler and runner for the Cairo language', + install_requires=requirements, + long_description=long_description, packages=setuptools.find_packages(), python_requires='>=3.6', setup_requires=['wheel'], - install_requires=requirements, + url='https://cairo-lang.org/', package_data={ + 'starkware.cairo.lang': ['VERSION'], 'starkware.cairo.lang.compiler': ['cairo.ebnf'], 'starkware.cairo.lang.tracer': ['*.html', '*.css', '*.js', '*.png'], 'starkware.cairo.common': ['*.cairo'], @@ -25,5 +30,6 @@ 'starkware/cairo/lang/scripts/cairo-format', 'starkware/cairo/lang/scripts/cairo-compile', 'starkware/cairo/lang/scripts/cairo-run', + 'starkware/cairo/lang/scripts/cairo-hash-program', ] ) diff --git a/src/starkware/cairo/lang/version.py b/src/starkware/cairo/lang/version.py new file mode 100644 index 00000000..d639a71d --- /dev/null +++ b/src/starkware/cairo/lang/version.py @@ -0,0 +1,3 @@ +import os + +__version__ = open(os.path.join(os.path.dirname(__file__), 'VERSION')).read().strip() diff --git a/src/starkware/cairo/lang/vm/CMakeLists.txt b/src/starkware/cairo/lang/vm/CMakeLists.txt index 5badd124..15b1316d 100644 --- a/src/starkware/cairo/lang/vm/CMakeLists.txt +++ b/src/starkware/cairo/lang/vm/CMakeLists.txt @@ -58,6 +58,7 @@ python_lib(cairo_run_lib cairo_instances_lib cairo_run_builtins_lib cairo_tracer_lib + cairo_version_lib cairo_vm_lib starkware_python_utils_lib ) diff --git a/src/starkware/cairo/lang/vm/cairo_run.py b/src/starkware/cairo/lang/vm/cairo_run.py index efbf04fc..65ecf574 100644 --- a/src/starkware/cairo/lang/vm/cairo_run.py +++ b/src/starkware/cairo/lang/vm/cairo_run.py @@ -12,6 +12,7 @@ from starkware.cairo.lang.compiler.debug_info import DebugInfo from starkware.cairo.lang.compiler.program import Program, ProgramBase from starkware.cairo.lang.instances import LAYOUTS +from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.air_public_input import PublicInput, PublicMemoryEntry from starkware.cairo.lang.vm.cairo_pie import CairoPie from starkware.cairo.lang.vm.cairo_runner import CairoRunner @@ -28,6 +29,7 @@ def main(): parser = argparse.ArgumentParser( description='A tool to run Cairo programs.') + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') parser.add_argument( '--program', type=argparse.FileType('r'), help='The name of the program json file.') parser.add_argument( diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index ee9a81a0..ffeaef46 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -116,7 +116,7 @@ def gen_arg(self, arg, apply_modulo_to_args=True): def write_arg(self, ptr, arg, apply_modulo_to_args=True): assert isinstance(arg, Iterable) data = [self.gen_arg(arg=x, apply_modulo_to_args=apply_modulo_to_args) for x in arg] - self.load_data(ptr, data) + return self.load_data(ptr, data) def get_segment_used_size(segment_index: int, memory: MemoryDict) -> int: diff --git a/src/starkware/python/CMakeLists.txt b/src/starkware/python/CMakeLists.txt index 6900ff0d..575aac43 100644 --- a/src/starkware/python/CMakeLists.txt +++ b/src/starkware/python/CMakeLists.txt @@ -16,6 +16,12 @@ python_lib(starkware_expression_string_lib expression_string.py ) +python_lib(starkware_merkle_tree_lib + PREFIX starkware/python + FILES + merkle_tree.py +) + python_lib(starkware_python_test_utils_lib PREFIX starkware/python FILES diff --git a/src/starkware/python/merkle_tree.py b/src/starkware/python/merkle_tree.py new file mode 100644 index 00000000..fc575747 --- /dev/null +++ b/src/starkware/python/merkle_tree.py @@ -0,0 +1,26 @@ +from typing import Any, Collection, Tuple + + +def build_update_tree(height, modifications: Collection[Tuple[int, Any]]): + """ + Constructs a tree from leaf updates. This is not a full binary tree. It is just the subtree + induced by the modification leaves. + Returns a tree. A tree is either: + * None + * a pair of trees + * A leaf, which is a pair (leaf_index, modification) + """ + # Bottom layer. This will prefer the last modification to an index. + if len(modifications) == 0: + return None + + # A layer is a dictionary from index in current merkle layer (0 to 2**layer_height) to a tree. + # A tree is either None, a leaf, or a pair of trees. + layer = dict(modifications) + + for _ in range(height): + parents = set(index // 2 for index in layer.keys()) + layer = {index: (layer.get(index * 2), layer.get(index * 2 + 1)) for index in parents} + assert len(layer) == 1 + # We reached layer_height=0, the top layer with only the root (with index 0). + return layer[0] diff --git a/src/starkware/python/python_dependencies.py b/src/starkware/python/python_dependencies.py index 31c02ff3..69dd3256 100644 --- a/src/starkware/python/python_dependencies.py +++ b/src/starkware/python/python_dependencies.py @@ -2,7 +2,7 @@ import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) -assert os.path.basename(ROOT_DIR) in ['src', 'site-packages'] or \ +assert os.path.basename(ROOT_DIR) in ['src', 'site-packages', 'dist-packages'] or \ os.path.basename(ROOT_DIR).endswith('-site') diff --git a/src/starkware/python/utils.py b/src/starkware/python/utils.py index 7b674f84..3fb3c1e8 100644 --- a/src/starkware/python/utils.py +++ b/src/starkware/python/utils.py @@ -77,6 +77,8 @@ def indent(code, indentation): Indent code by 'indentation' spaces. For example, indent('hello\nworld\n', 2) -> ' hello\n world\n'. """ + if len(code) == 0: + return code if isinstance(indentation, int): indentation = ' ' * indentation elif not isinstance(indentation, str):