Skip to content

Commit

Permalink
Updated InlineMultistateSDFG (#1689)
Browse files Browse the repository at this point in the history
The `can_be_applied()` function did not consider the symbol map when the
shape of the arrays were compared. This commit fixes this behaiour by
first appling a replacing step before the comparisson.

Furthermore, the commit removes all the commented out code.
  • Loading branch information
philip-paul-mueller authored Oct 17, 2024
1 parent 073b613 commit 653ec33
Showing 1 changed file with 7 additions and 216 deletions.
223 changes: 7 additions & 216 deletions dace/transformation/interstate/multistate_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg import InterstateEdge, SDFG, SDFGState
from dace.sdfg import utils as sdutil, infer_types
from dace.sdfg.replace import replace_datadesc_names
from dace.sdfg.replace import replace_datadesc_names, replace_properties_dict
from dace.transformation import transformation, helpers
from dace.properties import make_properties
from dace import data
Expand Down Expand Up @@ -103,7 +103,10 @@ def can_be_applied(self, state: SDFGState, expr_index, sdfg, permissive=False):
if isinstance(outer_desc, data.View):
return False

inner_desc = nested_sdfg.sdfg.arrays[edge.dst_conn]
# We can not compare shapes directly, we have to consider the symbol map
# for that. Clone the descriptor because the operation is inplace.
inner_desc = nested_sdfg.sdfg.arrays[edge.dst_conn].clone()
symbolic.safe_replace(nested_sdfg.symbol_mapping, lambda m: replace_properties_dict(inner_desc, m))
if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides):
return False

Expand All @@ -121,7 +124,8 @@ def can_be_applied(self, state: SDFGState, expr_index, sdfg, permissive=False):
if isinstance(outer_desc, data.View):
return False

inner_desc = nested_sdfg.sdfg.arrays[edge.src_conn]
inner_desc = nested_sdfg.sdfg.arrays[edge.src_conn].clone()
symbolic.safe_replace(nested_sdfg.symbol_mapping, lambda m: replace_properties_dict(inner_desc, m))
if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides):
return False

Expand Down Expand Up @@ -208,27 +212,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
#######################################################
# Collect and modify access nodes as necessary

# Access nodes that need to be reshaped
# reshapes: Set(str) = set()
# for aname, array in nsdfg.arrays.items():
# if array.transient:
# continue
# edge = None
# if aname in inputs:
# edge = inputs[aname]
# if len(array.shape) > len(edge.data.subset):
# reshapes.add(aname)
# continue
# if aname in outputs:
# edge = outputs[aname]
# if len(array.shape) > len(edge.data.subset):
# reshapes.add(aname)
# continue
# if edge is not None and not InlineMultistateSDFG._check_strides(
# array.strides, sdfg.arrays[edge.data.data].strides,
# edge.data, nsdfg_node):
# reshapes.add(aname)

# Mapping from nested transient name to top-level name
transients: Dict[str, str] = {}

Expand Down Expand Up @@ -281,50 +264,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):

symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True)

# Add views whenever reshapes are necessary
# for dname in reshapes:
# desc = nsdfg.arrays[dname]
# # To avoid potential confusion, rename protected __return keyword
# if dname.startswith('__return'):
# newname = f'{nsdfg.name}_ret{dname[8:]}'
# else:
# newname = dname
# newname, _ = sdfg.add_view(newname,
# desc.shape,
# desc.dtype,
# storage=desc.storage,
# strides=desc.strides,
# offset=desc.offset,
# debuginfo=desc.debuginfo,
# allow_conflicts=desc.allow_conflicts,
# total_size=desc.total_size,
# alignment=desc.alignment,
# may_alias=desc.may_alias,
# find_new_name=True)
# repldict[dname] = newname

# Add extra access nodes for out/in view nodes
# inv_reshapes = {repldict[r]: r for r in reshapes}
# for nstate in nsdfg.nodes():
# for node in nstate.nodes():
# if isinstance(node,
# nodes.AccessNode) and node.data in inv_reshapes:
# if nstate.in_degree(node) > 0 and nstate.out_degree(
# node) > 0:
# # Such a node has to be in the output set
# edge = outputs[inv_reshapes[node.data]]

# # Redirect outgoing edges through access node
# out_edges = list(nstate.out_edges(node))
# anode = nstate.add_access(edge.data.data)
# vnode = nstate.add_access(node.data)
# nstate.add_nedge(node, anode, edge.data)
# nstate.add_nedge(anode, vnode, edge.data)
# for e in out_edges:
# nstate.remove_edge(e)
# nstate.add_edge(vnode, e.src_conn, e.dst,
# e.dst_conn, e.data)

# Make unique names for states
statenames = set(s.label for s in sdfg.nodes())
for nstate in nsdfg.nodes():
Expand Down Expand Up @@ -364,46 +303,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
sdfg.start_state = sdfg.node_id(source)

# TODO: Modify memlets by offsetting
# If both source and sink nodes are inputs/outputs, reconnect once
# edges_to_ignore = self._modify_access_to_access(new_incoming_edges,
# nsdfg, nstate, state,
# orig_data)

