Skip to content

Commit

Permalink
Added more tests to the map fusion and refined some others.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Jan 17, 2025
1 parent c53f939 commit a3842f9
Showing 1 changed file with 60 additions and 11 deletions.
71 changes: 60 additions & 11 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,11 @@ def test_fusion_dataflow_intermediate():


def test_fusion_dataflow_intermediate_2():
# Because `A` is not also output transformation applies.
# The transformation applies for two reasons, first reading and writing `A`
# is pointwise. Furthermore, there is no further access to `A` after the
# intermediate node. Note that if the second map would also have an output
# that refers to `A` then the transformation would not apply regardless
# of the strict dataflow mode.
sdfg, state = _make_strict_dataflow_sdfg_pointwise(
input_data="A",
intermediate_data="A",
Expand All @@ -977,6 +981,21 @@ def test_fusion_dataflow_intermediate_2():
assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)}


def test_fusion_dataflow_intermediate_3():
# This is exactly the same situation as in `test_fusion_dataflow_intermediate_2()`
# with the exception that now the access to `A` is no longer pointwise, thus
# the transformation does not apply. Note that this SDFG is wrong, it is only
# here to show that the case is detected.
sdfg, state = _make_strict_dataflow_sdfg_pointwise(
input_data="A",
intermediate_data="A",
output_data="O",
input_read="9 - __i0",
output_write="__i0",
)
apply_fusion(sdfg, removed_maps=0, strict_dataflow=True)


def test_fusion_dataflow_intermediate_downstream():
# Because the intermediate `T` is used downstream again,
# the transformation can not apply.
Expand All @@ -1001,10 +1020,18 @@ def test_fusion_dataflow_intermediate_downstream():
input_nodes={output_1},
external_edges=True,
)

# Make another state where `T` is written back, such that it is not dead data flow.
state2 = sdfg.add_state_after(state)
sdfg.add_datadesc("output_2", sdfg.arrays["output_1"].clone())
state2.add_nedge(
state2.add_access("T"),
state2.add_access("output_2"),
sdfg.make_array_memlet("T"),
)
sdfg.validate()

apply_fusion(sdfg, removed_maps=0, strict_dataflow=True)
sdfg.view()

# However without strict dataflow, the merge is possible.
apply_fusion(sdfg, removed_maps=1, strict_dataflow=False)
Expand Down Expand Up @@ -1238,8 +1265,10 @@ def test_inner_map_dependency_resolved():
apply_fusion(sdfg, removed_maps=1, final_maps=1)


def _impl_fusion_intermediate_different_access(modified_shape: bool):

def _impl_fusion_intermediate_different_access(
modified_shape: bool,
traditional_memlet_direction: bool
):
def ref(A, B):
T = np.zeros((A.shape[0] + 1, 2))
for i in range(A.shape[0]):
Expand Down Expand Up @@ -1338,14 +1367,26 @@ def ref(A, B):
)
)

temp_subset = (
"0, 0:2"
if modified_shape
else "0:2"
)
T_subset = "__i0 + 1, 0:2"

if traditional_memlet_direction:
mem_data = "T"
mem_subset = T_subset
mem_other_subset = temp_subset
else:
mem_data = "temp"
mem_subset = temp_subset
mem_other_subset = T_subset

state.add_edge(
temp, None,
mx1, "IN_temp",
dace.Memlet(
"temp[0, 0:2] -> [__i0 + 1, 0:2]"
if modified_shape
else "temp[0:2] -> [__i0 + 1, 0:2]"
)
dace.Memlet(f"{mem_data}[{mem_subset}] -> [{mem_other_subset}]")
)
state.add_edge(
mx1, "OUT_temp",
Expand Down Expand Up @@ -1384,11 +1425,19 @@ def ref(A, B):


def test_fusion_intermediate_different_access():
_impl_fusion_intermediate_different_access(modified_shape=False)
_impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=False)


def test_fusion_intermediate_different_access_2():
_impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=True)


def test_fusion_intermediate_different_access_mod_shape():
_impl_fusion_intermediate_different_access(modified_shape=True)
_impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=False)


def test_fusion_intermediate_different_access_mod_shape_2():
_impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=True)


@pytest.mark.skip(reason="This feature is not yet fully supported.")
Expand Down

0 comments on commit a3842f9

Please sign in to comment.