Skip to content

Commit

Permalink
Merge pull request #1245 from CityOfZion/CU-86drpncqy
Browse files Browse the repository at this point in the history
Add support to match-case (switch) with dict without conditionals
  • Loading branch information
meevee98 authored May 2, 2024
2 parents c71affa + 3adb0ff commit a5627c2
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 15 deletions.
2 changes: 1 addition & 1 deletion boa3/internal/analyser/moduleanalyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ def visit_For(self, for_node: ast.For):

def visit_Match(self, match_node: ast.Match):
for case in match_node.cases:
if not (isinstance(case.pattern, (ast.MatchValue, ast.MatchSingleton)) or
if not (isinstance(case.pattern, (ast.MatchValue, ast.MatchSingleton, ast.MatchMapping)) or
isinstance(case.pattern, ast.MatchAs) and case.pattern.pattern is None and case.guard is None
):
self._log_error(
Expand Down
95 changes: 95 additions & 0 deletions boa3/internal/compiler/codegenerator/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,3 +2453,98 @@ def generate_implicit_init_user_class(self, init_method: Method):
self.remove_stack_top_item()

self.convert_end_method()

def compare_dicts_match_case(self):
# stack: MATCH var (bottom), CASE pattern (top)
self.swap_reverse_stack_items(2)
self.duplicate_stack_top_item()
# pattern_keys = pattern.keys()
self.convert_builtin_method_call(Builtin.DictKeys)
# stack: var, pattern, pattern_keys

self.duplicate_stack_top_item()
# list_length = len(pattern_keys)
self.convert_builtin_method_call(Builtin.Len)
# index = 0
self.convert_literal(0)
# is_ok = True
self.convert_literal(True)
# stack: var, pattern, pattern_keys, list_length, index, is_ok

# while index < list_length and is_ok:
begin_map_comparison = self.convert_begin_while()

self.duplicate_stack_item(4)
self.duplicate_stack_item(3)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, pattern_keys, index
self.convert_get_item(index_inserted_internally=True, index_is_positive=True, test_is_negative_index=False)
# current_key = pattern_keys[index]
# stack: var, pattern, pattern_keys, list_length, index, is_ok, current_key

self.duplicate_stack_top_item()
self.duplicate_stack_item(8)
self.swap_reverse_stack_items(2)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, current_key, var, current_key

self.swap_reverse_stack_items(2)
# key_in_var = current_key in var
self.convert_operation(BinaryOp.In, is_internal=True)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, current_key, key_in_var

# if key_in_var:
if_key_in_var = self.convert_begin_if()
self.duplicate_stack_top_item()
self.duplicate_stack_item(8)
self.swap_reverse_stack_items(2)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, current_key, var, current_key

# var_current_value == var[current_key]
self.convert_get_item(index_inserted_internally=True, index_is_positive=True, test_is_negative_index=False)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, current_key, var_current_value
self.swap_reverse_stack_items(2)
self.duplicate_stack_item(7)
self.swap_reverse_stack_items(2)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, var_current_value, pattern, current_key
# pattern_current_value == pattern[current_key]
self.convert_get_item(index_inserted_internally=True, index_is_positive=True, test_is_negative_index=False)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, var_current_value, pattern_current_value
# not_same_values = var_current_value == pattern_current_value
self.convert_operation(BinaryOp.NotEq, is_internal=True)
# stack: var, pattern, pattern_keys, list_length, index, is_ok, not_same_values

# if not_same_values:
values_are_not_equal = self.convert_begin_if()
self.remove_stack_top_item()
# return False
self.convert_literal(False)

# else:
self.convert_begin_else(if_key_in_var)
self.remove_stack_top_item()
self.remove_stack_top_item()
# return False
self.convert_literal(False)

self.convert_end_if(values_are_not_equal, is_internal=True)

# index += 1
self.swap_reverse_stack_items(2)
self.__insert1(OpcodeInfo.INC)
self.swap_reverse_stack_items(2)

# stack: var, pattern, pattern_keys, list_length, index, is_ok
# checking if `index < list_length and is_ok` is still true
condition_address = self.bytecode_size
self.duplicate_stack_item(2)
self.duplicate_stack_item(4)
self.convert_operation(BinaryOp.Lt, is_internal=True)
self.duplicate_stack_item(2)
self.convert_operation(BinaryOp.And, is_internal=True)
self.convert_end_while(begin_map_comparison, condition_address, is_internal=True)

# when finishing while loop, clean the stack and only `is_ok` will remain
self.remove_stack_item(2)
self.remove_stack_item(2)
self.remove_stack_item(2)
self.remove_stack_item(2)
self.remove_stack_item(2)
27 changes: 24 additions & 3 deletions boa3/internal/compiler/codegenerator/codegeneratorvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,16 +707,37 @@ def visit_Match(self, match_node: ast.Match) -> GeneratorData:
if isinstance(case.pattern, ast.MatchSingleton):
self.generator.convert_literal(case.pattern.value)
pattern_type = self.get_type(case.pattern.value)
elif isinstance(case.pattern, ast.MatchMapping):
pattern_type = Type.dict

self.generator.convert_new_map(pattern_type)

for map_index in range(len(case.pattern.keys)):
start_address = VMCodeMapping.instance().bytecode_size
self.generator.duplicate_stack_top_item()
self.visit_to_generate(case.pattern.keys[map_index])
self.visit_to_generate(case.pattern.patterns[map_index].value)
self.generator.convert_set_item(start_address)
else:
pattern = self.visit_to_generate(case.pattern.value)
pattern_type = pattern.type

self.generator.duplicate_stack_item(2)
pattern_type.generate_is_instance_type_check(self.generator)

self.generator.swap_reverse_stack_items(3)
self.generator.convert_operation(BinaryOp.NumEq)
self.generator.convert_operation(BinaryOp.And)
is_same_type = self.generator.convert_begin_if()

self.generator.swap_reverse_stack_items(2)
if isinstance(case.pattern, ast.MatchMapping):
self.generator.compare_dicts_match_case()
else:
self.generator.convert_operation(BinaryOp.NumEq)

is_not_same_type = self.generator.convert_begin_else(is_same_type)
self.generator.remove_stack_top_item()
self.generator.remove_stack_top_item()
self.generator.convert_literal(False)
self.generator.convert_end_if(is_not_same_type)

case_addresses.append(self.generator.convert_begin_if())
for stmt in case.body:
Expand Down
2 changes: 2 additions & 0 deletions boa3_test/test_sc/match_case_test/AnyTypeMatchCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def main(x: Any) -> str:
return "one"
case "2":
return "2 string"
case {}:
return "dictionary"
case _:
# this is the default case, when all others are False
return "other"
20 changes: 20 additions & 0 deletions boa3_test/test_sc/match_case_test/DictTypeMatchCase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from boa3.builtin.compile_time import public


@public
def main(dict_: dict) -> str:
match dict_:
case {
'ccccc': None,
'ab': 'cd',
'12': '34',
'xy': 'zy',
'00': '55',
}:
return "big dictionary"
case {'key': 'value'}:
return "key and value"
case {}:
return "empty dict"
case _:
return "default return"
90 changes: 79 additions & 11 deletions boa3_test/tests/compiler_tests/test_match_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,39 @@ class TestMatchCase(boatestcase.BoaTestCase):
async def test_any_type_match_case(self):
await self.set_up_contract('AnyTypeMatchCase.py')

result, _ = await self.call('main', [True], return_type=str)
self.assertEqual("True", result)

result, _ = await self.call('main', [1], return_type=str)
self.assertEqual("one", result)

result, _ = await self.call('main', ["2"], return_type=str)
self.assertEqual("2 string", result)

result, _ = await self.call('main', ['other value'], return_type=str)
self.assertEqual("other", result)
def match_case(x) -> str:
match x:
case True:
return "True"
case 1:
return "one"
case "2":
return "2 string"
case {}:
return "dictionary"
case _:
# this is the default case, when all others are False
return "other"

arg = True
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = 1
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = "2"
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = {'any': 'dict'}
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = 'other value'
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

async def test_bool_type_match_case(self):
await self.set_up_contract('BoolTypeMatchCase.py')
Expand Down Expand Up @@ -64,6 +86,52 @@ async def test_str_type_match_case(self):
result, _ = await self.call('main', ['unit test'], return_type=str)
self.assertEqual("other", result)

async def test_dict_type_match_case(self):
await self.set_up_contract('DictTypeMatchCase.py')

def match_case(dict_: dict) -> str:
match dict_:
case {
'ccccc': None,
'ab': 'cd',
'12': '34',
'xy': 'zy',
'00': '55',
}:
return "big dictionary"
case {'key': 'value'}:
return "key and value"
case {}:
return "empty dict"
case _:
return "default return"

arg = {}
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = {'key': 'value'}
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = {'key': 'value', 'unit': 'test'}
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = {'another': 'pair'}
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

arg = {
'ccccc': None,
'ab': 'cd',
'12': '34',
'xy': 'zy',
'00': '55',
}
result, _ = await self.call('main', [arg], return_type=str)
self.assertEqual(match_case(arg), result)

async def test_outer_var_inside_match(self):
await self.set_up_contract('OuterVariableInsideMatch.py')

Expand Down

0 comments on commit a5627c2

Please sign in to comment.