diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 050cc9f44e..50bf5e9fb2 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -343,10 +343,21 @@ def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subse arrname, tasklet_slice = astutils.subscript_to_ast_slice(node) arrname = arrname if arrname in self.arrays else None if len(tasklet_slice) < len(memlet_subset): + new_tasklet_slice = [(None, None, None)] * len(memlet_subset) # Unsqueeze all index dimensions from orig_subset into tasklet_subset - for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))): - if start == end: - tasklet_slice.insert(i, (None, None, None)) + j = 0 + for i, (start, end, _) in enumerate(memlet_subset.ndrange()): + if start != end: + new_tasklet_slice[i] = tasklet_slice[j] + j += 1 + + # Sanity check + if j != len(tasklet_slice): + raise IndexError(f'Only {j} out of {len(tasklet_slice)} indices were provided in subset expression ' + f'"{astutils.unparse(node)}", found during composing with memlet of subset ' + f'"{memlet_subset}".') + tasklet_slice = new_tasklet_slice + tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname)) return memlet_subset.compose(tasklet_subset) diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 9d8791182e..f1682a3667 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -747,6 +747,43 @@ def test_reversed_order(): sdfg.compile() +@pytest.mark.parametrize('memlet_volume_n', (False, True)) +def test_scalar_index_regression(memlet_volume_n): + """ + Tests a reported failure with an invalid promotion of a scalar index. + """ + N = dace.symbol('N') + volume = 1 if not memlet_volume_n else N + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [10, 10, N], dace.float64) + sdfg.add_scalar('scal', dace.int64) + sdfg.add_scalar('tmp', dace.int64, transient=True) + + init_state = sdfg.add_state() + t = init_state.add_tasklet('set', {}, {'t'}, 't = 1') + w = init_state.add_write('tmp') + init_state.add_edge(t, 't', w, None, dace.Memlet('tmp')) + + state = sdfg.add_state_after(init_state) + r = state.add_read('scal') + rt = state.add_read('tmp') + t = state.add_tasklet('setone', {'s', 't'}, {'a'}, 'a[s + t] = -1') + w = state.add_write('A') + state.add_edge(rt, None, t, 't', dace.Memlet('tmp')) + state.add_edge(r, None, t, 's', dace.Memlet('scal')) + state.add_edge(t, 'a', w, None, dace.Memlet(data='A', subset='0, 0, 0:N', volume=volume)) + + sdfg.validate() + scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) + + a = np.random.rand(10, 10, 20) + scal = np.int64(5) + ref = np.copy(a) + ref[0, 0, scal + 1] = -1 + sdfg(A=a, scal=scal, N=20) + assert np.allclose(a, ref) + + if __name__ == '__main__': test_find_promotable() test_promote_simple() @@ -772,3 +809,5 @@ def test_reversed_order(): test_ternary_expression(True) test_double_index_bug() test_reversed_order() + test_scalar_index_regression(False) + test_scalar_index_regression(True)