From eeb772a0589b90b31e1ab1e415f678f6bfa05920 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Fri, 6 Sep 2024 12:00:42 +0200 Subject: [PATCH 01/10] add if extraction transformation --- dace/transformation/interstate/__init__.py | 1 + .../interstate/if_extraction.py | 136 ++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 dace/transformation/interstate/if_extraction.py diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index b8bcc716e6..c7e71506e1 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -16,3 +16,4 @@ from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG from .move_assignment_outside_if import MoveAssignmentOutsideIf +from .if_extraction import IfExtraction diff --git a/dace/transformation/interstate/if_extraction.py b/dace/transformation/interstate/if_extraction.py new file mode 100644 index 0000000000..6c9c47728f --- /dev/null +++ b/dace/transformation/interstate/if_extraction.py @@ -0,0 +1,136 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" If extraction transformation """ + +from dace import data as dt, sdfg as sd +from dace.sdfg import utils as sdutil +from dace.sdfg.state import SDFGState +from dace.transformation import transformation +from copy import deepcopy +from dace.properties import make_properties + + +def eliminate_branch(sdfg: sd.SDFG, initial_state: sd.SDFGState): + state_list = [initial_state] + while len(state_list) > 0: + new_state_list = [] + for s in state_list: + for e in sdfg.out_edges(s): + if len(sdfg.in_edges(e.dst)) == 1: + new_state_list.append(e.dst) + sdfg.remove_node(s) + state_list = new_state_list + + +@make_properties +class IfExtraction(transformation.MultiStateTransformation): + """ + Detects an if statement as the root of a nested sdfg, and extracts it by computing it in the outer sdfg and + replicating the state containing the nested sdfg + """ + + + root_state = transformation.PatternNode(sd.SDFGState) + + + @staticmethod + def annotates_memlets(): + return False + + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.root_state)] + + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + root_state: SDFGState = self.root_state + + out_edges = graph.out_edges(root_state) + in_edges = graph.in_edges(root_state) + + if len(in_edges) > 0: + return False + + if len(out_edges) != 2: + return False + + # needs to be a nested sdfg + if not sdfg.parent: + return False + + # check if edges can be moved out (used symbols can be computed in the outer scope) + if_symbols = set(s for e in out_edges for s in e.data.free_symbols) + if not if_symbols.issubset(sdfg.parent_nsdfg_node.symbol_mapping.keys()): + return False + + return True + + + def apply(self, _, if_sdfg: sd.SDFG): + if_root_state: SDFGState = self.root_state + if_branch: SDFGState = if_sdfg.parent + outer_sdfg: sd.SDFG = if_branch.sdfg + if_nested_sdfg_node: sd.nodes.NestedSDFG = if_sdfg.parent_nsdfg_node + + if_edge, else_edge = if_sdfg.out_edges(if_root_state) + + # create new state to perform the if, and have it replace the state containing the nested SDFG + new_state = outer_sdfg.add_state() + sdutil.change_edge_dest(outer_sdfg, if_branch, new_state) + + # take the old state as the if branch, and create a copy to act as the else branch + else_branch = sd.SDFGState.from_json(if_branch.to_json(), context={'sdfg': outer_sdfg}) + else_branch.label = dt.find_new_name(else_branch.label, outer_sdfg._labels) + outer_sdfg.add_node(else_branch) + + # find the corresponding elements in the new state + else_nested_sdfg_node = None + for n in else_branch.nodes(): + if n.label == if_nested_sdfg_node.label: + else_nested_sdfg_node = n + break + else_sdfg = else_nested_sdfg_node.sdfg + + else_root_state = None + for s in else_nested_sdfg_node.sdfg.states(): + if s.label == if_root_state.label: + else_root_state = s + break + + # delete the else subgraph in the if state + eliminate_branch(if_sdfg, if_sdfg.out_edges(if_root_state)[1].dst) + # optimization: delete new base state if useless + new_base_state = if_sdfg.out_edges(if_root_state)[0].dst + if len(new_base_state.nodes()) == 0 and len(if_sdfg.out_edges(new_base_state)) == 1: + out_edge = if_sdfg.out_edges(new_base_state)[0] + if len(out_edge.data.assignments) == 0 and out_edge.data.is_unconditional(): + if_sdfg.remove_node(new_base_state) + if_sdfg.remove_node(if_root_state) + + # do the opposite for else state + eliminate_branch(else_sdfg, else_sdfg.out_edges(else_root_state)[0].dst) + new_base_state = else_sdfg.out_edges(else_root_state)[0].dst + if len(new_base_state.nodes()) == 0 and len(else_sdfg.out_edges(new_base_state)) == 1: + out_edge = else_sdfg.out_edges(new_base_state)[0] + if len(out_edge.data.assignments) == 0 and out_edge.data.is_unconditional(): + else_sdfg.remove_node(new_base_state) + else_sdfg.remove_node(else_root_state) + + # connect the if and else state + if_edge.data.replace_dict(if_nested_sdfg_node.symbol_mapping) + else_edge.data.replace_dict(if_nested_sdfg_node.symbol_mapping) + + # translate interstate edge assignemnts to symbol mappings + if_nested_sdfg_node.symbol_mapping.update(if_edge.data.assignments) + else_nested_sdfg_node.symbol_mapping.update(else_edge.data.assignments) + + # connect everyting + outer_sdfg.add_edge(new_state, if_branch, sd.InterstateEdge(if_edge.data.condition)) + outer_sdfg.add_edge(new_state, else_branch, sd.InterstateEdge(else_edge.data.condition)) + + # make sure the sdfg is valid + if len(outer_sdfg.out_edges(if_branch)) == 0: + outer_sdfg.add_state_after(if_branch) + + for e in outer_sdfg.out_edges(if_branch): + outer_sdfg.add_edge(else_branch, e.dst, sd.InterstateEdge(e.data.condition, e.data.assignments)) From 5cf7df9df026fbdf58483d356a6b4d8c5ff53560 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Fri, 6 Sep 2024 14:16:10 +0200 Subject: [PATCH 02/10] update symbol checking logic --- .../interstate/if_extraction.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/dace/transformation/interstate/if_extraction.py b/dace/transformation/interstate/if_extraction.py index 6c9c47728f..55f37ef123 100644 --- a/dace/transformation/interstate/if_extraction.py +++ b/dace/transformation/interstate/if_extraction.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ If extraction transformation """ -from dace import data as dt, sdfg as sd +from dace import data as dt, sdfg as sd, symbolic from dace.sdfg import utils as sdutil from dace.sdfg.state import SDFGState from dace.transformation import transformation @@ -58,9 +58,24 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not sdfg.parent: return False + nested_sdfg: sd.nodes.NestedSDFG = sdfg.parent_nsdfg_node + parent_sdfg: sd.SDFG = sdfg.parent.sdfg + + # collect outer symbols used in the interstate edges outgoing the if guard + if_symbols = set(str(nested) for e in out_edges for s in e.data.free_symbols + for nested in symbolic.pystr_to_symbolic(nested_sdfg.symbol_mapping[s]).free_symbols) + + # collect symbols available to state containing the nested sdfg + available_symbols = parent_sdfg.symbols.keys() | parent_sdfg.arglist().keys() + for desc in parent_sdfg.arrays.values(): + available_symbols |= {str(s) for s in desc.free_symbols} + + start_state = sdfg.start_state + for e in sdfg.predecessor_state_transitions(start_state): + available_symbols |= e.data.new_symbols(sdfg, available_symbols).keys() + # check if edges can be moved out (used symbols can be computed in the outer scope) - if_symbols = set(s for e in out_edges for s in e.data.free_symbols) - if not if_symbols.issubset(sdfg.parent_nsdfg_node.symbol_mapping.keys()): + if not if_symbols.issubset(available_symbols): return False return True From cc5c10e22fe13c16cd643435b61e39d4856d3870 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Fri, 6 Sep 2024 14:24:31 +0200 Subject: [PATCH 03/10] check if symbol is written in the state containing the nested sdfg --- dace/transformation/interstate/if_extraction.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dace/transformation/interstate/if_extraction.py b/dace/transformation/interstate/if_extraction.py index 55f37ef123..71a886e61f 100644 --- a/dace/transformation/interstate/if_extraction.py +++ b/dace/transformation/interstate/if_extraction.py @@ -74,10 +74,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): for e in sdfg.predecessor_state_transitions(start_state): available_symbols |= e.data.new_symbols(sdfg, available_symbols).keys() - # check if edges can be moved out (used symbols can be computed in the outer scope) + # check if used symbols can be computed in the outer scope if not if_symbols.issubset(available_symbols): return False + # check if symbols are not written in the state containing the nested sdfg + _, wset = sdfg.parent.read_and_write_sets() + if len(if_symbols.intersection(wset)) != 0: + return False + return True From 8095aeb0b3ecd3abb09288e1f05f0d4f5b7a4638 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Fri, 13 Sep 2024 11:40:42 +0200 Subject: [PATCH 04/10] add tests --- tests/transformations/if_extraction_test.py | 53 +++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/transformations/if_extraction_test.py diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py new file mode 100644 index 0000000000..b6338f2a7e --- /dev/null +++ b/tests/transformations/if_extraction_test.py @@ -0,0 +1,53 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.transformation.interstate import IfExtraction +from dace.sdfg.nodes import NestedSDFG + +N = dace.symbol('N', dtype=dace.int32) + +@dace.program +def simple_application(flag: dace.bool, arr: dace.float32[N]): + for i in dace.map[0:N]: + if flag: + outval = 1 + else: + outval = 2 + arr[i] = outval + + +@dace.program +def dependant_application(flag: dace.bool, arr: dace.float32[N]): + for i in dace.map[0:N]: + if i == 0: + outval = 1 + else: + outval = 2 + arr[i] = outval + + +def test_simple_application(): + sdfg = simple_application.to_sdfg() + + assert len(sdfg.nodes()) == 1 + + assert sdfg.apply_transformations_repeated([IfExtraction]) == 1 + + assert len(sdfg.nodes()) == 4 + assert sdfg.start_state.is_empty() + + sdfg.simplify() + + for s in sdfg.nodes(): + for n in s.nodes(): + assert not isinstance(n, NestedSDFG) + +def test_fails_due_to_dependency(): + sdfg = dependant_application.to_sdfg() + + assert sdfg.apply_transformations_repeated([IfExtraction]) == 0 + + +if __name__ == '__main__': + test_simple_application() + test_fails_due_to_dependency() From 96706d327c8867076a6460d98b04217c1138508c Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Fri, 13 Sep 2024 15:03:01 +0200 Subject: [PATCH 05/10] fix to_sdfg in test --- tests/transformations/if_extraction_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index b6338f2a7e..bc7a48930d 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -27,7 +27,7 @@ def dependant_application(flag: dace.bool, arr: dace.float32[N]): def test_simple_application(): - sdfg = simple_application.to_sdfg() + sdfg = simple_application.to_sdfg(simplify=True) assert len(sdfg.nodes()) == 1 @@ -43,7 +43,7 @@ def test_simple_application(): assert not isinstance(n, NestedSDFG) def test_fails_due_to_dependency(): - sdfg = dependant_application.to_sdfg() + sdfg = dependant_application.to_sdfg(simplify=True) assert sdfg.apply_transformations_repeated([IfExtraction]) == 0 From b1f10ac57c487f1e1d6a9be6420097ef4c08a519 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Mon, 28 Oct 2024 23:27:36 +0100 Subject: [PATCH 06/10] Add further explanation (especially about explicit assumptions) and some cosmetic rephrasing of expressions. --- .../interstate/if_extraction.py | 190 +++++++++--------- tests/transformations/if_extraction_test.py | 32 ++- 2 files changed, 120 insertions(+), 102 deletions(-) diff --git a/dace/transformation/interstate/if_extraction.py b/dace/transformation/interstate/if_extraction.py index 71a886e61f..52b77d5d14 100644 --- a/dace/transformation/interstate/if_extraction.py +++ b/dace/transformation/interstate/if_extraction.py @@ -1,133 +1,138 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ If extraction transformation """ - -from dace import data as dt, sdfg as sd, symbolic -from dace.sdfg import utils as sdutil +from dace import SDFG, data, InterstateEdge +from dace.properties import make_properties +from dace.sdfg import utils +from dace.sdfg.nodes import NestedSDFG from dace.sdfg.state import SDFGState +from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation -from copy import deepcopy -from dace.properties import make_properties -def eliminate_branch(sdfg: sd.SDFG, initial_state: sd.SDFGState): - state_list = [initial_state] - while len(state_list) > 0: - new_state_list = [] - for s in state_list: - for e in sdfg.out_edges(s): - if len(sdfg.in_edges(e.dst)) == 1: - new_state_list.append(e.dst) +def eliminate_branch(sdfg: SDFG, initial_state: SDFGState): + """ + Eliminates all nodes that are reachable _only_ from `initial_state`. + + Assumptions: + - The topmost level of each branch consists of `SDFGState` states connected by interstate edges. + + Example: + - If we start from `state_1` for the following graph, only `state_1` will be removed. + initial_state + / \\ + state_1 state_2 + \\ / + state_3 + | + terminal_state + - If we start from `state_1` for the following graph, `state_1` and `state_3` will be removed. But after that, + starting from `state_2` will remove the other four intermediate state too. + initial_state + / \\ + state_1 state_2 + | | + state_3 state_5 + \\ / + state_5 + | + state_6 + | + terminal_state + """ + assert len(sdfg.in_edges(initial_state)) == 1 + states_to_remove = {initial_state} + while states_to_remove: + assert all(isinstance(st, SDFGState) for st in states_to_remove) + new_states_to_remove = {e.dst for s in states_to_remove for e in sdfg.out_edges(s) + if len(sdfg.in_edges(e.dst)) == 1} + for s in states_to_remove: sdfg.remove_node(s) - state_list = new_state_list + states_to_remove = new_states_to_remove @make_properties class IfExtraction(transformation.MultiStateTransformation): """ - Detects an if statement as the root of a nested sdfg, and extracts it by computing it in the outer sdfg and - replicating the state containing the nested sdfg + Detects an If statement as the root of a nested SDFG, and if so, extracts it by computing it in the outer SDFG and + replicating the state containing the nested SDFG. """ - - root_state = transformation.PatternNode(sd.SDFGState) - - - @staticmethod - def annotates_memlets(): - return False - + root_state = transformation.PatternNode(SDFGState) @classmethod def expressions(cls): - return [sdutil.node_path_graph(cls.root_state)] - - - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - root_state: SDFGState = self.root_state - - out_edges = graph.out_edges(root_state) - in_edges = graph.in_edges(root_state) + return [utils.node_path_graph(cls.root_state)] - if len(in_edges) > 0: - return False - - if len(out_edges) != 2: - return False - - # needs to be a nested sdfg + def can_be_applied(self, graph, expr_index: int, sdfg, permissive=False): if not sdfg.parent: + # Must be a nested SDFG. return False - nested_sdfg: sd.nodes.NestedSDFG = sdfg.parent_nsdfg_node - parent_sdfg: sd.SDFG = sdfg.parent.sdfg + in_edges, out_edges = graph.in_edges(self.root_state), graph.out_edges(self.root_state) + if not (len(in_edges) == 0 and len(out_edges) == 2): + # Such an If state must have an inverted V shape. + return False - # collect outer symbols used in the interstate edges outgoing the if guard + # Collect outer symbols used in the interstate edges going out of the If guard. if_symbols = set(str(nested) for e in out_edges for s in e.data.free_symbols - for nested in symbolic.pystr_to_symbolic(nested_sdfg.symbol_mapping[s]).free_symbols) + for nested in pystr_to_symbolic(sdfg.parent_nsdfg_node.symbol_mapping[s]).free_symbols) - # collect symbols available to state containing the nested sdfg + # Collect symbols available to state containing the nested SDFG. + parent_sdfg = sdfg.parent.sdfg available_symbols = parent_sdfg.symbols.keys() | parent_sdfg.arglist().keys() for desc in parent_sdfg.arrays.values(): available_symbols |= {str(s) for s in desc.free_symbols} - - start_state = sdfg.start_state - for e in sdfg.predecessor_state_transitions(start_state): + for e in sdfg.predecessor_state_transitions(sdfg.start_state): available_symbols |= e.data.new_symbols(sdfg, available_symbols).keys() - # check if used symbols can be computed in the outer scope if not if_symbols.issubset(available_symbols): + # The symbols used on the branch must be computable in the outer scope. return False - # check if symbols are not written in the state containing the nested sdfg _, wset = sdfg.parent.read_and_write_sets() - if len(if_symbols.intersection(wset)) != 0: + if if_symbols.intersection(wset): + # The symbols used on the branch must not be written in the parent state of the nested SDFG. return False return True - - def apply(self, _, if_sdfg: sd.SDFG): + def apply(self, graph: SDFGState, sdfg: SDFG): if_root_state: SDFGState = self.root_state - if_branch: SDFGState = if_sdfg.parent - outer_sdfg: sd.SDFG = if_branch.sdfg - if_nested_sdfg_node: sd.nodes.NestedSDFG = if_sdfg.parent_nsdfg_node + if_branch: SDFGState = sdfg.parent + outer_sdfg: SDFG = if_branch.sdfg + if_nested_sdfg_node: NestedSDFG = sdfg.parent_nsdfg_node - if_edge, else_edge = if_sdfg.out_edges(if_root_state) + if_edge, else_edge = sdfg.out_edges(if_root_state) - # create new state to perform the if, and have it replace the state containing the nested SDFG + # Create new state to perform the If, and have it replace the state containing the nested SDFG. new_state = outer_sdfg.add_state() - sdutil.change_edge_dest(outer_sdfg, if_branch, new_state) + utils.change_edge_dest(outer_sdfg, if_branch, new_state) - # take the old state as the if branch, and create a copy to act as the else branch - else_branch = sd.SDFGState.from_json(if_branch.to_json(), context={'sdfg': outer_sdfg}) - else_branch.label = dt.find_new_name(else_branch.label, outer_sdfg._labels) + # Take the old state as the If branch, and create a copy to act as the else branch. + else_branch = SDFGState.from_json(if_branch.to_json(), context={'sdfg': outer_sdfg}) + else_branch.label = data.find_new_name(else_branch.label, outer_sdfg._labels) outer_sdfg.add_node(else_branch) - # find the corresponding elements in the new state - else_nested_sdfg_node = None - for n in else_branch.nodes(): - if n.label == if_nested_sdfg_node.label: - else_nested_sdfg_node = n - break - else_sdfg = else_nested_sdfg_node.sdfg - - else_root_state = None - for s in else_nested_sdfg_node.sdfg.states(): - if s.label == if_root_state.label: - else_root_state = s - break - - # delete the else subgraph in the if state - eliminate_branch(if_sdfg, if_sdfg.out_edges(if_root_state)[1].dst) - # optimization: delete new base state if useless - new_base_state = if_sdfg.out_edges(if_root_state)[0].dst - if len(new_base_state.nodes()) == 0 and len(if_sdfg.out_edges(new_base_state)) == 1: - out_edge = if_sdfg.out_edges(new_base_state)[0] + # Find the corresponding elements in the new state. + else_nested_sdfg_node = [n for n in else_branch.nodes() if n.label == if_nested_sdfg_node.label] + assert len(else_nested_sdfg_node) == 1 + else_nested_sdfg_node = else_nested_sdfg_node[0] + else_root_state = [s for s in else_nested_sdfg_node.sdfg.states() if s.label == if_root_state.label] + assert len(else_root_state) == 1 + else_root_state = else_root_state[0] + + # Delete the else subgraph in the If state. + eliminate_branch(sdfg, sdfg.out_edges(if_root_state)[1].dst) + # Optimization: Delete new base state if useless. + new_base_state = sdfg.out_edges(if_root_state)[0].dst + if not new_base_state.nodes() and len(sdfg.out_edges(new_base_state)) == 1: + out_edge = sdfg.out_edges(new_base_state)[0] if len(out_edge.data.assignments) == 0 and out_edge.data.is_unconditional(): - if_sdfg.remove_node(new_base_state) - if_sdfg.remove_node(if_root_state) + sdfg.remove_node(new_base_state) + sdfg.remove_node(if_root_state) - # do the opposite for else state + # Do the opposite for Else state. + else_sdfg = else_nested_sdfg_node.sdfg eliminate_branch(else_sdfg, else_sdfg.out_edges(else_root_state)[0].dst) new_base_state = else_sdfg.out_edges(else_root_state)[0].dst if len(new_base_state.nodes()) == 0 and len(else_sdfg.out_edges(new_base_state)) == 1: @@ -136,21 +141,20 @@ def apply(self, _, if_sdfg: sd.SDFG): else_sdfg.remove_node(new_base_state) else_sdfg.remove_node(else_root_state) - # connect the if and else state + # Connect the If and Else state. if_edge.data.replace_dict(if_nested_sdfg_node.symbol_mapping) else_edge.data.replace_dict(if_nested_sdfg_node.symbol_mapping) - # translate interstate edge assignemnts to symbol mappings + # Translate interstate edge assignemnts to symbol mappings. if_nested_sdfg_node.symbol_mapping.update(if_edge.data.assignments) else_nested_sdfg_node.symbol_mapping.update(else_edge.data.assignments) - # connect everyting - outer_sdfg.add_edge(new_state, if_branch, sd.InterstateEdge(if_edge.data.condition)) - outer_sdfg.add_edge(new_state, else_branch, sd.InterstateEdge(else_edge.data.condition)) + # Connect everything. + outer_sdfg.add_edge(new_state, if_branch, InterstateEdge(if_edge.data.condition)) + outer_sdfg.add_edge(new_state, else_branch, InterstateEdge(else_edge.data.condition)) - # make sure the sdfg is valid - if len(outer_sdfg.out_edges(if_branch)) == 0: + # Make sure the SDFG is valid. + if not outer_sdfg.out_edges(if_branch): outer_sdfg.add_state_after(if_branch) - for e in outer_sdfg.out_edges(if_branch): - outer_sdfg.add_edge(else_branch, e.dst, sd.InterstateEdge(e.data.condition, e.data.assignments)) + outer_sdfg.add_edge(else_branch, e.dst, InterstateEdge(e.data.condition, e.data.assignments)) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index bc7a48930d..28a6d44117 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -1,4 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import os + +from dace.symbolic import simplify import dace from dace.transformation.interstate import IfExtraction @@ -6,6 +9,7 @@ N = dace.symbol('N', dtype=dace.int32) + @dace.program def simple_application(flag: dace.bool, arr: dace.float32[N]): for i in dace.map[0:N]: @@ -27,21 +31,31 @@ def dependant_application(flag: dace.bool, arr: dace.float32[N]): def test_simple_application(): - sdfg = simple_application.to_sdfg(simplify=True) + g = simple_application.to_sdfg(simplify=False, validate=False, use_cache=False) + g.simplify(verbose=True) # Simplify (for convenience) to get the actual test graph. + g.save(os.path.join('_dacegraphs', 'simple-0.sdfg')) + g.validate() + g.compile() + + # Before, the outer graph had only one nested SDFG. + assert len(g.nodes()) == 1 - assert len(sdfg.nodes()) == 1 + assert g.apply_transformations_repeated([IfExtraction]) == 1 + g.save(os.path.join('_dacegraphs', 'simple-1.sdfg')) + g.validate() + g.compile() - assert sdfg.apply_transformations_repeated([IfExtraction]) == 1 + # But now, the outer graph have four: two copies of the original nested SDFGs and two for branch management. + assert len(g.nodes()) == 4 + assert g.start_state.is_empty() - assert len(sdfg.nodes()) == 4 - assert sdfg.start_state.is_empty() - - sdfg.simplify() - - for s in sdfg.nodes(): + g.simplify() + + for s in g.nodes(): for n in s.nodes(): assert not isinstance(n, NestedSDFG) + def test_fails_due_to_dependency(): sdfg = dependant_application.to_sdfg(simplify=True) From 75eb10ddf32b22343331f8201e09bd3fa552358a Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 00:23:14 +0100 Subject: [PATCH 07/10] Don't rely on `simplify()` for testing and construct the intended graphs as inconvenient as it is. --- tests/transformations/if_extraction_test.py | 69 ++++++++++++++++----- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index 28a6d44117..45bfe9dc7c 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -1,23 +1,47 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import os +from copy import deepcopy -from dace.symbolic import simplify +import numpy as np import dace +from dace import SDFG, InterstateEdge, Memlet from dace.transformation.interstate import IfExtraction -from dace.sdfg.nodes import NestedSDFG N = dace.symbol('N', dtype=dace.int32) -@dace.program -def simple_application(flag: dace.bool, arr: dace.float32[N]): - for i in dace.map[0:N]: - if flag: - outval = 1 - else: - outval = 2 - arr[i] = outval +def make_simple_branched_sdfg(): + # First prepare the map-body. + subg = SDFG('body') + subg.add_array('tmp', (1,), dace.float32) + subg.add_symbol('outval', dace.float32) + ifh = subg.add_state('if_head') + if1 = subg.add_state('if_b1') + if2 = subg.add_state('if_b2') + ift = subg.add_state('if_tail') + subg.add_edge(ifh, if1, InterstateEdge(condition='(flag)', assignments={'outval': 1})) + subg.add_edge(ifh, if2, InterstateEdge(condition='(not flag)', assignments={'outval': 2})) + subg.add_edge(if1, ift, InterstateEdge()) + subg.add_edge(if2, ift, InterstateEdge()) + t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') + tmp = ift.add_access('tmp') + ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + subg.fill_scope_connectors() + + g = SDFG('prog') + g.add_array('A', (10,), dace.float32) + g.add_symbol('flag', dace.bool) + st0 = g.add_state('outer', is_start_block=True) + en, ex = st0.add_map('map', {'i': '0:10'}) + body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'}) + A = st0.add_access('A') + st0.add_memlet_path(en, body, memlet=Memlet()) + st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) + st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + g.fill_scope_connectors() + + return g @dace.program @@ -31,12 +55,19 @@ def dependant_application(flag: dace.bool, arr: dace.float32[N]): def test_simple_application(): - g = simple_application.to_sdfg(simplify=False, validate=False, use_cache=False) - g.simplify(verbose=True) # Simplify (for convenience) to get the actual test graph. + origA = np.zeros((10,), np.float32) + + g = make_simple_branched_sdfg() g.save(os.path.join('_dacegraphs', 'simple-0.sdfg')) g.validate() g.compile() + # Get the expected values. + wantA_1 = deepcopy(origA) + wantA_2 = deepcopy(origA) + g(A=wantA_1, flag=True) + g(A=wantA_2, flag=False) + # Before, the outer graph had only one nested SDFG. assert len(g.nodes()) == 1 @@ -45,15 +76,19 @@ def test_simple_application(): g.validate() g.compile() + # Get the values from transformed program. + gotA_1 = deepcopy(origA) + gotA_2 = deepcopy(origA) + g(A=gotA_1, flag=True) + g(A=gotA_2, flag=False) + # But now, the outer graph have four: two copies of the original nested SDFGs and two for branch management. assert len(g.nodes()) == 4 assert g.start_state.is_empty() - g.simplify() - - for s in g.nodes(): - for n in s.nodes(): - assert not isinstance(n, NestedSDFG) + # Verify numerically. + assert all(np.equal(wantA_1, gotA_1)) + assert all(np.equal(wantA_2, gotA_2)) def test_fails_due_to_dependency(): From ce1f733aad865bba01e3d7af85239ebe10bbc69a Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 00:33:39 +0100 Subject: [PATCH 08/10] Construct the graph for the "non-dependency" test too. --- tests/transformations/if_extraction_test.py | 61 +++++++++++++++------ 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index 45bfe9dc7c..bc0b6d7452 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -8,10 +8,11 @@ from dace import SDFG, InterstateEdge, Memlet from dace.transformation.interstate import IfExtraction -N = dace.symbol('N', dtype=dace.int32) - -def make_simple_branched_sdfg(): +def make_branched_sdfg_that_does_not_depend_on_loop_var(): + """ + Construct a simple SDFG that does not depend on symbols defined or updated in the outer state, e.g., loop variables. + """ # First prepare the map-body. subg = SDFG('body') subg.add_array('tmp', (1,), dace.float32) @@ -29,6 +30,7 @@ def make_simple_branched_sdfg(): ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) subg.fill_scope_connectors() + # Then prepare the parent graph. g = SDFG('prog') g.add_array('A', (10,), dace.float32) g.add_symbol('flag', dace.bool) @@ -44,20 +46,46 @@ def make_simple_branched_sdfg(): return g -@dace.program -def dependant_application(flag: dace.bool, arr: dace.float32[N]): - for i in dace.map[0:N]: - if i == 0: - outval = 1 - else: - outval = 2 - arr[i] = outval +def make_branched_sdfg_that_depends_on_loop_var(): + """ + Construct a simple SDFG that depends on symbols defined or updated in the outer state, e.g., loop variables. + """ + # First prepare the map-body. + subg = SDFG('body') + subg.add_array('tmp', (1,), dace.float32) + subg.add_symbol('outval', dace.float32) + ifh = subg.add_state('if_head') + if1 = subg.add_state('if_b1') + if2 = subg.add_state('if_b2') + ift = subg.add_state('if_tail') + subg.add_edge(ifh, if1, InterstateEdge(condition='(i == 0)', assignments={'outval': 1})) + subg.add_edge(ifh, if2, InterstateEdge(condition='(not (i == 0))', assignments={'outval': 2})) + subg.add_edge(if1, ift, InterstateEdge()) + subg.add_edge(if2, ift, InterstateEdge()) + t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') + tmp = ift.add_access('tmp') + ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + subg.fill_scope_connectors() + + # Then prepare the parent graph. + g = SDFG('prog') + g.add_array('A', (10,), dace.float32) + st0 = g.add_state('outer', is_start_block=True) + en, ex = st0.add_map('map', {'i': '0:10'}) + body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}) + A = st0.add_access('A') + st0.add_memlet_path(en, body, memlet=Memlet()) + st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) + st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + g.fill_scope_connectors() + + return g def test_simple_application(): origA = np.zeros((10,), np.float32) - g = make_simple_branched_sdfg() + g = make_branched_sdfg_that_does_not_depend_on_loop_var() g.save(os.path.join('_dacegraphs', 'simple-0.sdfg')) g.validate() g.compile() @@ -91,12 +119,13 @@ def test_simple_application(): assert all(np.equal(wantA_2, gotA_2)) -def test_fails_due_to_dependency(): - sdfg = dependant_application.to_sdfg(simplify=True) +def test_fails_due_to_dependency_on_loop_var(): + g = make_branched_sdfg_that_depends_on_loop_var() + g.save(os.path.join('_dacegraphs', 'dependent-0.sdfg')) - assert sdfg.apply_transformations_repeated([IfExtraction]) == 0 + assert g.apply_transformations_repeated([IfExtraction]) == 0 if __name__ == '__main__': test_simple_application() - test_fails_due_to_dependency() + test_fails_due_to_dependency_on_loop_var() From 5bdff31f1593e02e93382b968b7447a62e60d45b Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 01:07:05 +0100 Subject: [PATCH 09/10] Add one more test with a more complicated structure (and to cover a tiny bit more of corner cases). --- tests/transformations/if_extraction_test.py | 98 ++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index bc0b6d7452..cb37932849 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -46,6 +46,65 @@ def make_branched_sdfg_that_does_not_depend_on_loop_var(): return g +def make_branched_sdfg_that_has_intermediate_branchlike_structure(): + """ + Construct an SDFG that has this structure: + initial_state + / \\ + state_1 state_2 + | | + state_3 state_5 + \\ / + state_5 + / \ + state_6 state_7 + \\ / + terminal_state + """ + # First prepare the map-body. + subg = SDFG('body') + subg.add_array('tmp', (1,), dace.float32) + subg.add_symbol('outval', dace.float32) + ifh = subg.add_state('if_head') + if1 = subg.add_state('state_1') + if3 = subg.add_state('state_2') + if2 = subg.add_state('state_3') + if4 = subg.add_state('state_4') + if5 = subg.add_state('state_5') + if6 = subg.add_state('state_6') + if7 = subg.add_state('state_7') + ift = subg.add_state('if_tail') + subg.add_edge(ifh, if1, InterstateEdge(condition='(flag)', assignments={'outval': 1})) + subg.add_edge(ifh, if2, InterstateEdge(condition='(not flag)', assignments={'outval': 2})) + subg.add_edge(if1, if3, InterstateEdge()) + subg.add_edge(if3, if5, InterstateEdge()) + subg.add_edge(if2, if4, InterstateEdge()) + subg.add_edge(if4, if5, InterstateEdge()) + subg.add_edge(if5, if6, InterstateEdge()) + subg.add_edge(if5, if7, InterstateEdge()) + subg.add_edge(if6, ift, InterstateEdge()) + subg.add_edge(if7, ift, InterstateEdge()) + t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') + tmp = ift.add_access('tmp') + ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + subg.fill_scope_connectors() + + # Then prepare the parent graph. + g = SDFG('prog') + g.add_array('A', (10,), dace.float32) + g.add_symbol('flag', dace.bool) + st0 = g.add_state('outer', is_start_block=True) + en, ex = st0.add_map('map', {'i': '0:10'}) + body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'}) + A = st0.add_access('A') + st0.add_memlet_path(en, body, memlet=Memlet()) + st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) + st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + g.fill_scope_connectors() + + return g + + def make_branched_sdfg_that_depends_on_loop_var(): """ Construct a simple SDFG that depends on symbols defined or updated in the outer state, e.g., loop variables. @@ -119,7 +178,42 @@ def test_simple_application(): assert all(np.equal(wantA_2, gotA_2)) -def test_fails_due_to_dependency_on_loop_var(): +def test_extracts_even_with_intermediate_branchlike_structure(): + origA = np.zeros((10,), np.float32) + + g = make_branched_sdfg_that_has_intermediate_branchlike_structure() + g.save(os.path.join('_dacegraphs', 'intermediate_branch-0.sdfg')) + g.validate() + g.compile() + + # Get the expected values. + wantA_1 = deepcopy(origA) + wantA_2 = deepcopy(origA) + g(A=wantA_1, flag=True) + g(A=wantA_2, flag=False) + + # Before, the outer graph had only one nested SDFG. + assert len(g.nodes()) == 1 + + assert g.apply_transformations_repeated([IfExtraction]) == 1 + g.save(os.path.join('_dacegraphs', 'intermediate_branch-1.sdfg')) + + # Get the values from transformed program. + gotA_1 = deepcopy(origA) + gotA_2 = deepcopy(origA) + g(A=gotA_1, flag=True) + g(A=gotA_2, flag=False) + + # But now, the outer graph have four: two copies of the original nested SDFGs and two for branch management. + assert len(g.nodes()) == 4 + assert g.start_state.is_empty() + + # Verify numerically. + assert all(np.equal(wantA_1, gotA_1)) + assert all(np.equal(wantA_2, gotA_2)) + + +def test_no_extraction_due_to_dependency_on_loop_var(): g = make_branched_sdfg_that_depends_on_loop_var() g.save(os.path.join('_dacegraphs', 'dependent-0.sdfg')) @@ -128,4 +222,4 @@ def test_fails_due_to_dependency_on_loop_var(): if __name__ == '__main__': test_simple_application() - test_fails_due_to_dependency_on_loop_var() + test_no_extraction_due_to_dependency_on_loop_var() From eb6b83853fb812a72eb98cc04dedbd1c936f3f9f Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 00:01:49 +0100 Subject: [PATCH 10/10] Replace `add_memlet_path()` with `add_edge()`. --- tests/transformations/if_extraction_test.py | 47 ++++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index cb37932849..e149a6a351 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -1,14 +1,29 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import os from copy import deepcopy +from typing import Dict, Collection import numpy as np import dace -from dace import SDFG, InterstateEdge, Memlet +from dace import SDFG, InterstateEdge, Memlet, SDFGState from dace.transformation.interstate import IfExtraction +def _add_map_with_connectors(st: SDFGState, name: str, ndrange: Dict[str, str], + en_conn_bases: Collection[str] = None, ex_conn_bases: Collection[str] = None): + en, ex = st.add_map(name, ndrange) + if en_conn_bases: + for c in en_conn_bases: + en.add_in_connector(f"IN_{c}") + en.add_out_connector(f"OUT_{c}") + if ex_conn_bases: + for c in ex_conn_bases: + ex.add_in_connector(f"IN_{c}") + ex.add_out_connector(f"OUT_{c}") + return en, ex + + def make_branched_sdfg_that_does_not_depend_on_loop_var(): """ Construct a simple SDFG that does not depend on symbols defined or updated in the outer state, e.g., loop variables. @@ -27,7 +42,7 @@ def make_branched_sdfg_that_does_not_depend_on_loop_var(): subg.add_edge(if2, ift, InterstateEdge()) t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') tmp = ift.add_access('tmp') - ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + ift.add_edge(t0, '__out', tmp, None, Memlet(expr='tmp[0]')) subg.fill_scope_connectors() # Then prepare the parent graph. @@ -35,12 +50,12 @@ def make_branched_sdfg_that_does_not_depend_on_loop_var(): g.add_array('A', (10,), dace.float32) g.add_symbol('flag', dace.bool) st0 = g.add_state('outer', is_start_block=True) - en, ex = st0.add_map('map', {'i': '0:10'}) + en, ex = _add_map_with_connectors(st0, 'map', {'i': '0:10'}, [], ['A']) body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'}) A = st0.add_access('A') - st0.add_memlet_path(en, body, memlet=Memlet()) - st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) - st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + st0.add_nedge(en, body, Memlet()) + st0.add_edge(body, 'tmp', ex, 'IN_A', Memlet(expr='A[i]')) + st0.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:10]')) g.fill_scope_connectors() return g @@ -86,7 +101,7 @@ def make_branched_sdfg_that_has_intermediate_branchlike_structure(): subg.add_edge(if7, ift, InterstateEdge()) t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') tmp = ift.add_access('tmp') - ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + ift.add_edge(t0, '__out', tmp, None, Memlet(expr='tmp[0]')) subg.fill_scope_connectors() # Then prepare the parent graph. @@ -94,12 +109,12 @@ def make_branched_sdfg_that_has_intermediate_branchlike_structure(): g.add_array('A', (10,), dace.float32) g.add_symbol('flag', dace.bool) st0 = g.add_state('outer', is_start_block=True) - en, ex = st0.add_map('map', {'i': '0:10'}) + en, ex = _add_map_with_connectors(st0, 'map', {'i': '0:10'}, [], ['A']) body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'}) A = st0.add_access('A') - st0.add_memlet_path(en, body, memlet=Memlet()) - st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) - st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + st0.add_nedge(en, body, Memlet()) + st0.add_edge(body, 'tmp', ex, 'IN_A', Memlet(expr='A[i]')) + st0.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:10]')) g.fill_scope_connectors() return g @@ -123,19 +138,19 @@ def make_branched_sdfg_that_depends_on_loop_var(): subg.add_edge(if2, ift, InterstateEdge()) t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') tmp = ift.add_access('tmp') - ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + ift.add_edge(t0, '__out', tmp, None, Memlet(expr='tmp[0]')) subg.fill_scope_connectors() # Then prepare the parent graph. g = SDFG('prog') g.add_array('A', (10,), dace.float32) st0 = g.add_state('outer', is_start_block=True) - en, ex = st0.add_map('map', {'i': '0:10'}) + en, ex = _add_map_with_connectors(st0, 'map', {'i': '0:10'}, [], ['A']) body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}) A = st0.add_access('A') - st0.add_memlet_path(en, body, memlet=Memlet()) - st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) - st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + st0.add_nedge(en, body, Memlet()) + st0.add_edge(body, 'tmp', ex, 'IN_A', Memlet(expr='A[i]')) + st0.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:10]')) g.fill_scope_connectors() return g