diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index c6e2e78d..0a88779c 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -61,6 +61,7 @@ '__ne__', '__gt__', '__ge__', + '__matmult__' ]) @@ -768,8 +769,8 @@ def visit_BitAnd(self, node): return self.node_contents_visit(node) def visit_MatMult(self, node): - """Matrix multiplication (`@`) is currently not allowed.""" - self.not_allowed(node) + """Allow `@` expressions.""" + return self.node_contents_visit(node) def visit_BoolOp(self, node): """Allow bool operator without restrictions.""" diff --git a/tests/transformer/operators/test_arithmetic_operators.py b/tests/transformer/operators/test_arithmetic_operators.py index 209eb501..893c61bd 100644 --- a/tests/transformer/operators/test_arithmetic_operators.py +++ b/tests/transformer/operators/test_arithmetic_operators.py @@ -32,9 +32,17 @@ def test_FloorDiv(): assert restricted_eval('7 // 2') == 3 +test_matmult = """\ +class Vector: + def __init__(self, values): + self.values = values + + def __matmul__(self, other): + return sum(x * y for x, y in zip(self.values, other.values)) + +result = Vector((8, 3, 5)) @ Vector((2, 7, 1)) +""" + def test_MatMult(): - result = compile_restricted_eval('(8, 3, 5) @ (2, 7, 1)') - assert result.errors == ( - 'Line None: MatMult statements are not allowed.', - ) - assert result.code is None + assert restricted_eval(test_matmult) == 42 +