From 56d40da1d5e8f6b0dc59bf3da54848e6e62d65ff Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp <19483938+rawwerks@users.noreply.github.com> Date: Fri, 9 Feb 2024 09:52:39 -0800 Subject: [PATCH 1/3] allow `@` expressions --- src/RestrictedPython/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index c6e2e78..7b60b14 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -768,8 +768,8 @@ def visit_BitAnd(self, node): return self.node_contents_visit(node) def visit_MatMult(self, node): - """Matrix multiplication (`@`) is currently not allowed.""" - self.not_allowed(node) + """Allow `@` expressions.""" + return self.node_contents_visit(node) def visit_BoolOp(self, node): """Allow bool operator without restrictions.""" From 959ac23ee015ba2c4aad0f3e98d1cc502689e350 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp <19483938+rawwerks@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:38:15 -0800 Subject: [PATCH 2/3] test that matmult works --- .../operators/test_arithmetic_operators.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/transformer/operators/test_arithmetic_operators.py b/tests/transformer/operators/test_arithmetic_operators.py index 209eb50..78b7fdc 100644 --- a/tests/transformer/operators/test_arithmetic_operators.py +++ b/tests/transformer/operators/test_arithmetic_operators.py @@ -33,8 +33,16 @@ def test_FloorDiv(): def test_MatMult(): - result = compile_restricted_eval('(8, 3, 5) @ (2, 7, 1)') - assert result.errors == ( - 'Line None: MatMult statements are not allowed.', - ) - assert result.code is None + source_code = """ +class Vector: + def __init__(self, values): + self.values = values + + def __matmul__(self, other): + return sum(x * y for x, y in zip(self.values, other.values)) + +result = Vector((8, 3, 5)) @ Vector((2, 7, 1)) +""" + # Assuming restricted_eval can execute the source_code and return the value of 'result' + assert restricted_eval(source_code) == 42 + From 39da405d5474a5fd9a91d530b97677c75a055d45 Mon Sep 17 00:00:00 2001 From: Raymond Weitekamp <19483938+rawwerks@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:19:29 -0800 Subject: [PATCH 3/3] try to pass the tests: update formatting and allow matmult --- src/RestrictedPython/transformer.py | 1 + tests/transformer/operators/test_arithmetic_operators.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index 7b60b14..0a88779 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -61,6 +61,7 @@ '__ne__', '__gt__', '__ge__', + '__matmult__' ]) diff --git a/tests/transformer/operators/test_arithmetic_operators.py b/tests/transformer/operators/test_arithmetic_operators.py index 78b7fdc..893c61b 100644 --- a/tests/transformer/operators/test_arithmetic_operators.py +++ b/tests/transformer/operators/test_arithmetic_operators.py @@ -32,8 +32,7 @@ def test_FloorDiv(): assert restricted_eval('7 // 2') == 3 -def test_MatMult(): - source_code = """ +test_matmult = """\ class Vector: def __init__(self, values): self.values = values @@ -43,6 +42,7 @@ def __matmul__(self, other): result = Vector((8, 3, 5)) @ Vector((2, 7, 1)) """ - # Assuming restricted_eval can execute the source_code and return the value of 'result' - assert restricted_eval(source_code) == 42 + +def test_MatMult(): + assert restricted_eval(test_matmult) == 42