diff --git a/dace/subsets.py b/dace/subsets.py index 0fdc36c22e..9c79e7d7d1 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -80,7 +80,7 @@ def covers(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) if not Config.get('optimizer', 'symbolic_positive'): @@ -106,7 +106,7 @@ def covers_precise(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) # If self does not cover other with a bounding box union, return false. diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 0c74842634..a195c0fab2 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -58,7 +58,10 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, # If we have an SDFG, recurse into graphs graph_or_subgraph.simplify(validate_all=validate_all) # MapFusion for trivial cases - graph_or_subgraph.apply_transformations_repeated(MapFusion, validate_all=validate_all) + graph_or_subgraph.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate_all=validate_all, + ) # recurse into graphs for graph in graph_or_subgraph.nodes(): @@ -76,7 +79,10 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, sdfg, graph, subgraph = None, None, None if isinstance(graph_or_subgraph, SDFGState): sdfg = graph_or_subgraph.parent - sdfg.apply_transformations_repeated(MapFusion, validate_all=validate_all) + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=True), + validate_all=validate_all, + ) graph = graph_or_subgraph subgraph = SubgraphView(graph, graph.nodes()) else: diff --git a/dace/transformation/dataflow/buffer_tiling.py b/dace/transformation/dataflow/buffer_tiling.py index a418e167d8..af966d8a32 100644 --- a/dace/transformation/dataflow/buffer_tiling.py +++ b/dace/transformation/dataflow/buffer_tiling.py @@ -98,7 +98,13 @@ def apply(self, graph, sdfg): # Fuse maps some_buffer = next(iter(buffers)) # some dummy to pass to MapFusion.apply_to() - MapFusion.apply_to(sdfg, first_map_exit=tile_map1_exit, array=some_buffer, second_map_entry=tile_map2_entry) + MapFusion.apply_to( + sdfg, + first_map_exit=tile_map1_exit, + array=some_buffer, + second_map_entry=tile_map2_entry, + verify=True, + ) # Optimize the simple cases map1_entry.range.ranges = [ diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index a6762d45c4..dccebe9727 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -1,537 +1,1777 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" This module contains classes that implement the map fusion transformation. -""" - -from copy import deepcopy as dcpy -from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState -from dace import data, dtypes, symbolic, subsets -from dace.sdfg import nodes -from dace.memlet import Memlet -from dace.sdfg import replace -from dace.sdfg import utils as sdutil -from dace.transformation import transformation -from typing import List, Union -import networkx as nx +"""Implements the serial map fusing transformation.""" +import copy +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Iterable +import dace +from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.transformation import helpers + + +@properties.make_properties class MapFusion(transformation.SingleStateTransformation): - """ Implements the MapFusion transformation. - It wil check for all patterns MapExit -> AccessNode -> MapEntry, and - based on the following rules, fuse them and remove the transient in - between. There are several possibilities of what it does to this - transient in between. - - Essentially, if there is some other place in the - sdfg where it is required, or if it is not a transient, then it will - not be removed. In such a case, it will be linked to the MapExit node - of the new fused map. - - Rules for fusing maps: - 0. The map range of the second map should be a permutation of the - first map range. - 1. Each of the access nodes that are adjacent to the first map exit - should have an edge to the second map entry. If it doesn't, then the - second map entry should not be reachable from this access node. - 2. Any node that has a wcr from the first map exit should not be - adjacent to the second map entry. - 3. Access pattern for the access nodes in the second map should be - the same permutation of the map parameters as the map ranges of the - two maps. Alternatively, this access node should not be adjacent to - the first map entry. + """Implements the MapFusion transformation. + + + From a high level perspective it will remove the MapExit node for the first and the MapEntry node of + the second Map. Then it will rewire and modify the Memlets to bypass the intermediate node and instead + go through a new intermediate node. This new intermediate node is much smaller because it has no longer + to absorb the whole output of the first map, but only the data that is produced by a single iteration + of the first map. It is important to note that it is not always possible to fully remove the intermediate + node, for example it is used somewhere else, see `is_shared_data()`. Thus by merging the two Maps together + the transformation will reduce the memory footprint because the intermediate nodes can be removed. + An example would be the following: + ```python + for i in range(N): + T[i] = foo(A[i]) + for j in range(N): + B[j] = bar(T[i]) + ``` + which would be translated into: + ```python + for i in range(N): + temp: scalar = foo(A[i]) + B[i] = bar(temp) + ``` + + The checks that two Maps can be fused are quite involved, however, they essentially check: + * If the two Maps cover the same iteration space, essentially have the same start, stop and + iteration , see `find_parameter_remapping()`. + * Furthermore, they verify if the new fused Map did not introduce read write conflict, + essentially it tests if the data is pointwise, i.e. what is read is also written, + see `has_read_write_dependency()`. + * Then it will examine the intermediate data. This will essentially test if the data that + is needed by a single iteration of the second Map is produced by a single iteration of + the first Map, see `partition_first_outputs()`. + + By default `strict_dataflow` is enabled. In this mode the transformation is + more conservative. The main difference is, that it will not adjust the + subsets of the intermediate, i.e. turning an array with shape `(1, 1, 1, 1)` + into a scalar. + Furthermore, shared intermediates, see `partition_first_outputs()` will only + be created if the data is not referred downstream in the dataflow. + + :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + :param only_toplevel_maps: Only consider Maps that are at the top. + :param strict_dataflow: Which dataflow mode should be used, see above. + :param assume_always_shared: Assume that all intermediates are shared. + + :note: This transformation modifies more nodes than it matches. + :note: An instance of MapFusion can be reused multiple times, with one exception. + Because the test if an intermediate can be removed or not is very expensive, + the transformation computes this information once in the beginning and then + caches it. However, the transformation lacks the means to detect if this data + has become out of data. Thus if new AccessNodes are added the cache is outdated + and the transformation should no longer be used. + :note: If `assume_always_shared` is `True` then the transformation will assume that + all intermediates are shared. This avoids the problems mentioned above with + the cache at the expense of the creation of dead dataflow. """ - first_map_exit = transformation.PatternNode(nodes.ExitNode) - array = transformation.PatternNode(nodes.AccessNode) - second_map_entry = transformation.PatternNode(nodes.EntryNode) - @staticmethod - def annotates_memlets(): - return False + # Pattern Nodes + first_map_exit = transformation.transformation.PatternNode(nodes.MapExit) + array = transformation.transformation.PatternNode(nodes.AccessNode) + second_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + + # Settings + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + strict_dataflow = properties.Property( + dtype=bool, + default=True, + desc="If `True` then the transformation will ensure a more stricter data flow.", + ) + assume_always_shared = properties.Property( + dtype=bool, + default=False, + desc="If `True` then all intermediates will be classified as shared.", + ) + + # Maps SDFGs to the set of data that can not be removed, + # because they transmit data _between states_, such data will be made 'shared'. + # This variable acts as a cache, and is managed by 'is_shared_data()'. + _shared_data: Dict[SDFG, Set[str]] + + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = only_toplevel_maps + if only_inner_maps is not None: + self.only_inner_maps = only_inner_maps + if strict_dataflow is not None: + self.strict_dataflow = strict_dataflow + self._shared_data = {} + @classmethod - def expressions(cls): - return [sdutil.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)] - - @staticmethod - def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[int], None]: - """ Find permutation between two map ranges. - - :param first_map: First map. - :param second_map: Second map. - :return: None if no such permutation exists, otherwise a list of - indices L such that L[x]'th parameter of second map has the same range as x'th - parameter of the first map. + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. """ - result = [] - - if len(first_map.range) != len(second_map.range): - return None - - # Match map ranges with reduce ranges - for i, tmap_rng in enumerate(first_map.range): - found = False - for j, rng in enumerate(second_map.range): - if tmap_rng == rng and j not in result: - result.append(j) - found = True - break - if not found: - break - - # Ensure all map ranges matched - if len(result) != len(first_map.range): - return None + return [dace.sdfg.utils.node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry)] + + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + * Checks general requirements, see `can_topologically_be_fused()`. + * Tests if there are read write dependencies. + * Tests if the decomposition exists. + """ + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg + ) + if param_repl is None: + return False - return result - - def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): - first_map_exit = self.first_map_exit - first_map_entry = graph.entry_node(first_map_exit) - second_map_entry = self.second_map_entry - second_map_exit = graph.exit_node(second_map_entry) - - for _in_e in graph.in_edges(first_map_exit): - if _in_e.data.wcr is not None: - for _out_e in graph.out_edges(second_map_entry): - if _out_e.data.data == _in_e.data.data: - # wcr is on a node that is used in the second map, quit - return False - # Check whether there is a pattern map -> access -> map. - intermediate_nodes = set() - intermediate_data = set() - for _, _, dst, _, _ in graph.out_edges(first_map_exit): - if isinstance(dst, nodes.AccessNode): - intermediate_nodes.add(dst) - intermediate_data.add(dst.data) - - # If array is used anywhere else in this state. - num_occurrences = len([n for n in sdfg.data_nodes() if n.data == dst.data]) - if num_occurrences > 1: - return False - else: - return False - # Check map ranges - perm = self.find_permutation(first_map_entry.map, second_map_entry.map) - if perm is None: + # Tests if there are read write dependencies that are caused by the bodies + # of the Maps, such as referring to the same data. Note that this tests are + # different from the ones performed by `has_read_write_dependency()`, which + # only checks the data dependencies that go through the scope nodes. + if self.has_inner_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + state=graph, + sdfg=sdfg, + ): return False - # Check if any intermediate transient is also going to another location - second_inodes = set(e.src for e in graph.in_edges(second_map_entry) if isinstance(e.src, nodes.AccessNode)) - transients_to_remove = intermediate_nodes & second_inodes - # if any(e.dst != second_map_entry for n in transients_to_remove - # for e in graph.out_edges(n)): - if any(graph.out_degree(n) > 1 for n in transients_to_remove): + # Tests for read write conflicts of the two maps, this is only checking + # the data that goes through the scope nodes. `has_inner_read_write_dependency()` + # if used to check if there are internal dependencies. + if self.has_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + param_repl=param_repl, + state=graph, + sdfg=sdfg, + ): return False - # Create a dict that maps parameters of the first map to those of the - # second map. - params_dict = {} - for _index, _param in enumerate(second_map_entry.map.params): - params_dict[_param] = first_map_entry.map.params[perm[_index]] + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=param_repl, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True - out_memlets = [e.data for e in graph.in_edges(first_map_exit)] - # Check that input set of second map is provided by the output set - # of the first map, or other unrelated maps - for second_edge in graph.out_edges(second_map_entry): - # NOTE: We ignore edges that do not carry data (e.g., connecting a tasklet with no inputs to the MapEntry) - if second_edge.data.is_empty(): - continue - # Memlets that do not come from one of the intermediate arrays - if second_edge.data.data not in intermediate_data: - # however, if intermediate_data eventually leads to - # second_memlet.data, need to fail. - for _n in intermediate_nodes: - source_node = _n - destination_node = graph.memlet_path(second_edge)[0].src - # NOTE: Assumes graph has networkx version - if destination_node in nx.descendants(graph._nx, source_node): - return False - continue + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. - provided = False + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. - # Compute second subset with respect to first subset's symbols - sbs_permuted = dcpy(second_edge.data.subset) - if sbs_permuted: - # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} - symbolic.safe_replace(params_dict, lambda m: sbs_permuted.replace(m)) + By assumption we do not have to rename anything. - for first_memlet in out_memlets: - if first_memlet.data != second_edge.data.data: - continue + :param graph: The SDFG state we are operating on. + :param sdfg: The SDFG we are operating on. + """ + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(self.second_map_entry) + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=first_map_exit.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, + state=graph, + ) - # If there is a covered subset, it is provided - if first_memlet.subset.covers(sbs_permuted): - provided = True - break + # Now compute the partition. Because we have already renamed the parameters + # of the second Map, there is no need to perform any renaming, thus we can + # pass an empty `dict`. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=dict(), + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + # Now perform the actual rewiring, we handle each partition separately. + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(first_map_exit)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=first_map_exit, + to_node=second_map_exit, + state=graph, + sdfg=sdfg, + ) - # If none of the output memlets of the first map provide the info, - # fail. - if provided is False: - return False + # Now move the input of the second map, that has no connection to the first + # map, to the first map. This is needed because we will later delete the + # exit of the first map (which we have essentially handled above). Now + # we must handle the input of the second map (that has no connection to the + # first map) to the input of the first map. + self.relocate_nodes( + from_node=second_map_entry, + to_node=first_map_entry, + state=graph, + sdfg=sdfg, + ) - # Checking for stencil pattern and common input/output data - # (after fusing the maps) - first_map_inputnodes = { - e.src: e.src.data - for e in graph.in_edges(first_map_entry) if isinstance(e.src, nodes.AccessNode) - } - input_views = set() - viewed_inputnodes = dict() - for n in first_map_inputnodes.keys(): - if isinstance(n.desc(sdfg), data.View): - input_views.add(n) - for v in input_views: - del first_map_inputnodes[v] - e = sdutil.get_view_edge(graph, v) - if e: - while not isinstance(e.src, nodes.AccessNode): - e = graph.memlet_path(e)[0] - first_map_inputnodes[e.src] = e.src.data - viewed_inputnodes[e.src.data] = v - second_map_outputnodes = { - e.dst: e.dst.data - for e in graph.out_edges(second_map_exit) if isinstance(e.dst, nodes.AccessNode) - } - output_views = set() - viewed_outputnodes = dict() - for n in second_map_outputnodes: - if isinstance(n.desc(sdfg), data.View): - output_views.add(n) - for v in output_views: - del second_map_outputnodes[v] - e = sdutil.get_view_edge(graph, v) - if e: - while not isinstance(e.dst, nodes.AccessNode): - e = graph.memlet_path(e)[-1] - second_map_outputnodes[e.dst] = e.dst.data - viewed_outputnodes[e.dst.data] = v - common_data = set(first_map_inputnodes.values()).intersection(set(second_map_outputnodes.values())) - if common_data: - input_data = [viewed_inputnodes[d].data if d in viewed_inputnodes.keys() else d for d in common_data] - input_accesses = [ - graph.memlet_path(e)[-1].data.src_subset for e in graph.out_edges(first_map_entry) - if e.data.data in input_data - ] - if len(input_accesses) > 1: - for i, a in enumerate(input_accesses[:-1]): - for b in input_accesses[i + 1:]: - if isinstance(a, subsets.Indices): - c = subsets.Range.from_indices(a) - c.offset(b, negative=True) - else: - c = a.offset_new(b, negative=True) - for r in c: - if r != (0, 0, 1): - return False - - output_data = [viewed_outputnodes[d].data if d in viewed_outputnodes.keys() else d for d in common_data] - output_accesses = [ - graph.memlet_path(e)[0].data.dst_subset for e in graph.in_edges(second_map_exit) - if e.data.data in output_data - ] - - # Compute output accesses with respect to first map's symbols - oacc_permuted = [dcpy(a) for a in output_accesses] - for a in oacc_permuted: - # Create intermediate dicts to avoid conflicts, such as {i:j, j:i} - symbolic.safe_replace(params_dict, lambda m: a.replace(m)) - - a = input_accesses[0] - for b in oacc_permuted: - if isinstance(a, subsets.Indices): - c = subsets.Range.from_indices(a) - c.offset(b, negative=True) - else: - c = a.offset_new(b, negative=True) - for r in c: - if r != (0, 0, 1): - return False + for node_to_remove in [first_map_exit, second_map_entry]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + second_map_exit.map = first_map_entry.map + + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `first_map_exit` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + * Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + * Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + * Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + :return: If such a decomposition exists the function will return the three sets + mentioned above in the same order. In case the decomposition does not exist, + i.e. the maps can not be fused the function returns `None`. + + :param state: The in which the two maps are located. + :param sdfg: The full SDFG in whcih we operate. + :param first_map_exit: The exit node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Use this map to rename the parameter of the second Map, such + that they match the one of the first map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(first_map_exit): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=second_map_entry, + ): + pure_outputs.add(out_edge) + continue - # Success - return True + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + intermediate_desc: dace.data.Data = intermediate_node.desc(sdfg) + if self.is_view(intermediate_desc, sdfg): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Melets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view(producer_edge.src, sdfg): + return None + if producer_edge.data.dynamic: + # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not second_map_entry: + if self.is_node_reachable_from(graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry): + return None + continue - def apply(self, graph: SDFGState, sdfg: SDFG): + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE1: The subset still uses the old iteration variables. + # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. + # This is different compared to the producer Memlet. The reason is + # because in a consumer the data is conditionally read, so the data + # has to exists anyway. + for inner_consumer_edge in state.out_edges_by_connector( + second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:]): + if inner_consumer_edge.data.src_subset is None: + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert found_second_map, f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if param_repl: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace(mapping=param_repl, replace_callback=consumer_subset.replace) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum(producer_subset.covers(consumer_subset) for producer_subset in producer_subsets) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is used somewhere else, either in this or another state. + # NOTE: If the intermediate is shared, then we will turn it into a + # sink node attached to the combined map exit. Technically this + # should be enough, even if the same data appears again in the + # dataflow down streams. However, some DaCe transformations, + # I am looking at you `auto_optimizer()` do not like that. Thus + # if the intermediate is used further down in the same datadflow, + # then we consider that the maps can not be fused. But we only + # do this in the strict data flow mode. + if self.strict_dataflow: + if self._is_data_accessed_downstream( + data=intermediate_node.data, + graph=state, + begin=intermediate_node, # is ignored itself. + ): + return None + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum(len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs]) + return (pure_outputs, exclusive_outputs, shared_outputs) + + + def relocate_nodes( + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + The function assumes that the parameter renaming was already done. + + :param from_node: Node from which the edges should be removed. + :param to_node: Node to which the edges should reconnect. + :param state: The state in which the operation happens. + :param sdfg: The SDFG that is modified. """ - This method applies the mapfusion transformation. - Other than the removal of the second map entry node (SME), and the first - map exit (FME) node, it has the following side effects: - 1. Any transient adjacent to both FME and SME with degree = 2 will be removed. - The tasklets that use/produce it shall be connected directly with a - scalar/new transient (if the dataflow is more than a single scalar) + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError(f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented.") + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'.") + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) - 2. If this transient is adjacent to FME and SME and has other - uses, it will be adjacent to the new map exit post fusion. - Tasklet-> Tasklet edges will ALSO be added as mentioned above. + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + second_map_exit: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + :param intermediate_outputs: The set of outputs, that should be processed. + :param state: The state in which the map is processed. + :param sdfg: The SDFG that should be optimized. + :param first_map_exit: The exit of the first/top map. + :param second_map_entry: The entry of the second map. + :param second_map_exit: The exit of the second map. + :param is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + :note: Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ - 3. If an access node is adjacent to FME but not SME, it will be - adjacent to new map exit post fusion. + map_params = first_map_exit.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list(state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) - 4. If an access node is adjacent to SME but not FME, it will be - adjacent to the new map entry node post fusion. + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + find_new_name=True, + ) - """ - first_exit = self.first_map_exit - first_entry = graph.entry_node(first_exit) - second_entry = self.second_map_entry - second_exit = graph.exit_node(second_entry) - - intermediate_nodes = set() - for _, _, dst, _, _ in graph.out_edges(first_exit): - intermediate_nodes.add(dst) - assert isinstance(dst, nodes.AccessNode) - - # Check if an access node refers to non transient memory, or transient - # is used at another location (cannot erase) - do_not_erase = set() - for node in intermediate_nodes: - if sdfg.arrays[node.data].transient is False: - do_not_erase.add(node) else: - for edge in graph.in_edges(node): - if edge.src != first_exit: - do_not_erase.add(node) - break + assert (pre_exit_edge.data.subset.num_elements() > 1) or all(x == 1 for x in new_inter_shape) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=None, + ) + + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(include_self=False): + producer_edge = producer_tree.edge + + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == second_map_entry: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) else: - for edge in graph.out_edges(node): - if edge.dst != second_entry: - do_not_erase.add(node) - break - - # Find permutation between first and second scopes - perm = self.find_permutation(first_entry.map, second_entry.map) - params_dict = {} - for index, param in enumerate(first_entry.map.params): - params_dict[param] = second_entry.map.params[perm[index]] - - # Replaces (in memlets and tasklet) the second scope map - # indices with the permuted first map indices. - # This works in two passes to avoid problems when e.g., exchanging two - # parameters (instead of replacing (j,i) and (i,j) to (j,j) and then - # i,i). - second_scope = graph.scope_subgraph(second_entry) - for firstp, secondp in params_dict.items(): - if firstp != secondp: - replace(second_scope, secondp, '__' + secondp + '_fused') - for firstp, secondp in params_dict.items(): - if firstp != secondp: - replace(second_scope, '__' + secondp + '_fused', firstp) - - # Isolate First exit node - ############################ - edges_to_remove = set() - nodes_to_remove = set() - for edge in graph.in_edges(first_exit): - tree = graph.memlet_tree(edge) - access_node = tree.root().edge.dst - if access_node not in do_not_erase: - out_edges = [e for e in graph.out_edges(access_node) if e.dst == second_entry] - # In this transformation, there can only be one edge to the - # second map - assert len(out_edges) == 1 - - # Get source connector to the second map - connector = out_edges[0].dst_conn[3:] - - new_dsts = [] - # Look at the second map entry out-edges to get the new - # destinations - for e in graph.out_edges(second_entry): - if e.src_conn and e.src_conn[4:] == connector: - new_dsts.append(e) - if not new_dsts: # Access node is not used in the second map - nodes_to_remove.add(access_node) - continue + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. + new_inner_memlet = copy.deepcopy(inner_edge.data) + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(include_self=False): + consumer_edge = consumer_tree.edge + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] - # Add a transient scalar/array - self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst, new_dsts[0].dst_conn, new_dsts[1:]) - - edges_to_remove.add(edge) - - # Remove transient node between the two maps - nodes_to_remove.add(access_node) - else: # The case where intermediate array node cannot be removed - # Node will become an output of the second map exit - out_e = tree.parent.edge - conn = second_exit.next_connector() - graph.add_edge( - second_exit, - 'OUT_' + conn, - out_e.dst, - out_e.dst_conn, - dcpy(out_e.data), + else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = second_map_exit.next_connector() + state.add_edge( + new_inter_node, + None, + second_map_exit, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, ) - second_exit.add_out_connector('OUT_' + conn) - - graph.add_edge(edge.src, edge.src_conn, second_exit, 'IN_' + conn, dcpy(edge.data)) - second_exit.add_in_connector('IN_' + conn) - - edges_to_remove.add(out_e) - edges_to_remove.add(edge) - - # If the second map needs this node, link the connector - # that generated this to the place where it is needed, with a - # temp transient/scalar for memlet to be generated - for out_e in graph.out_edges(second_entry): - second_memlet_path = graph.memlet_path(out_e) - source_node = second_memlet_path[0].src - if source_node == access_node: - self.fuse_nodes(sdfg, graph, edge, out_e.dst, out_e.dst_conn) - - ### - # First scope exit is isolated and can now be safely removed - for e in edges_to_remove: - graph.remove_edge(e) - graph.remove_nodes_from(nodes_to_remove) - graph.remove_node(first_exit) - - # Isolate second_entry node - ########################### - for edge in graph.in_edges(second_entry): - tree = graph.memlet_tree(edge) - access_node = tree.root().edge.src - if access_node in intermediate_nodes: - # Already handled above, can be safely removed - graph.remove_edge(edge) - continue - - # This is an external input to the second map which will now go - # through the first map. - conn = first_entry.next_connector() - graph.add_edge(edge.src, edge.src_conn, first_entry, 'IN_' + conn, dcpy(edge.data)) - first_entry.add_in_connector('IN_' + conn) - graph.remove_edge(edge) - for out_enode in tree.children: - out_e = out_enode.edge - graph.add_edge( - first_entry, - 'OUT_' + conn, - out_e.dst, - out_e.dst_conn, - dcpy(out_e.data), + state.add_edge( + second_map_exit, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), ) - graph.remove_edge(out_e) - first_entry.add_out_connector('OUT_' + conn) - - # NOTE: Check the second MapEntry for output edges with empty memlets - for edge in graph.out_edges(second_entry): - if edge.data.is_empty(): - graph.remove_edge(edge) - graph.add_edge(first_entry, edge.src_conn, edge.dst, edge.dst_conn, edge.data) - - ### - # Second node is isolated and can now be safely removed - graph.remove_node(second_entry) - - # Fix scope exit to point to the right map - second_exit.map = first_entry.map - - def fuse_nodes(self, sdfg: SDFG, graph: SDFGState, edge, new_dst, new_dst_conn, other_edges=None): - """ Fuses two nodes via memlets and possibly transient arrays. """ - other_edges = other_edges or [] - memlet_path = graph.memlet_path(edge) - access_node = memlet_path[-1].dst - - local_name = "__s%d_n%d%s_n%d%s" % ( - self.state_id, - graph.node_id(edge.src), - edge.src_conn, - graph.node_id(edge.dst), - edge.dst_conn, - ) - # Add intermediate memory between subgraphs. - # If a scalar, uses direct connection. If an array, adds a transient node. - # NOTE: If any of the src/dst nodes is a nested SDFG, treat it as an array. - is_scalar = edge.data.subset.num_elements() == 1 - accesses = ( - [graph.memlet_path(e1)[0].src for e0 in graph.in_edges(access_node) for e1 in graph.memlet_tree(e0)] + - [graph.memlet_path(e1)[-1].dst for e0 in graph.out_edges(access_node) for e1 in graph.memlet_tree(e0)]) - if any(isinstance(a, nodes.NestedSDFG) for a in accesses): - is_scalar = False - if is_scalar: - local_name, _ = sdfg.add_scalar( - local_name, - dtype=access_node.desc(graph).dtype, - transient=True, - storage=dtypes.StorageType.Register, - find_new_name=True, + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) + + first_map_exit.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) + + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate(zip(new_inter_shape_raw, inter_shape)): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) + + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Union[subsets.Range, None], + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + This is the value that must be substracted from the memlets to adjust, i.e + (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is + `None` then the function computes the correction that should be applied to + the producer memlets, i.e. the memlets of the tree converging at + `intermediate_node`. If `producer_offset` is given, it should be the output + of the previous call to this function, with `producer_offset=None`. In this + case the function computes the correction for the consumer side, i.e. the + memlet tree that originates at `intermediate_desc`. + + :param original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + :param intermediate_desc: The original intermediate data descriptor. + :param map_params: The parameter of the final map. + :param producer_offset: The correction that was applied to the producer side. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + # If the intermediate was a scalar, then it will remain a scalar. + # Thus there is no correction that we must apply. + return subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError(f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'.") + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, ) - edge.data.data = local_name - edge.data.subset = "0" - - # If source of edge leads to multiple destinations, redirect all through an access node. - out_edges = list(graph.out_edges_by_connector(edge.src, edge.src_conn)) - if len(out_edges) > 1: - local_node = graph.add_access(local_name) - src_connector = None - - # Add edge that leads to transient node - graph.add_edge(edge.src, edge.src_conn, local_node, None, dcpy(edge.data)) - - for other_edge in out_edges: - if other_edge is not edge: - graph.remove_edge(other_edge) - mem = Memlet(data=local_name, other_subset=other_edge.data.dst_subset) - graph.add_edge(local_node, src_connector, other_edge.dst, other_edge.dst_conn, mem) + return final_offset + + + def can_topologically_be_fused( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> Optional[Dict[str, str]]: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + * The scope of the maps. + * The scheduling of the maps. + * The map parameters. + + :return: If the maps can not be topologically fused the function returns `None`. + If they can be fused the function returns `dict` that describes parameter + replacement, see `find_parameter_remapping()` for more. + + :param first_map_entry: The entry of the first (in serial case the top) map. + :param second_map_exit: The entry of the second (in serial case the bottom) map. + :param graph: The SDFGState in which the maps are located. + :param sdfg: The SDFG itself. + :param permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError("Only one of `only_inner_maps` and `only_toplevel_maps` is allowed per MapFusion instance.") + + # Ensure that both have the same schedule + if first_map_entry.map.schedule != second_map_entry.map.schedule: + return None + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[first_map_entry] != scope[second_map_entry]: + return None + elif self.only_inner_maps: + if scope[first_map_entry] is None: + return None + elif self.only_toplevel_maps: + if scope[first_map_entry] is not None: + return None + + # We will now check if we can rename the Map parameter of the second Map such that they + # match the one of the first Map. + param_repl = self.find_parameter_remapping(first_map=first_map_entry.map, second_map=second_map_entry.map) + return param_repl + + + def has_inner_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """This function tests if there are dependency inside the Maps. + + The function will scan and anaysize the body of the two Maps and look for + inconsistencies. To detect them the function will scan the body of the maps + and examine the all AccessNodes and apply the following rules: + * If an AccessNode refers to a View, it is ignored. Because the source is + either on the outside, in which case `has_read_write_dependency()` + takes care of it, or the data source is inside the Map body itself. + * An inconsistency is detected, if in each bodies there exists an AccessNode + that refer to the same data. + * An inconsistency is detected, if there exists an AccessNode that refers + to non transient data. This is an implementation detail and could be + lifted. + + Note that some of the restrictions of this function could be relaxed by + performing more analysis. + + :return: The function returns `True` if an inconsistency has been found. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_body = state.scope_subgraph(first_map_entry, False, False) + second_map_body = state.scope_subgraph(second_map_entry, False, False) + + # Find the data that is internally referenced. Because of the first rule above, + # we filter all views above. + first_map_body_data, second_map_body_data = [ + { + dnode.data + for dnode in map_body.nodes() + if isinstance(dnode, nodes.AccessNode) and not self.is_view(dnode, sdfg) + } + for map_body in [first_map_body, second_map_body] + ] + + # If there is data that is referenced in both, then we consider this as an error + # this is the second rule above. + if not first_map_body_data.isdisjoint(second_map_body_data): + return True + + # We consider it as a problem if any map refers to non-transient data. + # This is an implementation detail and could be dropped if we do further + # analysis. + if any( + not sdfg.arrays[data].transient + for data in first_map_body_data.union(second_map_body_data) + ): + return True + + return False + + + def has_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks three different things. + * The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets of the inputs of the MapEntry of the first and the + outputs of the MapExit node of the second map. + * The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as output of the + combined map. Note that it is allowed that an intermediate node is also + an input to the first map. + * In case an intermediate node, is also used as input node of the first map, + it is forbidden that the data is used as output of the second map, the + function will do additional checks. This is needed as the partition function + only checks the data consumption of the second map can be satisfied by the + data production of the first map, it ignores any potential reads made by + the first map's MapEntry. + + :return: `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` is returned. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Dict that describes how to rename the parameters of the second Map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_exit: nodes.MapExit = state.exit_node(first_map_entry) + second_map_exit: nodes.MapExit = state.exit_node(second_map_entry) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [first_map_entry, first_map_exit, second_map_entry, second_map_exit]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append({ + self.track_view(node, state, sdfg).data if self.is_view(node, sdfg) else node.data + for node in unresolved_set.values() + }) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + # This essentially ensures that an intermediate can not be used as output of + # the second map at the same time. It is actually stronger as it does not + # take their role into account. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection(read_map_2.values()) + + # If the number are different then a data is accessed through different + # AccessNodes. We could analyse this, but we will consider this as a data race. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + # NOTE: This set is not required to be empty. It might look as this would + # create a data race, but it is save. The reason is because all data has + # to pass through the intermediate we create, this will separate the reads + # from the writes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can not be used as output (of the second as well as the + # combined map) and as intermediate. If we would allow that the map would + # have two output nodes one the original one and the second is the created + # node that is created because the intermediate is shared. + # TODO(phimuell): Handle this case. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # While it is forbidden that a data container, used as intermediate, is also + # used as output of the second map. It is allowed that the data container + # is used as intermediate and as input of the first map. The partition only + # checks that the data dependencies are mean, i.e. what is read by the second + # map is also computed (written to the intermediate) it does not take into + # account the first map's read to the data container. + # To make an example: The partition function will make sure that if the + # second map reads index `i` from the intermediate that the first map writes + # to that index. But it will not care if the first map reads (through its + # MapEntry) index `i + 1`. In order to be valid me must ensure that the first + # map's reads and writes to the intermediate are pointwise. + # Note that we only have to make this check if it is also an intermediate node. + # Because if it is not read by the second map it is not a problem as the node + # will end up as an pure output node anyway. + read_write_map_1 = set(read_map_1.keys()).intersection(write_map_1.keys()) + datas_to_inspect = read_write_map_1.intersection(exchange_names) + for data_to_inspect in datas_to_inspect: + + # Now get all subsets of the data container that the first map reads + # from or writes to and check if they are pointwise. + all_subsets: List[subsets.Subset] = [] + all_subsets.extend( + self.find_subsets( + node=read_map_1[data_to_inspect], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + )) + all_subsets.extend( + self.find_subsets( + node=write_map_1[data_to_inspect], + scope_node=first_map_exit, + state=state, + sdfg=sdfg, + param_repl=None, + )) + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # If there is no intersection between the input and output data, then we can + # we have nothing to check. + if len(fused_inout_data_names) == 0: + return False + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets: List[subsets.Subset] = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + )) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=second_map_exit, + state=state, + sdfg=sdfg, + param_repl=param_repl, + )) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # No read write dependency was found. + return False + + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + :param subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False else: - local_node = edge.src - src_connector = edge.src_conn - - # update edge data in case source or destination is a scalar access node - test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)] - for new_data in test_data: - if isinstance(sdfg.arrays[new_data], data.Scalar): - edge.data.data = new_data - - # If destination of edge leads to multiple destinations, redirect all through an access node. - if other_edges: - # NOTE: If a new local node was already created, reuse it. - if local_node == edge.src: - local_node_out = graph.add_access(local_name) - connector_out = None + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + + def is_shared_data( + self, + data: nodes.AccessNode, + sdfg: dace.SDFG, + ) -> bool: + """Tests if `data` is shared data, an can not be removed from the SDFG. + + Interstate data is used to transmit data, this includes: + * The data is referred in multiple states. + * The data is referred to multiple times in the same state, either the state + has multiple access nodes for that data or an access node has an + out degree larger than one. + * The data is read inside interstate edges. + + This definition is stricter than the one employed by `SDFG.shared_transients()`, + as it also includes usage within a state. + + :param transient: The transient that should be checked. + :param sdfg: The SDFG containing the array. + + :note: The function computes the this set once for every SDFG and then caches it. + There is no mechanism to detect if the cache must be evicted. However, + as long as no additional data is added to the SDFG, there is no problem. + :note: If `assume_always_shared` was set, then the function will always return `True`. + """ + # This is the only point where we check for `assume_always_shared`. + if self.assume_always_shared: + return True + + # Check if the SDFG is known, if not scan it and compute the set. + if sdfg not in self._shared_data: + self._compute_shared_data_in(sdfg) + return data.data in self._shared_data[sdfg] + + + def _compute_shared_data_in( + self, + sdfg: dace.SDFG, + ) -> None: + """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + + See the documentation for `self.is_shared_data()` for a description. + + :param sdfg: The SDFG for which the set of shared data should be computed. + """ + # Shared data of this SDFG. + shared_data: Set[str] = set() + + # All global data can not be removed, so it must always be shared. + for data_name, data_desc in sdfg.arrays.items(): + if not data_desc.transient: + shared_data.add(data_name) + elif isinstance(data_desc, dace.data.Scalar): + shared_data.add(data_name) + + # We go through all states and classify the nodes/data: + # - Data is referred to in different states. + # - The access node is a view (both have to survive). + # - Transient sink or source node. + # - The access node has output degree larger than 1 (input degrees larger + # than one, will always be partitioned as shared anyway). + prevously_seen_data: Set[str] = set() + for state in sdfg.states(): + for access_node in state.data_nodes(): + + if access_node.data in shared_data: + # The data was already classified to be shared data + pass + + elif access_node.data in prevously_seen_data: + # We have seen this data before, either in this state or in + # a previous one, but we did not classifies it as shared back then + shared_data.add(access_node.data) + + if state.in_degree(access_node) == 0: + # (Transient) sink nodes are used in other states, or simplify + # will get rid of them. + shared_data.add(access_node.data) + + elif state.out_degree(access_node) != 1: # state.out_degree() == 0 or state.out_degree() > 1 + # The access node is either a source node (it is shared in another + # state) or the node has a degree larger than one, so it is used + # in this state somewhere else. + shared_data.add(access_node.data) + + elif self.is_view(node=access_node, sdfg=sdfg): + # To ensure that the write to the view happens, both have to be shared. + viewed_data: str = self.track_view(view=access_node, state=state, sdfg=sdfg).data + shared_data.update([access_node.data, viewed_data]) + prevously_seen_data.update([access_node.data, viewed_data]) + else: - local_node_out = local_node - connector_out = src_connector - graph.add_edge(local_node, src_connector, local_node_out, connector_out, - Memlet.from_array(local_name, sdfg.arrays[local_name])) - graph.add_edge(local_node_out, connector_out, new_dst, new_dst_conn, dcpy(edge.data)) - for e in other_edges: - graph.add_edge(local_node_out, connector_out, e.dst, e.dst_conn, dcpy(edge.data)) + # The node was not classified as shared data, so we record that + # we saw it. Note that a node that was immediately classified + # as shared node will never be added to this set, but a data + # that was found twice will be inside this list. + prevously_seen_data.add(access_node.data) + + # Now we collect all symbols that are read in interstate edges. + # Because, they might refer to data inside states and must be kept alive. + interstate_read_symbols: Set[str] = set() + for edge in sdfg.edges(): + interstate_read_symbols.update(edge.data.read_symbols()) + data_read_in_interstate_edges = interstate_read_symbols.intersection(prevously_seen_data) + + # Compute the final set of shared data and update the internal cache. + shared_data.update(data_read_in_interstate_edges) + self._shared_data[sdfg] = shared_data + + + def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Optional[Dict[str, str]]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at is _needed_, i.e. all parameter have the same name and range, + then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. + + :param first_map: The first map (these parameters will be replaced). + :param second_map: The second map, these parameters acts as source. + """ + + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.discard(param) + unused_first_params.discard(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) + break else: - # Add edge that leads to the second node - graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) + # We did not find a candidate, so the remapping does not exist + return None + + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping + + + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + :param first_map: The first map (these are the final parameter). + :param second_map: The second map, this map will be replaced. + :param second_map_entry: The entry node of the second map. + :param state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping(first_map=first_map, second_map=second_map) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) + + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + def is_node_reachable_from( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + :param graph: The graph to operate on. + :param begin: The start of the DFS. + :param end: The node that should be located. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() + + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + + def _is_data_accessed_downstream( + self, + data: str, + graph: dace.SDFGState, + begin: nodes.Node, + ) -> bool: + """Tests if there is an AccessNode for `data` downstream of `begin`. + + Essentially, this function starts a DFS at `begin` and checks every + AccessNode that is reachable from it. If it finds such a node it will + check if it refers to `data` and if so, it will return `True`. + If no such node is found it will return `False`. + Note that the node `begin` will be ignored. + + :param data: The name of the data to look for. + :param graph: The graph to explore. + :param begin: The node to start exploration; The node itself is ignored. + """ + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + # Dataflow graph is acyclic, so we do not need to keep a list of + # what we have visited. + to_visit: List[nodes.Node] = list(next_nodes(begin)) + while len(to_visit) > 0: + node = to_visit.pop() + if isinstance(node, nodes.AccessNode) and node.data == data: + return True + to_visit.extend(next_nodes(node)) + + return False + + + def get_access_set( + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. + + :param scope_node: The scope node that should be evaluated. + :param state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) + other_node = lambda e: e.src else: - local_name, _ = sdfg.add_transient(local_name, - symbolic.overapproximate(edge.data.subset.size()), - dtype=access_node.desc(graph).dtype, - find_new_name=True) - old_edge = dcpy(edge) - local_node = graph.add_access(local_name) - src_connector = None - edge.data.data = local_name - edge.data.subset = ",".join(["0:" + str(s) for s in edge.data.subset.size()]) - # Add edge that leads to transient node - graph.add_edge( - edge.src, - edge.src_conn, - local_node, - None, - dcpy(edge.data), - ) + get_edges = lambda node: state.out_edges(node) + other_node = lambda e: e.dst + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) if isinstance(node, nodes.AccessNode) + } - # Add edge that leads to the second node - graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) + return access_set + + + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + param_repl: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + + :param node: The access node that should be examined. + :param scope_node: We are only interested in data that flows through this node. + :param state: The state in which we operate. + :param sdfg: The SDFG object. + :param param_repl: `dict` that describes the parameter renaming that should be + performed. Can be `None` to skip the processing. + """ + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset + get_inner_edges = lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset + get_inner_edges = lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) + + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + found_subsets = copy.deepcopy(found_subsets) + if param_repl: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(param_repl, subset.replace) + + return found_subsets + + + def is_view( + self, + node: Union[nodes.AccessNode, data.Data], + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node if isinstance(node, data.Data) else node.desc(sdfg) + return isinstance(node_desc, data.View) + + + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + :param view: The view that should be traced. + :param state: The state in which we operate. + :param sdfg: The SDFG on which we operate. + """ - for e in other_edges: - graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, dcpy(edge.data)) + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view - # Modify data and memlets on all surrounding edges to match array - for neighbor in graph.all_edges(local_node): - for e in graph.memlet_tree(neighbor): - if e.data.data == local_name: - continue - e.data.data = local_name - e.data.subset.offset(old_edge.data.subset, negative=True) + # This is the node that defines the view. + defining_node = dace.sdfg.utils.get_last_view_node(state, view) + assert isinstance(defining_node, nodes.AccessNode) + assert not self.is_view(defining_node, sdfg) + return defining_node diff --git a/tests/buffer_tiling_test.py b/tests/buffer_tiling_test.py index 52477dcc72..03635a7721 100644 --- a/tests/buffer_tiling_test.py +++ b/tests/buffer_tiling_test.py @@ -78,6 +78,7 @@ def _semantic_eq(tile_sizes, program): count = sdfg.apply_transformations(BufferTiling, options={'tile_sizes': tile_sizes}) assert count > 0 + sdfg.validate() sdfg(w3=w3, w5=w5, A=A, B=B2, I=A.shape[0], J=A.shape[1]) assert np.allclose(B1, B2) diff --git a/tests/npbench/polybench/correlation_test.py b/tests/npbench/polybench/correlation_test.py index d743ba528d..a5532cf829 100644 --- a/tests/npbench/polybench/correlation_test.py +++ b/tests/npbench/polybench/correlation_test.py @@ -83,7 +83,8 @@ def run_correlation(device_type: dace.dtypes.DeviceType): # Compute ground truth and validate result corr_ref = ground_truth(M, float_n_ref, data_ref) - assert np.allclose(corr, corr_ref) + diff = corr_ref - corr + assert np.abs(diff).max() <= 10e-10 return sdfg diff --git a/tests/npbench/polybench/heat_3d_test.py b/tests/npbench/polybench/heat_3d_test.py index e058914fd3..75ad902c4b 100644 --- a/tests/npbench/polybench/heat_3d_test.py +++ b/tests/npbench/polybench/heat_3d_test.py @@ -64,10 +64,22 @@ def run_heat_3d(device_type: dace.dtypes.DeviceType): A_ref = np.copy(A) B_ref = np.copy(B) + def count_maps(sdfg: dc.SDFG) -> int: + nb_maps = 0 + for _, state in sdfg.all_nodes_recursive(): + node: dc.SDFGState + for node in state.nodes(): + if isinstance(node, dc.sdfg.nodes.MapEntry): + nb_maps += 1 + return nb_maps + if device_type in {dace.dtypes.DeviceType.CPU, dace.dtypes.DeviceType.GPU}: # Parse the SDFG and apply auto-opt sdfg = heat_3d_kernel.to_sdfg() + initial_maps = count_maps(sdfg) sdfg = auto_optimize(sdfg, device_type) + after_maps = count_maps(sdfg) + assert after_maps < initial_maps, f"Expected less maps, initially {initial_maps} many maps, but after optimization {after_maps}" sdfg(TSTEPS, A, B, N=N) elif device_type == dace.dtypes.DeviceType.FPGA: # Parse SDFG and apply FPGA friendly optimization diff --git a/tests/npbench/polybench/jacobi_2d_test.py b/tests/npbench/polybench/jacobi_2d_test.py index bc2d5a4f2b..61982c427f 100644 --- a/tests/npbench/polybench/jacobi_2d_test.py +++ b/tests/npbench/polybench/jacobi_2d_test.py @@ -47,6 +47,7 @@ def run_jacobi_2d(device_type: dace.dtypes.DeviceType): # Parse the SDFG and apply autopot sdfg = kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) + sdfg(A=A, B=B, TSTEPS=TSTEPS, N=N) elif device_type == dace.dtypes.DeviceType.FPGA: diff --git a/tests/python_frontend/fields_and_global_arrays_test.py b/tests/python_frontend/fields_and_global_arrays_test.py index b7f5e46ee9..03cb4c5915 100644 --- a/tests/python_frontend/fields_and_global_arrays_test.py +++ b/tests/python_frontend/fields_and_global_arrays_test.py @@ -585,7 +585,7 @@ def caller(): # Ensure only three globals are created sdfg = caller.to_sdfg() - assert len([k for k in sdfg.arrays if '__g' in k]) == 3 + assert len([k for k in sdfg.arrays if k.startswith('__g')]) == 3 def test_two_inner_methods(): diff --git a/tests/transformations/apply_to_test.py b/tests/transformations/apply_to_test.py index de542b758c..34ff114ac5 100644 --- a/tests/transformations/apply_to_test.py +++ b/tests/transformations/apply_to_test.py @@ -15,6 +15,7 @@ def dbladd(A: dace.float64[100, 100], B: dace.float64[100, 100]): dbl = B return A + dbl * B + @dace.program def unfusable(A: dace.float64[100], B: dace.float64[100, 100]): """Test function of two maps that can not be fused.""" @@ -57,8 +58,12 @@ def test_applyto_pattern(): transient = next(aname for aname, desc in sdfg.arrays.items() if desc.transient) access_node = next(n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == transient) - assert MapFusion.can_be_applied_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry) - + assert MapFusion.can_be_applied_to( + sdfg, + first_map_exit=mult_exit, + array=access_node, + second_map_entry=add_entry + ) MapFusion.apply_to(sdfg, first_map_exit=mult_exit, array=access_node, second_map_entry=add_entry) assert len([node for node in state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1 diff --git a/tests/transformations/mapfusion_data_races_test.py b/tests/transformations/mapfusion_data_races_test.py index e765ec6978..ff87fd61ec 100644 --- a/tests/transformations/mapfusion_data_races_test.py +++ b/tests/transformations/mapfusion_data_races_test.py @@ -41,6 +41,13 @@ def rw_data_race_3(A: dace.float64[20], B: dace.float64[20]): A[:10] += 3.0 * offset(A[:11]) +@dace.program +def rw_data_race_4(A: dace.float64[20], B: dace.float64[20]): + # This is potentially fusable + A += B + A *= 2.0 + + def test_rw_data_race(): sdfg = rw_data_race.to_sdfg(simplify=True) sdfg.apply_transformations_repeated(MapFusion) @@ -50,8 +57,9 @@ def test_rw_data_race(): def test_rw_data_race_2_mf(): sdfg = rw_data_race_2.to_sdfg(simplify=True) - sdfg.apply_transformations_repeated(MapFusion) + nb_applied = sdfg.apply_transformations_repeated(MapFusion) map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert nb_applied > 0 assert (len(map_entry_nodes) > 1) @@ -69,8 +77,27 @@ def test_rw_data_race_3_sgf(): assert (len(map_entry_nodes) > 1) +def test_rw_data_race_3_mf(): + sdfg = rw_data_race_3.to_sdfg(simplify=True) + nb_applied = sdfg.apply_transformations_repeated(MapFusion) + map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert (len(map_entry_nodes) > 1) + assert nb_applied > 0 + + +def test_rw_data_race_4_mf(): + # It is technically possible to fuse it, because there is only a point wise dependency. + # However, it is very hard to detect and handle correct. + sdfg = rw_data_race_4.to_sdfg(simplify=True) + sdfg.apply_transformations_repeated(MapFusion) + map_entry_nodes = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, nodes.MapEntry)] + assert (len(map_entry_nodes) >= 1) + + if __name__ == "__main__": test_rw_data_race() - test_rw_data_race_2_mf() test_rw_data_race_2_sgf() + test_rw_data_race_2_mf() test_rw_data_race_3_sgf() + test_rw_data_race_3_mf() + test_rw_data_race_4_mf() diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 724c8c97ee..1e7678c3aa 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1,12 +1,120 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Any, Union, Tuple, Optional + import numpy as np import os import dace -from dace.transformation.dataflow import MapFusion +import copy +import uuid +import pytest + +from dace import SDFG, SDFGState +from dace.sdfg import nodes +from dace.transformation.dataflow import MapFusion, MapExpansion + + +def count_node(sdfg: SDFG, node_type): + nb_nodes = 0 + for rsdfg in sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, node_type): + nb_nodes += 1 + return nb_nodes + +def apply_fusion( + sdfg: SDFG, + removed_maps: Union[int, None] = None, + final_maps: Union[int, None] = None, + unspecific: bool = False, + apply_once: bool = False, + strict_dataflow: bool = True, +) -> SDFG: + """Applies the Map fusion transformation. + + The function checks that the number of maps has been reduced, it is also possible + to specify the number of removed maps. It is also possible to specify the final + number of maps. + If `unspecific` is set to `True` then the function will just apply the + transformation and not check if maps were removed at all. + If `strict_dataflow` is set to `True`, the default, then the function will perform + the fusion in strict dataflow mode. + """ + org_sdfg = copy.deepcopy(sdfg) + num_maps_before = None if unspecific else count_node(sdfg, nodes.MapEntry) + + try: + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + if apply_once: + sdfg.apply_transformations( + MapFusion(strict_dataflow=strict_dataflow), + validate=True, + validate_all=True + ) + else: + sdfg.apply_transformations_repeated( + MapFusion(strict_dataflow=strict_dataflow), + validate=True, + validate_all=True + ) + except: + org_sdfg.view() + sdfg.view() + raise + + if unspecific: + return sdfg + + num_maps_after = count_node(sdfg, nodes.MapEntry) + has_processed = False + if removed_maps is not None: + has_processed = True + rm = num_maps_before - num_maps_after + if not (rm == removed_maps): + sdfg.view() + assert rm == removed_maps, f"Expected to remove {removed_maps} but removed {rm}" + if final_maps is not None: + has_processed = True + if not (final_maps == num_maps_after): + sdfg.view() + assert final_maps == num_maps_after, f"Expected that only {final_maps} maps remain, but there are sill {num_maps_after}." + if not has_processed: + if not (num_maps_after < num_maps_before): + sdfg.view() + assert num_maps_after < num_maps_before, f"Maps after: {num_maps_after}; Maps before: {num_maps_before}" + return sdfg @dace.program -def fusion(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): +def fusion_simple(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): + tmp = dace.define_local([10, 20], dtype=A.dtype) + tmp_2 = dace.define_local([10, 20], dtype=A.dtype) + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << A[i, j] + b >> tmp[i, j] + + b = a * a + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp[i, j] + b << B[i, j] + c >> tmp_2[i, j] + + c = a + b + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp_2[i, j] + b >> out(1, lambda a, b: a + b)[0] + + b = a + + +@dace.program +def fusion_rename(A: dace.float32[10, 20], B: dace.float32[10, 20], out: dace.float32[1]): tmp = dace.define_local([10, 20], dtype=A.dtype) tmp_2 = dace.define_local([10, 20], dtype=A.dtype) for i, j in dace.map[0:10, 0:20]: @@ -66,12 +174,99 @@ def fusion_chain(A: dace.float32[10, 20], B: dace.float32[10, 20]): B[:] = tmp2 + 5 +@dace.program +def fusion_with_transient(A: dace.float64[2, 20]): + res = np.ndarray([2, 20], dace.float64) + for i in dace.map[0:20]: + for j in dace.map[0:2]: + with dace.tasklet: + a << A[j, i] + t >> res[j, i] + t = a * a + for i in dace.map[0:20]: + for j in dace.map[0:2]: + with dace.tasklet: + t << res[j, i] + o >> A[j, i] + o = t * 2 + + +@dace.program +def fusion_shared_output(A: dace.float32[10, 20], B: dace.float32[10, 20], C: dace.float32[10, 20]): + tmp = A + 3 + B[:] = tmp * 4 + C[:] = tmp / 6 + + +@dace.program +def fusion_indirect_access(A: dace.float32[100], B: dace.float32[100], idx: dace.int32[30], out: dace.float32[30]): + tmp = (A + B * 2) + 3 + out[:] = tmp[idx] + + +def make_interstate_transient_fusion_sdfg(): + sdfg = dace.SDFG("interstate_transient_fusion") + state1 = sdfg.add_state("state1", is_start_block=True) + state2 = sdfg.add_state_after(state1, "state2") + + for name in ["A", "B", "C", "D"]: + sdfg.add_array(name, shape=(20, 20), dtype=dace.float64, transient=False) + sdfg.arrays["B"].transient = True + + A1, B1, C1 = (state1.add_access(name) for name in ["A", "B", "C"]) + state1.add_mapped_tasklet( + "map_1_1", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 20", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + input_nodes={"A": A1}, + output_nodes={"B": B1}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "map_2_1", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("B[__i0, __i1]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + input_nodes={"B": B1}, + output_nodes={"C": C1}, + external_edges=True, + ) + + B2, D2 = (state2.add_access(name) for name in ["B", "D"]) + state2.add_mapped_tasklet( + "map_1_2", + map_ranges={"__i0": "0:20", "__i1": "0:20"}, + inputs={"__in1": dace.Memlet("B[__i0, __i1]")}, + code="__out = __in1 + 6", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + input_nodes={"B": B2}, + output_nodes={"D": D2}, + external_edges=True, + ) + + return sdfg, state1, state2 + + def test_fusion_simple(): - sdfg = fusion.to_sdfg() - sdfg.save(os.path.join('_dacegraphs', 'before1.sdfg')) - sdfg.simplify() - sdfg.apply_transformations_repeated(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after1.sdfg')) + sdfg = fusion_simple.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg, final_maps=1) + + A = np.random.rand(10, 20).astype(np.float32) + B = np.random.rand(10, 20).astype(np.float32) + out = np.zeros(shape=1, dtype=np.float32) + sdfg(A=A, B=B, out=out) + + diff = abs(np.sum(A * A + B) - out) + print('Difference:', diff) + assert diff <= 1e-3 + + +def test_fusion_rename(): + sdfg = fusion_rename.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) B = np.random.rand(10, 20).astype(np.float32) @@ -83,20 +278,43 @@ def test_fusion_simple(): assert diff <= 1e-3 +def test_fusion_shared(): + sdfg = fusion_shared_output.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg) + + A = np.random.rand(10, 20).astype(np.float32) + B = np.random.rand(10, 20).astype(np.float32) + C = np.random.rand(10, 20).astype(np.float32) + + B_res = (A + 3) * 4 + C_res = (A + 3) / 6 + sdfg(A=A, B=B, C=C) + + assert np.allclose(B_res, B) + assert np.allclose(C_res, C) + + +def test_indirect_accesses(): + sdfg = fusion_indirect_access.to_sdfg(simplify=True) + sdfg = apply_fusion(sdfg, final_maps=2) + + A = np.random.rand(100).astype(np.float32) + B = np.random.rand(100).astype(np.float32) + idx = ((np.random.rand(30) * 100) % 100).astype(np.int32) + out = np.zeros(shape=30, dtype=np.float32) + + res = ((A + B * 2) + 3)[idx] + sdfg(A=A, B=B, idx=idx, out=out) + + assert np.allclose(res, out) + + def test_multiple_fusions(): - sdfg = multiple_fusions.to_sdfg() - num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()]) + sdfg = multiple_fusions.to_sdfg(simplify=True) sdfg.save(os.path.join('_dacegraphs', 'before2.sdfg')) sdfg.simplify() - sdfg.apply_transformations_repeated(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after2.sdfg')) - - num_nodes_after = len([node for state in sdfg.nodes() for node in state.nodes()]) - # Ensure that the number of nodes was reduced after transformation - if num_nodes_after >= num_nodes_before: - raise RuntimeError('SDFG was not properly transformed ' - '(nodes before: %d, after: %d)' % (num_nodes_before, num_nodes_after)) + sdfg = apply_fusion(sdfg) A = np.random.rand(10, 20).astype(np.float32) B = np.zeros_like(A) @@ -113,20 +331,9 @@ def test_multiple_fusions(): def test_fusion_chain(): - sdfg = fusion_chain.to_sdfg() - sdfg.save(os.path.join('_dacegraphs', 'before3.sdfg')) + sdfg = fusion_chain.to_sdfg(simplify=True) sdfg.simplify() - sdfg.apply_transformations(MapFusion) - num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()]) - sdfg.apply_transformations(MapFusion) - sdfg.apply_transformations(MapFusion) - sdfg.save(os.path.join('_dacegraphs', 'after3.sdfg')) - - num_nodes_after = len([node for state in sdfg.nodes() for node in state.nodes()]) - # Ensure that the number of nodes was reduced after transformation - if num_nodes_after >= num_nodes_before: - raise RuntimeError('SDFG was not properly transformed ' - '(nodes before: %d, after: %d)' % (num_nodes_before, num_nodes_after)) + sdfg = apply_fusion(sdfg, final_maps=1) A = np.random.rand(10, 20).astype(np.float32) B = np.zeros_like(A) @@ -136,29 +343,14 @@ def test_fusion_chain(): assert diff <= 1e-4 -@dace.program -def fusion_with_transient(A: dace.float64[2, 20]): - res = np.ndarray([2, 20], dace.float64) - for i in dace.map[0:20]: - for j in dace.map[0:2]: - with dace.tasklet: - a << A[j, i] - t >> res[j, i] - t = a * a - for i in dace.map[0:20]: - for j in dace.map[0:2]: - with dace.tasklet: - t << res[j, i] - o >> A[j, i] - o = t * 2 - def test_fusion_with_transient(): A = np.random.rand(2, 20) expected = A * A * 2 - sdfg = fusion_with_transient.to_sdfg() + sdfg = fusion_with_transient.to_sdfg(simplify=True) sdfg.simplify() - sdfg.apply_transformations(MapFusion) + sdfg = apply_fusion(sdfg, removed_maps=2) + sdfg(A=A) assert np.allclose(A, expected) @@ -191,7 +383,7 @@ def build_sdfg(): return sdfg sdfg = build_sdfg() - sdfg.apply_transformations(MapFusion) + sdfg = apply_fusion(sdfg) A = np.random.rand(N, K) B = np.repeat(np.nan, N) @@ -217,10 +409,12 @@ def inverted_maps(A: dace.int32[10]): sdfg(A=val0) assert np.array_equal(val0, ref) - sdfg.apply_transformations(MapFusion) + # This can not be fused + apply_fusion(sdfg, removed_maps=0) + val1 = np.ndarray((10,), dtype=np.int32) sdfg(A=val1) - assert np.array_equal(val1, ref) + assert np.array_equal(val1, ref), f"REF: {ref}; VAL: {val1}" def test_fusion_with_empty_memlet(): @@ -240,8 +434,7 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): out[0] += lsum sdfg = inner_product.to_sdfg(simplify=True) - count = sdfg.apply_transformations_repeated(MapFusion) - assert count == 2 + apply_fusion(sdfg, removed_maps=2) A = np.arange(1024, dtype=np.float32) B = np.arange(1024, dtype=np.float32) @@ -253,9 +446,8 @@ def inner_product(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): def test_fusion_with_nested_sdfg_0(): - @dace.program - def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]): - tmp = np.empty([10], dtype=np.int32) + def ref(A, B, C): + tmp = np.zeros_like(A) for i in dace.map[0:10]: if C[i] < 0: tmp[i] = B[i] - A[i] @@ -263,9 +455,130 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 tmp[i] = B[i] + A[i] for i in dace.map[0:10]: A[i] = tmp[i] * 2 - - sdfg = fusion_with_nested_sdfg_0.to_sdfg(simplify=True) - sdfg.apply_transformations(MapFusion) + + def _make_sdfg() -> dace.SDFG: + sdfg = SDFG("fusion_with_nested_sdfg_0") + state = sdfg.add_state(is_start_block=True) + + for name in "ABCT": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + me1, mx1 = state.add_map("first_map", ndrange={"__i0": "0:10"}) + nsdfg = state.add_nested_sdfg( + sdfg=_make_nested_sdfg(), + parent=sdfg, + inputs={"a", "b", "c"}, + outputs={"t"}, + symbol_mapping={}, + ) + + for name in "ABC": + state.add_edge( + state.add_access(name), None, + me1, "IN_" + name, + dace.Memlet(f"{name}[0:10]"), + ) + me1.add_in_connector("IN_" + name) + state.add_edge( + me1, "OUT_" + name, + nsdfg, name.lower(), + dace.Memlet(f"{name}[__i0]"), + ) + me1.add_out_connector("OUT_" + name) + state.add_edge( + nsdfg, "t", + mx1, "IN_T", + dace.Memlet("T[__i0]"), + ) + T = state.add_access("T") + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10]"), + ) + mx1.add_in_connector("IN_T") + mx1.add_out_connector("OUT_T") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i0]")}, + code="__out = __in1 * 2", + outputs={"__out": dace.Memlet("A[__i0]")}, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + def _make_nested_sdfg() -> dace.SDFG: + sdfg = SDFG("Nested") + + for name in "abct": + sdfg.add_scalar( + name, + dtype=dace.float64, + transient=False, + ) + + state_head = sdfg.add_state("head_state", is_start_block=True) + state_if_guard = sdfg.add_state("state_if_guard") + sdfg.add_edge( + state_head, + state_if_guard, + dace.InterstateEdge( + condition="1", + assignments={"__tmp2": "c < 0.0"}, + ) + ) + + def _make_branch_tasklet( + state: dace.SDFGState, + code: str, + ) -> None: + tasklet = state.add_tasklet( + state.label + "_tasklet", + inputs={"__in1", "__in2"}, + code=code, + outputs={"__out"}, + ) + state.add_edge( + state.add_access("b"), None, + tasklet, "__in1", + dace.Memlet("b[0]"), + ) + state.add_edge( + state.add_access("a"), None, + tasklet, "__in2", + dace.Memlet("a[0]"), + ) + state.add_edge( + tasklet, "__out", + state.add_access("t"), None, + dace.Memlet("t[0]"), + ) + + state_trueb = sdfg.add_state("true_branch") + _make_branch_tasklet(state_trueb, "__out = __in1 - __in2") + state_falseb = sdfg.add_state("false_branch") + _make_branch_tasklet(state_falseb, "__out = __in1 + __in2") + state_if_end = sdfg.add_state("if_join") + + sdfg.add_edge(state_if_guard, state_trueb, dace.InterstateEdge(condition="__tmp2")) + sdfg.add_edge(state_if_guard, state_falseb, dace.InterstateEdge(condition="not __tmp2")) + sdfg.add_edge(state_falseb, state_if_end, dace.InterstateEdge()) + sdfg.add_edge(state_trueb, state_if_end, dace.InterstateEdge()) + sdfg.validate() + return sdfg + + sdfg = _make_sdfg() + apply_fusion(sdfg) for sd in sdfg.all_sdfgs_recursive(): if sd is not sdfg: @@ -277,8 +590,25 @@ def fusion_with_nested_sdfg_0(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(dst, dace.nodes.AccessNode) + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'C': np.array(np.random.rand(10) - 0.5, dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res), f"Failed in {arg}" + + def test_fusion_with_nested_sdfg_1(): - + + # As a side effect this test also ensures that dynamic consumer edges, does not + # impact fusing, i.e. allow that fusion can take place. @dace.program def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int32[10]): tmp = np.empty([10], dtype=np.int32) @@ -295,7 +625,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 B[i] = tmp[i] * 2 sdfg = fusion_with_nested_sdfg_1.to_sdfg(simplify=True) - sdfg.apply_transformations(MapFusion) + apply_fusion(sdfg) if len(sdfg.states()) != 1: return @@ -310,13 +640,1406 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 assert isinstance(src, dace.nodes.AccessNode) +def test_interstate_fusion(): + """Transient between two maps is used in another state and must become shared. + """ + sdfg, state1, state2 = make_interstate_transient_fusion_sdfg() + + A = np.random.rand(20, 20) + C = np.random.rand(20, 20) + D = np.random.rand(20, 20) + + ref_C = A + 30 + ref_D = A + 26 + + apply_fusion(sdfg, removed_maps=1) + assert sdfg.number_of_nodes() == 2 + assert len([node for node in state1.data_nodes() if node.data == "B"]) == 1 + + sdfg(A=A, C=C, D=D) + + assert np.allclose(C, ref_C) + assert np.allclose(D, ref_D) + + +def test_fuse_indirect_accesses(): + + @dace.program(auto_optimize=False) + def inner_product( + A: dace.float32[20], + B: dace.float32[20], + idx: dace.int32[20], + out: dace.float32[20], + ): + tmp1 = np.empty_like(A) + tmp2 = np.empty_like(A) + for i in dace.map[0:20]: + tmp1[i] = A[i] * B[i] + for i in dace.map[0:20]: + tmp2[i] = tmp1[i] + A[i] + for i in dace.map[0:20]: + with dace.tasklet: + __arr << tmp2(1)[:] + __idx << idx[i] + __out >> out[i] + __out = __arr[__idx] + + sdfg = inner_product.to_sdfg(simplify=True) + assert sdfg.number_of_nodes() == 1 + assert count_node(sdfg, nodes.MapEntry) == 3 + + apply_fusion(sdfg, final_maps=2) + + # The last map, with the indirect access, can not be fused, so check that. + state = next(iter(sdfg.nodes())) + assert len(list(state.sink_nodes())) == 1 + out_node = next(iter(state.sink_nodes())) + assert out_node.data == "out" + assert state.in_degree(out_node) == 1 + + # Now find the last map and the indirect access Tasklet + last_map_exit = next(iter(state.in_edges(out_node))).src + last_map_entry = state.entry_node(last_map_exit) + assert isinstance(last_map_exit, nodes.MapExit) + assert state.in_degree(last_map_exit) == 1 + + indirect_access_tasklet = next(iter(state.in_edges(last_map_exit))).src + assert isinstance(indirect_access_tasklet, nodes.Tasklet) + assert indirect_access_tasklet.code == "__out = __arr[__idx]" # TODO: Regex with connectors + + # The tasklet can only be connected to a map entry. + assert all(in_edge.src is last_map_entry for in_edge in state.in_edges(indirect_access_tasklet)) + + +def make_correction_offset_sdfg( + range_read: bool, + second_read_start: int, +) -> SDFG: + """Make the SDFGs for the `test_offset_correction_*` tests. + + Args: + range_read: If `True` then a range is read in the second map. + if `False` then only a scalar is read. + second_read_start: Where the second map should start reading. + """ + sdfg = SDFG("offset_correction_test") + state = sdfg.add_state(is_start_block=True) + shapes = { + "A": (20, 10), + "B": (20, 8), + "C": (20, 2) if range_read else (20, 1), + } + descs = {} + for name, shape in shapes.items(): + _, desc = sdfg.add_array(name, shape, dace.float64, transient=False) + descs[name] = desc + sdfg.arrays["B"].transient = True + A, B, C = (state.add_access(name) for name in sorted(shapes.keys())) + + state.add_mapped_tasklet( + "first_map", + map_ranges={"i": "0:20", "j": "2:8"}, + inputs={"__in1": dace.Memlet("A[i, j]")}, + code="__out = __in1 + 1.0", + outputs={"__out": dace.Memlet("B[i, j]")}, + input_nodes={"A": A}, + output_nodes={"B": B}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_map", + map_ranges=( + {"i": "0:20", "k": "0:2"} + if range_read + else {"i": "0:20"} + ), + inputs={"__in1": dace.Memlet(f"B[i, {second_read_start}{'+k' if range_read else ''}]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet(f"C[i, {'k' if range_read else '0'}]")}, + input_nodes={"B": B}, + output_nodes={"C": C}, + external_edges=True, + ) + sdfg.validate() + assert sdfg.apply_transformations_repeated(MapExpansion, validate_all=True) > 0 + return sdfg + + +def test_offset_correction_range_read(): + + np.random.seed(42) + A = np.random.rand(20, 10) + C = np.zeros((20, 2)) + exp = (A + 1.0)[:, 3:5].copy() + + sdfg = make_correction_offset_sdfg(range_read=True, second_read_start=3) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + C[:] = 0.0 + + apply_fusion(sdfg) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + + +def test_offset_correction_scalar_read(): + + np.random.seed(42) + A = np.random.rand(20, 10) + C = np.zeros((20, 1)) + exp = (A + 1.0)[:, 3].copy().reshape((-1, 1)) + + sdfg = make_correction_offset_sdfg(range_read=False, second_read_start=3) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + C[:] = 0.0 + + apply_fusion(sdfg) + + sdfg(A=A, C=C) + assert np.allclose(C, exp) + + +def test_offset_correction_empty(): + + # Because the second map starts reading from 1, but the second map only + # starts writing from 2 there is no overlap and it can not be fused. + # NOTE: This computation is useless. + sdfg = make_correction_offset_sdfg(range_read=True, second_read_start=1) + + apply_fusion(sdfg, removed_maps=0) + + +def test_different_offsets(): + + def exptected(A, B): + N, M = A.shape + return (A + 1) + B[1:(N+1), 2:(M+2)] + + def _make_sdfg(N: int, M: int) -> dace.SDFG: + sdfg = dace.SDFG("test_different_access") + names = ["A", "B", "__tmp", "__return"] + def_shape = (N, M) + sshape = {"B": (N+1, M+2), "__tmp": (N+1, M+1)} + for name in names: + sdfg.add_array( + name, + shape=sshape.get(name, def_shape), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["__tmp"].transient = True + + state = sdfg.add_state(is_start_block=True) + A, B, _tmp, _return = (state.add_access(name) for name in names) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={"__in": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]")}, + input_nodes={A}, + output_nodes={_tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": f"0:{N}", "__i1": f"0:{M}"}, + inputs={ + "__in1": dace.Memlet("__tmp[__i0 + 1, __i1 + 1]"), + "__in2": dace.Memlet("B[__i0 + 1, __i1 + 2]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("__return[__i0, __i1]")}, + input_nodes={_tmp, B}, + output_nodes={_return}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + N, M = 14, 17 + sdfg = _make_sdfg(N, M) + apply_fusion(sdfg, final_maps=1) + + A = np.array(np.random.rand(N, M), dtype=np.float64, copy=True) + B = np.array(np.random.rand(N + 1, M + 2), dtype=np.float64, copy=True) + + ref = exptected(A, B) + res = sdfg(A=A, B=B) + assert np.allclose(ref, res) + + +def _make_strict_dataflow_sdfg_pointwise( + input_data: str = "A", + intermediate_data: str = "T", + output_data: Optional[str] = None, + input_read: str = "__i0", + output_write: Optional[str] = None, +) -> Tuple[dace.SDFG, dace.SDFGState]: + """ + Creates the SDFG for the strict data flow tests. + + The SDFG will read and write into `A`, but it is pointwise, thus the Maps can + be fused. Furthermore, this particular SDFG guarantees that no data race occurs. + """ + if output_data is None: + output_data = input_data + if output_write is None: + output_write = input_read + + sdfg = dace.SDFG(f"strict_dataflow_sdfg_pointwise_{str(uuid.uuid1()).replace('-', '_')}") + state = sdfg.add_state(is_start_block=True) + for name in {input_data, intermediate_data, output_data}: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + if intermediate_data not in {input_data, output_data}: + sdfg.arrays[intermediate_data].transient = True + + input_node, intermediate_node, output_node = (state.add_access(name) for name in [input_data, intermediate_data, output_data]) + + state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet(f"{input_data}[{input_read}]")}, + code="__out = __in1 + 2.0", + outputs={"__out": dace.Memlet(f"{intermediate_data}[__i0]")}, + input_nodes={input_node}, + output_nodes={intermediate_node}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i1": "0:10"}, + inputs={"__in1": dace.Memlet(f"{intermediate_data}[__i1]")}, + code="__out = __in1 + 3.0", + outputs={"__out": dace.Memlet(f"{output_data}[{output_write}]")}, + input_nodes={intermediate_node}, + output_nodes={output_node}, + external_edges=True, + ) + sdfg.validate() + return sdfg, state + + +def test_fusion_strict_dataflow_pointwise(): + sdfg, state = _make_strict_dataflow_sdfg_pointwise(input_data="A") + + # However, if strict dataflow is disabled, then it will be able to fuse. + apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) + + +def test_fusion_strict_dataflow_not_pointwise(): + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + input_read="__i0", + output_write="9 - __i0", + ) + + # Because the dependency is not pointwise even disabling strict dataflow + # will not make it work. + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) + + +def test_fusion_dataflow_intermediate(): + sdfg, _ = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="O", + output_data="O", + ) + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + # Because the intermediate is also output of the second map it is not possible + # to fuse even without strict dataflow mode. + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) + + +def test_fusion_dataflow_intermediate_2(): + # The transformation applies for two reasons, first reading and writing `A` + # is pointwise. Furthermore, there is no further access to `A` after the + # intermediate node. Note that if the second map would also have an output + # that refers to `A` then the transformation would not apply regardless + # of the strict dataflow mode. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + ) + apply_fusion(sdfg, removed_maps=1, strict_dataflow=True) + map_exit = next(iter(node for node in state.nodes() if isinstance(node, nodes.MapExit))) + assert state.out_degree(map_exit) == 2 + assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)} + + +def test_fusion_dataflow_intermediate_3(): + # This is exactly the same situation as in `test_fusion_dataflow_intermediate_2()` + # with the exception that now the access to `A` is no longer pointwise, thus + # the transformation does not apply. Note that this SDFG is wrong, it is only + # here to show that the case is detected. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + input_read="9 - __i0", + output_write="__i0", + ) + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + +def test_fusion_dataflow_intermediate_downstream(): + # Because the intermediate `T` is used downstream again, + # the transformation can not apply. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="T", + output_data="output_1", + ) + sdfg.arrays["output_1"].transient = False + sdfg.arrays["T"].transient = True + output_1 = next(iter(dnode for dnode in state.sink_nodes())) + assert isinstance(output_1, nodes.AccessNode) and output_1.data == "output_1" + + # Make the real output node. + sdfg.arrays["O"] = sdfg.arrays["A"].clone() + state.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("output_1[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("T[__i0]")}, + input_nodes={output_1}, + external_edges=True, + ) + + # Make another state where `T` is written back, such that it is not dead data flow. + state2 = sdfg.add_state_after(state) + sdfg.add_datadesc("output_2", sdfg.arrays["output_1"].clone()) + state2.add_nedge( + state2.add_access("T"), + state2.add_access("output_2"), + sdfg.make_array_memlet("T"), + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + # However without strict dataflow, the merge is possible. + apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) + assert state.in_degree(output_1) == 1 + assert state.out_degree(output_1) == 1 + assert all(isinstance(edge.src, nodes.MapExit) for edge in state.in_edges(output_1)) + assert all(isinstance(edge.dst, nodes.MapEntry) for edge in state.out_edges(output_1)) + + upper_map_exit = next(iter(edge.src for edge in state.in_edges(output_1))) + assert isinstance(upper_map_exit, nodes.MapExit) + assert state.out_degree(upper_map_exit) == 2 + assert {"T", "output_1"} == {edge.dst.data for edge in state.out_edges(upper_map_exit) if isinstance(edge.dst, nodes.AccessNode)} + + +def test_fusion_non_strict_dataflow_implicit_dependency(): + """ + This test simulates if the fusion respect implicit dependencies, given by access nodes. + + This test simulates a situation that could arise if non strict dataflow is enabled. + The test ensures that the fusion does not continue fusing in this situation. + """ + sdfg = dace.SDFG("fusion_strict_dataflow_implicit_dependency_sdfg") + state = sdfg.add_state(is_start_block=True) + names = ["A", "B", "T1", "T2", "C"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T1"].transient = True + sdfg.arrays["T2"].transient = True + + me, mx = state.add_map( + "first_map", + ndrange={"__i0": "0:10"} + ) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1", "__in2"}, + code="__out = __in1 * __in2", + outputs={"__out"} + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1", "__in2"}, + code="__out = (__in1 + __in2) / 2", + outputs={"__out"} + ) + A, B, T1, T2 = (state.add_access(name) for name in names[:-1]) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10]")) + me.add_in_connector("IN_A") + me.add_in_connector("IN_B") + + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[__i0]")) + state.add_edge(me, "OUT_B", tskl1, "__in2", dace.Memlet("B[__i0]")) + state.add_edge(me, "OUT_A", tskl2, "__in1", dace.Memlet("A[__i0]")) + state.add_edge(me, "OUT_B", tskl2, "__in2", dace.Memlet("B[__i0]")) + me.add_out_connector("OUT_A") + me.add_out_connector("OUT_B") + + state.add_edge(tskl1, "__out", mx, "IN_T1", dace.Memlet("T1[__i0]")) + state.add_edge(tskl2, "__out", mx, "IN_T2", dace.Memlet("T2[__i0]")) + mx.add_in_connector("IN_T1") + mx.add_in_connector("IN_T2") + + state.add_edge(mx, "OUT_T1", T1, None, dace.Memlet("T1[0:10]")) + state.add_edge(mx, "OUT_T2", T2, None, dace.Memlet("T2[0:10]")) + mx.add_out_connector("OUT_T1") + mx.add_out_connector("OUT_T2") + + state.add_mapped_tasklet( + "second_map", + map_ranges={"__in0": "0:10"}, + inputs={"__in1": dace.Memlet("T1[__i0]")}, + code="if __in1 < 0.5:\n\t__out = 100.", + outputs={"__out": dace.Memlet("T2[__i0]", dynamic=True)}, + input_nodes={T1}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state) + state2.add_edge( + state2.add_access("T2"), + None, + state2.add_access("C"), + None, + dace.Memlet("T2[0:10] -> [0:10]"), + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=0, strict_dataflow=False) + + +def _make_inner_conflict_shared_scalar( + has_conflict: bool, +) -> dace.SDFG: + """Generate the SDFG for tests with the inner dependency. + + If `has_conflict` is `True` then a transient scalar is used inside both Map bodies. + Therefore, `MapFusion` should not be able to fuse them. + In case `has_conflict` is `False` then different scalars are used which allows + fusing the two maps. + """ + sdfg = dace.SDFG( + "inner_map_dependency_sdfg" + if has_conflict + else "inner_map_dependency_resolved_sdfg" + ) + state = sdfg.add_state(is_start_block=True) + + name_arrays = ["A", "T", "C"] + for aname in name_arrays: + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + name_scalars = ["s", "s"] if has_conflict else ["s1", "s2"] + for sname in set(name_scalars): + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, T, C = (state.add_access(aname) for aname in name_arrays) + s1, s2 = (state.add_access(sname) for sname in name_scalars) + + me1, mx1 = state.add_map( + "map_1", + ndrange={"__i0": "0:10"}, + ) + tsklt1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 1.0", + ) + + # Create the first map series. + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + state.add_edge( + me1, "OUT_A", + s1, None, + dace.Memlet("A[__i0] -> [0]") + ) + me1.add_out_connector("OUT_A") + state.add_edge( + s1, None, + tsklt1, "__in1", + dace.Memlet(f"{s1.data}[0]") + ) + state.add_edge( + tsklt1, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0]") + ) + mx1.add_in_connector("IN_T") + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10]") + ) + mx1.add_out_connector("OUT_T") + + # Create the second map. + me2, mx2 = state.add_map( + "map_2", + ndrange={"__i0": "0:10"}, + ) + tsklt2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 3.0", + ) + + state.add_edge( + T, None, + me2, "IN_T", + dace.Memlet("T[0:10]") + ) + me2.add_in_connector("IN_T") + state.add_edge( + me2, "OUT_T", + s2, None, + dace.Memlet("T[__i0]") + ) + me2.add_out_connector("OUT_T") + state.add_edge( + s2, None, + tsklt2, "__in1", + dace.Memlet(f"{s2.data}[0]") + ) + state.add_edge( + tsklt2, "__out", + mx2, "IN_C", + dace.Memlet("C[__i0]") + ) + mx2.add_in_connector("IN_C") + state.add_edge( + mx2, "OUT_C", + C, None, + dace.Memlet("C[0:10]") + ) + mx2.add_out_connector("OUT_C") + sdfg.validate() + return sdfg + + +def test_inner_map_dependency(): + # Because the scalar is not shared the maps can not be fused. + sdfg = _make_inner_conflict_shared_scalar(has_conflict=True) + apply_fusion(sdfg, removed_maps=0, final_maps=2) + + +def test_inner_map_dependency_resolved(): + # Because the scalars are different, the scalar + sdfg = _make_inner_conflict_shared_scalar(has_conflict=False) + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + +def _impl_fusion_intermediate_different_access( + modified_shape: bool, + traditional_memlet_direction: bool +): + def ref(A, B): + T = np.zeros((A.shape[0] + 1, 2)) + for i in range(A.shape[0]): + T[i + 1, 0] = A[i] * 2 + T[i + 1, 1] = A[i] / 2 + for j in range(A.shape[0]): + B[j] = np.sin(T[j+1, 1]) + + sdfg = dace.SDFG("fusion_intermediate_different_access_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "T", + shape=(11, 2), + dtype=dace.float64, + transient=True, + ) + + # For this intermediate, which essentially represents `[A[i] * 2, A[i] / 2]` in + # the reference above, there are two important remarks: + # - It exists because one data stream, i.e. `T[i + 1, 1]` would be dead data flow + # and currently the transformation can not handle this. + # - The strange shape is because the transformation can not handle this case. + # This is a limitation of the implementation. + sdfg.add_array( + "temp", + shape=( + (1, 2,) + if modified_shape + else (2,) + ), + dtype=dace.float64, + transient=True, + ) + + A, B, T, temp = (state.add_access(name) for name in ["A", "B", "T", "temp"]) + + me1, mx1 = state.add_map( + "first_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("OUT_A") + + tsklt1_1 = state.add_tasklet( + "tsklt1_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_1, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_1, "__out", + temp, None, + dace.Memlet( + "temp[0, 0]" + if modified_shape + else "temp[0]" + ) + ) + + tsklt1_2 = state.add_tasklet( + "tsklt1_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 / 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_2, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_2, "__out", + temp, None, + dace.Memlet( + "temp[0, 1]" + if modified_shape + else "temp[1]" + ) + ) + + temp_subset = ( + "0, 0:2" + if modified_shape + else "0:2" + ) + T_subset = "__i0 + 1, 0:2" + + if traditional_memlet_direction: + mem_data = "T" + mem_subset = T_subset + mem_other_subset = temp_subset + else: + mem_data = "temp" + mem_subset = temp_subset + mem_other_subset = T_subset + + state.add_edge( + temp, None, + mx1, "IN_temp", + dace.Memlet(f"{mem_data}[{mem_subset}] -> [{mem_other_subset}]") + ) + state.add_edge( + mx1, "OUT_temp", + T, None, + dace.Memlet("T[1:11, 0:2]") + ) + mx1.add_in_connector("IN_temp") + mx1.add_out_connector("OUT_temp") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i1": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i1 + 1, 1]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i1]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_intermediate_different_access(): + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_2(): + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=True) + + +def test_fusion_intermediate_different_access_mod_shape(): + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_mod_shape_2(): + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=True) + + +@pytest.mark.skip(reason="This feature is not yet fully supported.") +def test_fusion_multiple_producers_consumers(): + """Multiple producer and consumer nodes. + + This test is very similar to the `test_fusion_intermediate_different_access()` + and `test_fusion_intermediate_different_access_mod_shape()` test, with the + exception that now full data is used in the second map. + However, currently `MapFusion` only supports a single producer, thus this test can + not run. + """ + def ref(A, B): + T = np.zeros((A.shape[0], 2)) + for i in range(A.shape[0]): + T[i, 0] = A[i] * 2 + T[i, 1] = A[i] / 2 + for j in range(A.shape[0]): + B[j] = np.sin(T[j, 1]) + np.cos(T[j, 0]) + + sdfg = dace.SDFG("fusion_multiple_producers_consumers_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "T", + shape=(10, 2), + dtype=dace.float64, + transient=True, + ) + + A, B, T = (state.add_access(name) for name in ["A", "B", "T"]) + + me1, mx1 = state.add_map( + "first_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + A, None, + me1, "IN_A", + dace.Memlet("A[0:10]") + ) + me1.add_in_connector("IN_A") + me1.add_out_connector("OUT_A") + + tsklt1_1 = state.add_tasklet( + "tsklt1_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_1, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_1, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0, 0]") + ) + + tsklt1_2 = state.add_tasklet( + "tsklt1_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 / 2.0", + ) + state.add_edge( + me1, "OUT_A", + tsklt1_2, "__in1", + dace.Memlet("A[__i0]") + ) + state.add_edge( + tsklt1_2, "__out", + mx1, "IN_T", + dace.Memlet("T[__i0, 1]") + ) + mx1.add_in_connector("IN_T") + + state.add_edge( + mx1, "OUT_T", + T, None, + dace.Memlet("T[0:10, 0:2]"), + ) + mx1.add_out_connector("OUT_T") + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i1": "0:10"}, + inputs={ + "__in1": dace.Memlet("T[__i1, 1]"), + "__in2": dace.Memlet("T[__i1, 0]"), + }, + code="__out = math.sin(__in1) + math.cos(__in2)", + outputs={"__out": dace.Memlet("B[__i1]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_multiple_consumers(): + """The intermediate is consumed multiple times in the second map. + """ + def ref(A, B, C): + T = np.zeros_like(A) + for i in range(A.shape[0]): + T[i] = np.sin(A[i] * 2) + for j in range(A.shape[0]): + B[j] = T[j] * 3. + C[j] = T[j] - 1. + + sdfg = dace.SDFG("fusion_multiple_consumers_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "ABCT": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + A, B, C, T = (state.add_access(name) for name in ["A", "B", "C", "T"]) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i1": "0:10"}, + inputs={ + "__in1": dace.Memlet("A[__i1]"), + }, + code="__out = math.sin(2 * __in1)", + outputs={"__out": dace.Memlet("T[__i1]")}, + input_nodes={A}, + output_nodes={T}, + external_edges=True, + ) + + me2, mx2 = state.add_map( + "second_map", + ndrange={"__i0": "0:10"}, + ) + + state.add_edge( + T, None, + me2, "IN_T", + dace.Memlet("T[0:10]", volume=20) + ) + me2.add_in_connector("IN_T") + me2.add_out_connector("OUT_T") + + tsklt2_1 = state.add_tasklet( + "tsklt2_1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 * 3.0", + ) + state.add_edge( + me2, "OUT_T", + tsklt2_1, "__in1", + dace.Memlet("T[__i0]") + ) + state.add_edge( + tsklt2_1, "__out", + mx2, "IN_B", + dace.Memlet("B[__i0]") + ) + + tsklt2_2 = state.add_tasklet( + "tsklt2_2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 - 1.0", + ) + state.add_edge( + me2, "OUT_T", + tsklt2_2, "__in1", + dace.Memlet("T[__i0]") + ) + state.add_edge( + tsklt2_2, "__out", + mx2, "IN_C", + dace.Memlet("C[__i0]") + ) + mx2.add_in_connector("IN_B") + mx2.add_in_connector("IN_C") + + state.add_edge( + mx2, "OUT_B", + B, None, + dace.Memlet("B[0:10]"), + ) + state.add_edge( + mx2, "OUT_C", + C, None, + dace.Memlet("C[0:10]"), + ) + mx2.add_out_connector("OUT_B") + mx2.add_out_connector("OUT_C") + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10), dtype=np.float64, copy=True), + 'C': np.array(np.random.rand(10), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_different_global_accesses(): + + def ref(A, B): + T = np.zeros_like(A) + for i in range(10): + T[i] = A[i] - B[i + 1] + for i in range(10): + A[i] = np.sin(T[i]) + B[i + 1] = np.cos(T[i]) + + sdfg = dace.SDFG("fusion_different_global_accesses_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "ABT": + sdfg.add_array( + name, + shape=(11,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + T = state.add_access("T") + + state.add_mapped_tasklet( + "first_comp", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("A[__i0]"), + "__in2": dace.Memlet("B[__i0 + 1]") + }, + code="__out = __in1 - __in2", + outputs={"__out": dace.Memlet("T[__i0]")}, + output_nodes={T}, + external_edges=True, + ) + state.add_mapped_tasklet( + "second_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("T[__i0]")}, + code="__out1 = math.sin(__in1)\n__out2 = math.cos(__in1)", + outputs={ + "__out1": dace.Memlet("A[__i0]"), + "__out2": dace.Memlet("B[__i0 + 1]"), + }, + input_nodes={T}, + external_edges=True, + ) + sdfg.validate() + + apply_fusion(sdfg, removed_maps=1, final_maps=1) + + args_ref = { + 'A': np.array(np.random.rand(11), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(11), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_dynamic_producer(): + + def ref(A, B): + for i in range(10): + if B[i] < 0.5: + A[i] = 0.0 + for i in range(10): + B[i] = np.sin(A[i]) + + sdfg = dace.SDFG("fusion_dynamic_producer_sdfg") + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + B_top, B_bottom, A = (state.add_access(name) for name in "BBA") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("B[__i0]")}, + code="if __in1 < 0.5:\n\t__out = 0.0", + outputs={"__out": dace.Memlet("A[__i0]", dynamic=True)}, + input_nodes={B_top}, + output_nodes={A}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A}, + output_nodes={B_bottom}, + external_edges=True, + ) + sdfg.validate() + + # In case dynamic Memlets should be handled, we specify `unspecific`, i.e. + # only validation tests are done. However, we run a verification step to see + # if the transformation did the right thing. + apply_fusion(sdfg, unspecific=True) + + args_ref = { + 'A': np.array(np.random.rand(11), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(11), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + csdfg = sdfg.compile() + csdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def test_fusion_intrinsic_memlet_direction(): + + def ref(A, B): + T = A + 10.0 + B[:] = np.sin(T) + + sdfg = dace.SDFG("fusion_dynamic_producer_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in "ATB": + sdfg.add_array( + name, + shape=(10, 11, 12), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["T"].transient = True + + for num in "12": + sdfg.add_scalar( + "t" + num, + dtype=dace.float64, + transient=True, + ) + + A, T, B, t1, t2 = (state.add_access(name) for name in ["A", "T", "B", "t1", "t2"]) + + tsklt1, me1, mx1 = state.add_mapped_tasklet( + "comp1", + map_ranges={ + "__i1": "0:10", + "__i2": "0:11", + "__i3": "0:12", + }, + inputs={"__in1": dace.Memlet("A[__i1, __i2, __i3]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("T[__i1, __i2, __i3]")}, + input_nodes={A}, + output_nodes={T}, + external_edges=True, + ) + + tsklt2, me2, mx2 = state.add_mapped_tasklet( + "comp2", + map_ranges={ + "__i1": "0:10", + "__i2": "0:11", + "__i3": "0:12", + }, + inputs={"__in1": dace.Memlet("T[__i1, __i2, __i3]")}, + code="__out = math.sin(__in1)", + outputs={"__out": dace.Memlet("B[__i1, __i2, __i3]")}, + input_nodes={T}, + output_nodes={B}, + external_edges=True, + ) + + for me in [me1, me2]: + dace.transformation.dataflow.MapExpansion.apply_to( + sdfg, + options={"inner_schedule": dace.ScheduleType.Default}, + map_entry=me, + ) + + # Now add a transient scalar at the output of `tsklt1`. + tsklt1_oedge = next(iter(state.out_edges(tsklt1))) + me1_inner = tsklt1_oedge.dst + state.add_edge( + tsklt1, "__out", + t1, None, + dace.Memlet("t1[0]"), + ) + state.add_edge( + t1, None, + me1_inner, tsklt1_oedge.dst_conn, + dace.Memlet("t1[0] -> [__i1, __i2, __i3]"), + ) + state.remove_edge(tsklt1_oedge) + tsklt1_oedge = None + + # Now add a transient scalar in the front of `tsklt2`. + tsklt2_iedge = next(iter(state.in_edges(tsklt2))) + me2_inner = tsklt2_iedge.src + state.add_edge( + me2_inner, tsklt2_iedge.src_conn, + t2, None, + dace.Memlet("t2[0] -> [__i1, __i2, __i3]"), + ) + state.add_edge( + t2, None, + tsklt2, "__in1", + dace.Memlet("t2[0]"), + ) + state.remove_edge(tsklt2_iedge) + tsklt2_iedge = None + sdfg.validate() + + # By Specifying `apply_once` we only perform one fusion, which will eliminate `T`. + # This is not efficient, we do this to make sure that the update of the Memlets + # has worked. + apply_fusion(sdfg, apply_once=True) + + for edge in state.edges(): + # There should be no edge, that references `T`. + assert edge.data.data != "T" + + # If an edge is connected to `t2` or `t1` then its data should refer to it. + # no other Memlet shall refer to them. + for t in [t1, t2]: + if edge.src is t or edge.dst is t: + assert edge.data.data == t.data + else: + assert edge.data.data != t.data + + args_ref = { + 'A': np.array(np.random.rand(10, 11, 12), dtype=np.float64, copy=True), + 'B': np.array(np.random.rand(10, 11, 12), dtype=np.float64, copy=True), + } + args_res = copy.deepcopy(args_ref) + + ref(**args_ref) + sdfg(**args_res) + for arg in args_ref.keys(): + arg_ref = args_ref[arg] + arg_res = args_res[arg] + assert np.allclose(arg_ref, arg_res) + + +def _make_possible_cycle_if_fuesed_sdfg() -> Tuple[dace.SDFG, nodes.MapExit, nodes.AccessNode, nodes.MapEntry]: + """Generate an SDFG that if two maps would be fused a cycle would be created. + + Essentially tests if the MapFusion detects this special case. + """ + sdfg = dace.SDFG("possible_cycle_if_fuesed_sdfg") + state = sdfg.add_state(is_start_block=True) + + names = ["A", "B", "T", "U", "V"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["A"].transient = False + sdfg.arrays["B"].transient = False + + A, B, T, U, V = (state.add_access(name) for name in names) + + _, _, first_map_exit = state.add_mapped_tasklet( + "map1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("A[__i]")}, + code="__out1 = __in + 10\n__out2 = __in - 10", + outputs={ + "__out1": dace.Memlet("T[__i]"), + "__out2": dace.Memlet("U[__i]"), + }, + input_nodes={A}, + output_nodes={T, U}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "map2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("U[__i]")}, + code="__out = math.sin(__in)", + outputs={"__out": dace.Memlet("V[__i]")}, + input_nodes={U}, + output_nodes={V}, + external_edges=True, + ) + + _, second_map_entry, _ = state.add_mapped_tasklet( + "map3", + map_ranges={"__i": "0:10"}, + inputs={ + "__in1": dace.Memlet("T[__i]"), + "__in2": dace.Memlet("V[__i]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("B[__i]")}, + input_nodes={T, V}, + output_nodes={B}, + external_edges=True, + ) + sdfg.validate() + + return sdfg, first_map_exit, T, second_map_entry + + +def test_possible_cycle_if_fuesed_sdfg(): + sdfg, first_map_exit, array, second_map_entry = _make_possible_cycle_if_fuesed_sdfg() + + would_transformation_apply = MapFusion.can_be_applied_to( + sdfg, + first_map_exit=first_map_exit, + array=array, + second_map_entry=second_map_entry, + ) + assert not would_transformation_apply + + if __name__ == '__main__': + test_fusion_intrinsic_memlet_direction() + test_fusion_dynamic_producer() + test_fusion_different_global_accesses() + test_fusion_multiple_consumers() + test_fusion_intermediate_different_access() + test_fusion_intermediate_different_access_mod_shape() + test_fusion_non_strict_dataflow_implicit_dependency() + test_fusion_strict_dataflow_pointwise() + test_fusion_strict_dataflow_not_pointwise() + test_fusion_dataflow_intermediate() + test_fusion_dataflow_intermediate_2() + test_fusion_dataflow_intermediate_downstream() + test_indirect_accesses() + test_fusion_shared() + test_fusion_with_transient() + test_fusion_rename() test_fusion_simple() test_multiple_fusions() test_fusion_chain() - test_fusion_with_transient() test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() + test_interstate_fusion() test_fusion_with_nested_sdfg_1() + test_fuse_indirect_accesses() + test_offset_correction_range_read() + test_offset_correction_scalar_read() + test_offset_correction_empty() + test_different_offsets() + test_inner_map_dependency() + test_inner_map_dependency_resolved() + test_possible_cycle_if_fuesed_sdfg() diff --git a/tests/transformations/warp_tiling_test.py b/tests/transformations/warp_tiling_test.py index 7c75d08878..4ca424a1c1 100644 --- a/tests/transformations/warp_tiling_test.py +++ b/tests/transformations/warp_tiling_test.py @@ -38,19 +38,27 @@ def test_warp_softmax(vector_length=1): sdfg = softmax_fwd.to_sdfg(simplify=True) # Apply transformations - sdfg.apply_transformations_repeated(ReduceExpansion) + sdfg.apply_transformations_repeated(ReduceExpansion, validate_all=True) MultiExpansion.apply_to(sdfg, sdfg.node(0).nodes()) SubgraphFusion.apply_to(sdfg, sdfg.node(0).nodes()) sdfg.expand_library_nodes() sdfg.simplify() - sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) - sdfg.apply_transformations(GPUTransformSDFG) + sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion], validate_all=True) + sdfg.apply_transformations(GPUTransformSDFG, validate_all=True) assert sdfg.apply_transformations(WarpTiling) == 1 - sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion]) - sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) + sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion], validate_all=True) + sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion], validate_all=True) if vector_length != 1: sdfg.apply_transformations_repeated( - Vectorization, dict(vector_len=vector_length, preamble=False, postamble=False, strided_map=False)) + Vectorization, + dict( + vector_len=vector_length, + preamble=False, + postamble=False, + strided_map=False + ), + validate_all=True + ) sdfg.specialize(dict(dn1=2, dn2=16, dn3=128, dr=128)) # Check validity