From e1886a025e547eee572a3b4f14c9c3946d669882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 28 Jan 2025 09:48:41 +0100 Subject: [PATCH] Added `FindSingleUseData` Pass (#1906) Added `FindSingleUseData` analysis pass This new pass scans the SDFG and computes the data that is only used in one single place. Essentially, this boils down to the check how often a data descriptor is accessed A data descriptor is classified if the following statements are true: - There is exactly one AccessNode referring to the data descriptor. - The data descriptor is not referred to on an interstate edge, nor in the condition parts of `ConditionBlock`s or `LoopRegion`s. This pass will be needed to speed up fusion passes. --- .../passes/analysis/analysis.py | 73 +++ tests/passes/find_single_use_data_test.py | 434 ++++++++++++++++++ 2 files changed, 507 insertions(+) create mode 100644 tests/passes/find_single_use_data_test.py diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 94c24399ee..7798340884 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -333,6 +333,79 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: return top_result +@properties.make_properties +@transformation.explicit_cf_compatible +class FindSingleUseData(ppl.Pass): + """ + For each SDFG find all data descriptors that are referenced in exactly one location. + + In addition to the requirement that there exists exactly one AccessNode that + refers to a data descriptor the following conditions have to be meet as well: + - The data is not read on an interstate edge. + - The data is not accessed in the branch condition, loop condition, etc. of + control flow regions. + - There must be at least one AccessNode that refers to the data. I.e. if it exists + inside `SDFG.arrays` but there is no AccessNode, then it is _not_ included. + + It is also important to note that the degree of the AccessNodes are ignored. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + # If anything was modified, reapply + return modified & ppl.Modifies.AccessNodes & ppl.Modifies.CFG + + def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFG, Set[str]]: + """ + :return: A dictionary mapping SDFGs to a `set` of strings containing the name + of the data descriptors that are only used once. + """ + # TODO(pschaad): Should we index on cfg or the SDFG itself. + exclusive_data: Dict[SDFG, Set[str]] = {} + for nsdfg in sdfg.all_sdfgs_recursive(): + exclusive_data[nsdfg] = self._find_single_use_data_in_sdfg(nsdfg) + return exclusive_data + + def _find_single_use_data_in_sdfg(self, sdfg: SDFG) -> Set[str]: + """Scans an SDFG and computes the data that is only used once in the SDFG. + + The rules used to classify data descriptors are outlined above. The function + will not scan nested SDFGs. + + :return: The set of data descriptors that are used once in the SDFG. + """ + # If we encounter a data descriptor for the first time we immediately + # classify it as single use. We will undo this decision as soon as + # learn that it is used somewhere else. + single_use_data: Set[str] = set() + previously_seen: Set[str] = set() + + for state in sdfg.states(): + for dnode in state.data_nodes(): + data_name: str = dnode.data + if data_name in single_use_data: + single_use_data.discard(data_name) # Classified too early -> Undo + elif data_name not in previously_seen: + single_use_data.add(data_name) # Never seen -> Assume single use + previously_seen.add(data_name) + + # By definition, data that is referenced by interstate edges is not single + # use data, also remove it. + for edge in sdfg.all_interstate_edges(): + single_use_data.difference_update(edge.data.free_symbols) + + # By definition, data that is referenced by the conditions (branching condition, + # loop condition, ...) is not single use data, also remove that. + for cfr in sdfg.all_control_flow_regions(): + single_use_data.difference_update(cfr.used_symbols(all_symbols=True, with_contents=False)) + + return single_use_data + + @properties.make_properties @transformation.explicit_cf_compatible class FindAccessNodes(ppl.Pass): diff --git a/tests/passes/find_single_use_data_test.py b/tests/passes/find_single_use_data_test.py new file mode 100644 index 0000000000..a3f60dd7a7 --- /dev/null +++ b/tests/passes/find_single_use_data_test.py @@ -0,0 +1,434 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict, Set, Tuple +import dace +from dace.transformation.passes.analysis import FindSingleUseData + +def perform_scan(sdfg: dace.SDFG) -> Dict[dace.SDFG, Set[str]]: + scanner = FindSingleUseData() + return scanner.apply_pass(sdfg, None) + + +def _make_all_single_use_data_but_one_unused_sdfg() -> dace.SDFG: + sdfg = dace.SDFG('all_single_use_data_but_one_unused_sdfg') + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + for name in 'abcde': + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state1.add_nedge( + state1.add_access('a'), + state1.add_access('b'), + sdfg.make_array_memlet('a') + ) + state2.add_nedge( + state2.add_access('c'), + state2.add_access('d'), + sdfg.make_array_memlet('c') + ) + sdfg.validate() + return sdfg + + +def test_all_single_use_data_but_one_unused(): + sdfg = _make_all_single_use_data_but_one_unused_sdfg() + assert len(sdfg.arrays) == 5 + + # Because `e` is not used inside the SDFG, it is not included in the returned set, + # all other descriptors are included because they appear once. + expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'} + + single_use_set = perform_scan(sdfg) + + assert len(single_use_set[sdfg]) == 4 + assert len(single_use_set) == 1 + assert single_use_set[sdfg] == expected_single_use_set + + +def _make_multiple_access_same_state_sdfg() -> dace.SDFG: + sdfg = dace.SDFG('multiple_access_same_state_sdfg') + state = sdfg.add_state(is_start_block=True) + + for name in 'abd': + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state.add_nedge( + state.add_access('a'), + state.add_access('b'), + sdfg.make_array_memlet('a') + ) + state.add_nedge( + state.add_access('a'), + state.add_access('d'), + sdfg.make_array_memlet('a') + ) + sdfg.validate() + return sdfg + + +def test_multiple_access_same_state(): + sdfg = _make_multiple_access_same_state_sdfg() + assert len(sdfg.arrays) == 3 + + # `a` is not single use data because there are multiple access nodes for it + # in a single state. + expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'a'} + single_use_set = perform_scan(sdfg) + assert len(single_use_set) == 1 + assert len(single_use_set[sdfg]) == 2 + assert expected_single_use_set == single_use_set[sdfg] + + +def _make_multiple_single_access_node_same_state_sdfg() -> dace.SDFG: + sdfg = dace.SDFG('multiple_single_access_node_same_state_sdfg') + state = sdfg.add_state(is_start_block=True) + + for name in 'abd': + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + a = state.add_access('a') + state.add_nedge( + a, + state.add_access('b'), + sdfg.make_array_memlet('a') + ) + state.add_nedge( + a, + state.add_access('d'), + sdfg.make_array_memlet('a') + ) + assert state.out_degree(a) == 2 + sdfg.validate() + return sdfg + + +def test_multiple_single_access_node_same_state_sdfg() -> dace.SDFG: + sdfg = _make_multiple_single_access_node_same_state_sdfg() + assert len(sdfg.arrays) == 3 + + # Unlike `test_multiple_access_same_state()` here `a` is included in the single use + # set, because, there is only a single AccessNode, that is used multiple times, + # i.e. has an output degree larger than one. + expected_single_use_set = sdfg.arrays.keys() + single_use_set = perform_scan(sdfg) + assert len(single_use_set) == 1 + assert len(single_use_set[sdfg]) == 3 + assert expected_single_use_set == single_use_set[sdfg] + + +def _make_multiple_access_different_states_sdfg() -> dace.SDFG: + sdfg = dace.SDFG('multiple_access_different_states_sdfg') + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + for name in 'abd': + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + # Note these edges are useless as `a` is written to twice. It is just to generate + # an additional case, i.e. the data are also written to. + state1.add_nedge( + state1.add_access('b'), + state1.add_access('a'), + sdfg.make_array_memlet('a') + ) + state2.add_nedge( + state2.add_access('d'), + state2.add_access('a'), + sdfg.make_array_memlet('a') + ) + sdfg.validate() + return sdfg + + +def test_multiple_access_different_states(): + sdfg = _make_multiple_access_different_states_sdfg() + assert len(sdfg.arrays) == 3 + + # `a` is not included in the single use set, because it is used in two different states. + single_use_set = perform_scan(sdfg) + expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'a'} + assert len(single_use_set) == 1 + assert len(single_use_set[sdfg]) == 2 + assert expected_single_use_set == single_use_set[sdfg] + + +def _make_access_only_on_interstate_edge_sdfg() -> dace.SDFG: + sdfg = dace.SDFG('access_on_interstate_edge_sdfg') + + for name in 'abcd': + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.add_scalar('e', dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1, assignments={'e_sym': 'e'}) + + state1.add_nedge( + state1.add_access('a'), + state1.add_access('b'), + sdfg.make_array_memlet('a') + ) + state2.add_nedge( + state2.add_access('c'), + state2.add_access('d'), + sdfg.make_array_memlet('c') + ) + sdfg.validate() + return sdfg + + +def test_access_only_on_interstate_edge(): + sdfg = _make_access_only_on_interstate_edge_sdfg() + assert len(sdfg.arrays) == 5 + + # `e` is only accessed on the interstate edge. So it is technically an single use + # data. But by definition we handle this case as non single_use. + expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'} + single_use_set = perform_scan(sdfg) + assert len(single_use_set) == 1 + assert len(single_use_set[sdfg]) == 4 + assert single_use_set[sdfg] == expected_single_use_set + + +def _make_additional_access_on_interstate_edge_sdfg() -> dace.SDFG: + sdfg = dace.SDFG('additional_access_on_interstate_edge_sdfg') + + for name in 'abcd': + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.add_scalar('e', dtype=dace.float64, transient=False) + sdfg.add_scalar('f', dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1, assignments={'e_sym': 'e'}) + + state1.add_nedge( + state1.add_access('a'), + state1.add_access('b'), + sdfg.make_array_memlet('a') + ) + state2.add_nedge( + state2.add_access('c'), + state2.add_access('d'), + sdfg.make_array_memlet('c') + ) + state2.add_nedge( + state2.add_access('e'), + state2.add_access('f'), + dace.Memlet('f[0] -> [0]') + ) + sdfg.validate() + return sdfg + + +def test_additional_access_on_interstate_edge(): + sdfg = _make_additional_access_on_interstate_edge_sdfg() + assert len(sdfg.arrays) == 6 + + # There is one AccessNode for `a`, but as in `test_access_only_on_interstate_edge` + # `e` is also used on the inter state edge, so it is not included. + expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'} + single_use_set = perform_scan(sdfg) + assert len(single_use_set) == 1 + assert len(single_use_set[sdfg]) == 5 + assert single_use_set[sdfg] == expected_single_use_set + + +def _make_access_nested_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG('access_nested_nsdfg') + + for aname in 'ab': + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + state = sdfg.add_state(is_start_block=True) + state.add_nedge( + state.add_access('a'), + state.add_access('b'), + sdfg.make_array_memlet('a') + ) + sdfg.validate() + return sdfg + + +def _make_access_nested_sdfg() -> Tuple[dace.SDFG, dace.SDFG]: + sdfg = dace.SDFG('access_nested_sdfg') + nsdfg = _make_access_nested_nsdfg() + + for aname in 'ab': + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + state = sdfg.add_state(is_start_block=True) + nsdfg_node = state.add_nested_sdfg( + nsdfg, + parent=sdfg, + inputs={'a'}, + outputs={'b'}, + symbol_mapping={}, + ) + + state.add_edge( + state.add_access('a'), + None, + nsdfg_node, + 'a', + sdfg.make_array_memlet('a'), + ) + state.add_edge( + nsdfg_node, + 'b', + state.add_access('b'), + None, + sdfg.make_array_memlet('b'), + ) + sdfg.validate() + return sdfg, nsdfg + + +def test_access_nested_sdfg(): + sdfg, nested_sdfg = _make_access_nested_sdfg() + assert all(len(nsdfg.arrays) == 2 for nsdfg in [sdfg, nested_sdfg]) + + # In the top and the nested SDFG `a` and `b` are both used once, so for + # both they are included in the single use set. + # Essentially tests if there is separation between the two. + expected_single_use_set = {'a', 'b'} + single_use_sets = perform_scan(sdfg) + + assert len(single_use_sets) == 2 + assert all(single_use_sets[nsdfg] == expected_single_use_set for nsdfg in [sdfg, nested_sdfg]) + + +def _make_conditional_block_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("conditional_block_sdfg") + + for name in ["a", "b", "c", "d", "cond", "cond2"]: + sdfg.add_scalar( + name, + dtype=dace.bool_ if name.startswith("cond") else dace.float64, + transient=False + ) + sdfg.arrays["b"].transient = True + sdfg.arrays["cond2"].transient = True + + entry_state = sdfg.add_state("entry", is_start_block=True) + entry_state.add_nedge( + entry_state.add_access("a"), + entry_state.add_access("b"), + sdfg.make_array_memlet("a") + ) + cond_tasklet: dace.nodes.Tasklet = entry_state.add_tasklet( + "cond_processing", + inputs={"__in"}, + code="__out = not __in", + outputs={"__out"}, + ) + entry_state.add_edge( + entry_state.add_access("cond"), + None, + cond_tasklet, + "__in", + dace.Memlet("cond[0]") + ) + entry_state.add_edge( + cond_tasklet, + "__out", + entry_state.add_access("cond2"), + None, + dace.Memlet("cond2[0]") + ) + + if_region = dace.sdfg.state.ConditionalBlock("if") + sdfg.add_node(if_region) + sdfg.add_edge( + entry_state, + if_region, + dace.InterstateEdge() + ) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tstate.add_nedge( + tstate.add_access("b"), + tstate.add_access("c"), + sdfg.make_array_memlet("b") + ) + if_region.add_branch( + dace.sdfg.state.CodeBlock("cond2"), + then_body + ) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + fstate.add_nedge( + fstate.add_access("b"), + fstate.add_access("d"), + sdfg.make_array_memlet("d") + ) + if_region.add_branch( + dace.sdfg.state.CodeBlock("not (cond2)"), + else_body + ) + sdfg.validate() + return sdfg + + +def test_conditional_block(): + sdfg = _make_conditional_block_sdfg() + + # `b` is not in no single use data, because there are three AccessNodes for it. + # `cond2` is no single use data, although there is exactly one AccessNode for + # it, it is used in the condition expression. + expected_single_use_set = {a for a in sdfg.arrays.keys() if a not in ["b", "cond2"]} + single_use_set = perform_scan(sdfg) + + assert len(single_use_set) == 1 + assert single_use_set[sdfg] == expected_single_use_set + + +if __name__ == '__main__': + test_all_single_use_data_but_one_unused() + test_multiple_access_same_state() + test_multiple_single_access_node_same_state_sdfg() + test_multiple_access_different_states() + test_access_only_on_interstate_edge() + test_additional_access_on_interstate_edge() + test_access_nested_sdfg() + test_conditional_block()