Skip to content

Commit

Permalink
Python: Allow python-style unpacking assignments (#1849)
Browse files Browse the repository at this point in the history
Co-authored-by: Tal Ben-Nun <[email protected]>
  • Loading branch information
phschaad and tbennun authored Jan 22, 2025
1 parent 1a56352 commit f0ca36b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
16 changes: 15 additions & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3297,7 +3297,21 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
for n in node.value.elts:
results.extend(self._gettype(n))
else:
results.extend(self._gettype(node.value))
rval = self._gettype(node.value)
if (len(elts) > 1 and len(rval) == 1 and rval[0][1] == data.Array and rval[0][0] in self.sdfg.arrays and
self.sdfg.arrays[rval[0][0]].shape[0] == len(elts)):
# In the case where the rhs is an array (not being accessed with a slice) of exactly the same length as
# the number of elements in the lhs, the array can be expanded with a series of slice/subscript accesses
# to constant indexes (according to the number of elements in the lhs). These expansions can then be
# used to perform an unpacking assignment, similar to what Python does natively.
for i in range(len(elts)):
const_node = NumConstant(i)
ast.copy_location(const_node, node)
slice_node = ast.Subscript(rval[0][0], const_node, ast.Load)
ast.copy_location(slice_node, node)
results.extend(self._gettype(slice_node))
else:
results.extend(rval)

if len(results) != len(elts):
raise DaceSyntaxError(self, node, 'Function returns %d values but %d provided' % (len(results), len(elts)))
Expand Down
58 changes: 58 additions & 0 deletions tests/python_frontend/assignment_statements_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,60 @@ def test_multiple_targets_parentheses():
assert (c[0] == np.float32(2) * np.float32(np.pi))


@dace.program
def multiple_targets_unpacking(a: dace.float32[2]):
b, c = a
return b, c


def test_multiple_targets_unpacking():
a = np.zeros((2, ), dtype=np.float32)
a[0] = np.pi
a[1] = 2 * np.pi
b, c = multiple_targets_unpacking(a=a)
assert (b[0] == a[0])
assert (c[0] == a[1])


@dace.program
def multiple_targets_unpacking_multidim(a: dace.float64[2, 3, 4]):
b, c = a
return b, c


def test_multiple_targets_unpacking_multidim():
a = np.random.rand(2, 3, 4)
b, c = multiple_targets_unpacking_multidim(a)
bref, cref = a
assert np.allclose(b, bref)
assert np.allclose(c, cref)


@dace.program
def multiple_targets_unpacking_func(a: dace.float64[2, 3, 4]):
b, c = np.square(a)
return b, c


def test_multiple_targets_unpacking_func():
a = np.random.rand(2, 3, 4)
b, c = multiple_targets_unpacking_func(a)
bref, cref = np.square(a)
assert np.allclose(b, bref)
assert np.allclose(c, cref)


def test_multiple_targets_unpacking_invalid():

@dace.program
def tester(a: dace.float64[2, 3, 4]):
b, c, d = np.square(a)
return b, c, d

with pytest.raises(DaceSyntaxError):
tester.to_sdfg()


@dace.program
def starred_target(a: dace.float32[1]):
b, *c, d, e = a, 2 * a, 3 * a, 4 * a, 5 * a, 6 * a
Expand Down Expand Up @@ -175,6 +229,10 @@ def method(self):
test_single_target_parentheses()
test_multiple_targets()
test_multiple_targets_parentheses()
test_multiple_targets_unpacking()
test_multiple_targets_unpacking_multidim()
test_multiple_targets_unpacking_func()
test_multiple_targets_unpacking_invalid()

# test_starred_target()
# test_attribute_reference()
Expand Down

0 comments on commit f0ca36b

Please sign in to comment.