# source_to_outer = {n: e.src for n, e in new_incoming_edges.items()}
# sink_to_outer = {n: e.dst for n, e in new_outgoing_edges.items()}
# # If a source/sink node is one of the inputs/outputs, reconnect it,
# # replacing memlets in outgoing/incoming paths
# modified_edges = set()
# modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
# state, sink_to_outer, True,
# edges_to_ignore)
# modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
# state, source_to_outer,
# False, edges_to_ignore)

# # Reshape: add connections to viewed data
# self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
# True)
# self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
# False)

# Modify all other internal edges pertaining to input/output nodes
# for nstate in nsdfg.nodes():
# for node in nstate.nodes():
# if isinstance(node, nodes.AccessNode):
# if node.data in input_set or node.data in output_set:
# if node.data in input_set:
# outer_edge = inputs[input_set[node.data]]
# else:
# outer_edge = outputs[output_set[node.data]]

# for edge in state.all_edges(node):
# if (edge not in modified_edges
# and edge.data.data == node.data):
# for e in state.memlet_tree(edge):
# if e.data.data == node.data:
# e._data = helpers.unsqueeze_memlet(
# e.data, outer_edge.data)

# Replace nested SDFG parents with new SDFG
for nstate in nsdfg.nodes():
Expand All @@ -420,111 +319,3 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
sdfg._cfg_list = sdfg.reset_cfg_list()

return nsdfg.nodes()

# def _modify_access_to_access(
# self,
# input_edges: Dict[nodes.Node, MultiConnectorEdge],
# nsdfg: SDFG,
# nstate: SDFGState,
# state: SDFGState,
# orig_data: Dict[Union[nodes.AccessNode, MultiConnectorEdge], str],
# ) -> Set[MultiConnectorEdge]:
# """
# Deals with access->access edges where both sides are non-transient.
# """
# result = set()
# for node, top_edge in input_edges.items():
# for inner_edge in nstate.out_edges(node):
# if inner_edge.dst not in orig_data:
# continue
# inner_data = orig_data[inner_edge.dst]
# if (isinstance(inner_edge.dst, nodes.AccessNode)
# and not nsdfg.arrays[inner_data].transient):
# matching_edge: MultiConnectorEdge = next(
# state.out_edges_by_connector(top_edge.dst, inner_data))
# # Create memlet by unsqueezing both w.r.t. src and dst
# # subsets
# in_memlet = helpers.unsqueeze_memlet(
# inner_edge.data, top_edge.data)
# out_memlet = helpers.unsqueeze_memlet(
# inner_edge.data, matching_edge.data)
# new_memlet = in_memlet
# new_memlet.other_subset = out_memlet.subset

# # Connect with new edge
# state.add_edge(top_edge.src, top_edge.src_conn,
# matching_edge.dst, matching_edge.dst_conn,
# new_memlet)
# result.add(inner_edge)

# return result

# def _modify_memlet_path(
# self,
# new_edges: Dict[nodes.Node, MultiConnectorEdge],
# nstate: SDFGState,
# state: SDFGState,
# inner_to_outer: Dict[nodes.Node, MultiConnectorEdge],
# inputs: bool,
# edges_to_ignore: Set[MultiConnectorEdge],
# ) -> Set[MultiConnectorEdge]:
# """ Modifies memlet paths in an inlined SDFG. Returns set of modified
# edges.
# """
# result = set()
# for node, top_edge in new_edges.items():
# inner_edges = (nstate.out_edges(node)
# if inputs else nstate.in_edges(node))
# for inner_edge in inner_edges:
# if inner_edge in edges_to_ignore:
# continue
# new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
# top_edge.data)
# if inputs:
# if inner_edge.dst in inner_to_outer:
# dst = inner_to_outer[inner_edge.dst]
# else:
# dst = inner_edge.dst

# new_edge = state.add_edge(top_edge.src, top_edge.src_conn,
# dst, inner_edge.dst_conn,
# new_memlet)
# mtree = state.memlet_tree(new_edge)
# else:
# if inner_edge.src in inner_to_outer:
# # don't add edges twice
# continue

# new_edge = state.add_edge(inner_edge.src,
# inner_edge.src_conn, top_edge.dst,
# top_edge.dst_conn, new_memlet)
# mtree = state.memlet_tree(new_edge)

# # Modify all memlets going forward/backward
# def traverse(mtree_node):
# result.add(mtree_node.edge)
# mtree_node.edge._data = helpers.unsqueeze_memlet(
# mtree_node.edge.data, top_edge.data)
# for child in mtree_node.children:
# traverse(child)

# for child in mtree.children:
# traverse(child)

# return result

# def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str],
# new_edges: Dict[str, MultiConnectorEdge],
# nstate: SDFGState, state: SDFGState, inputs: bool):
# anodes = nstate.source_nodes() if inputs else nstate.sink_nodes()
# reshp = {repldict[r]: r for r in reshapes}
# for node in anodes:
# if not isinstance(node, nodes.AccessNode):
# continue
# if node.data not in reshp:
# continue
# edge = new_edges[reshp[node.data]]
# if inputs:
# state.add_edge(edge.src, edge.src_conn, node, None, edge.data)
# else:
# state.add_edge(node, None, edge.dst, edge.dst_conn, edge.data)

0 comments on commit 653ec33

Please sign in to comment.