Skip to content

Commit

Permalink
Added a new test for the map fusion.
Browse files Browse the repository at this point in the history
It mainly verifies that fusion stop at a certain invalid point.
  • Loading branch information
philip-paul-mueller committed Dec 5, 2024
1 parent a412394 commit fa21bd3
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,96 @@ def test_fusion_dataflow_intermediate_downstream():
assert {"T", "output_1"} == {edge.dst.data for edge in state.out_edges(upper_map_exit) if isinstance(edge.dst, nodes.AccessNode)}


def test_fusion_non_strict_dataflow_implicit_dependency():
"""
This test simulates if the fusion respect implicit dependencies, given by access nodes.
This test simulates a situation that could arise if non strict dataflow is enabled.
The test ensures that the fusion does not continue fusing in this situation.
"""
sdfg = dace.SDFG("fusion_strict_dataflow_implicit_dependency_sdfg")
state = sdfg.add_state(is_start_block=True)
names = ["A", "B", "T1", "T2", "C"]
for name in names:
sdfg.add_array(
name,
shape=(10,),
dtype=dace.float64,
transient=False,
)
sdfg.arrays["T1"].transient = True
sdfg.arrays["T2"].transient = True

me, mx = state.add_map(
"first_map",
ndrange={"__i0": "0:10"}
)
tskl1 = state.add_tasklet(
"tskl1",
inputs={"__in1", "__in2"},
code="__out = __in1 * __in2",
outputs={"__out"}
)
tskl2 = state.add_tasklet(
"tskl2",
inputs={"__in1", "__in2"},
code="__out = (__in1 + __in2) / 2",
outputs={"__out"}
)
A, B, T1, T2 = (state.add_access(name) for name in names[:-1])

state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]"))
state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10]"))
me.add_in_connector("IN_A")
me.add_in_connector("IN_B")

state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[__i0]"))
state.add_edge(me, "OUT_B", tskl1, "__in2", dace.Memlet("B[__i0]"))
state.add_edge(me, "OUT_A", tskl2, "__in1", dace.Memlet("A[__i0]"))
state.add_edge(me, "OUT_B", tskl2, "__in2", dace.Memlet("B[__i0]"))
me.add_out_connector("OUT_A")
me.add_out_connector("OUT_B")

state.add_edge(tskl1, "__out", mx, "IN_T1", dace.Memlet("T1[__i0]"))
state.add_edge(tskl2, "__out", mx, "IN_T2", dace.Memlet("T2[__i0]"))
mx.add_in_connector("IN_T1")
mx.add_in_connector("IN_T2")

state.add_edge(mx, "OUT_T1", T1, None, dace.Memlet("T1[0:10]"))
state.add_edge(mx, "OUT_T2", T2, None, dace.Memlet("T2[0:10]"))
mx.add_out_connector("OUT_T1")
mx.add_out_connector("OUT_T2")

state.add_mapped_tasklet(
"second_map",
map_ranges={"__in0": "0:10"},
inputs={"__in1": dace.Memlet("T1[__i0]")},
code="if __in1 < 0.5:\n\t__out = 100.",
outputs={"__out": dace.Memlet("T2[__i0]", dynamic=True)},
input_nodes={T1},
external_edges=True,
)

state2 = sdfg.add_state_after(state)
state2.add_edge(
state2.add_access("T2"),
None,
state2.add_access("C"),
None,
dace.Memlet("T2[0:10] -> [0:10]"),
)
sdfg.validate()

count = sdfg.apply_transformations_repeated(
MapFusion(strict_dataflow=False),
validate=True,
validate_all=True,
)
assert count == 0


if __name__ == '__main__':
test_fusion_non_strict_dataflow_implicit_dependency()
test_fusion_strict_dataflow_pointwise()
test_fusion_strict_dataflow_not_pointwise()
test_fusion_dataflow_intermediate()
Expand Down

0 comments on commit fa21bd3

Please sign in to comment.