From a3842f97ebba4a0bec8c2783c02c129acc244524 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 17 Jan 2025 11:35:07 +0100 Subject: [PATCH] Added more tests to the map fusion and refined some others. --- tests/transformations/mapfusion_test.py | 71 +++++++++++++++++++++---- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index ed7e5666ca..1e7678c3aa 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -965,7 +965,11 @@ def test_fusion_dataflow_intermediate(): def test_fusion_dataflow_intermediate_2(): - # Because `A` is not also output transformation applies. + # The transformation applies for two reasons, first reading and writing `A` + # is pointwise. Furthermore, there is no further access to `A` after the + # intermediate node. Note that if the second map would also have an output + # that refers to `A` then the transformation would not apply regardless + # of the strict dataflow mode. sdfg, state = _make_strict_dataflow_sdfg_pointwise( input_data="A", intermediate_data="A", @@ -977,6 +981,21 @@ def test_fusion_dataflow_intermediate_2(): assert {"A", "O"} == {edge.dst.data for edge in state.out_edges(map_exit) if isinstance(edge.dst, nodes.AccessNode)} +def test_fusion_dataflow_intermediate_3(): + # This is exactly the same situation as in `test_fusion_dataflow_intermediate_2()` + # with the exception that now the access to `A` is no longer pointwise, thus + # the transformation does not apply. Note that this SDFG is wrong, it is only + # here to show that the case is detected. + sdfg, state = _make_strict_dataflow_sdfg_pointwise( + input_data="A", + intermediate_data="A", + output_data="O", + input_read="9 - __i0", + output_write="__i0", + ) + apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) + + def test_fusion_dataflow_intermediate_downstream(): # Because the intermediate `T` is used downstream again, # the transformation can not apply. @@ -1001,10 +1020,18 @@ def test_fusion_dataflow_intermediate_downstream(): input_nodes={output_1}, external_edges=True, ) + + # Make another state where `T` is written back, such that it is not dead data flow. + state2 = sdfg.add_state_after(state) + sdfg.add_datadesc("output_2", sdfg.arrays["output_1"].clone()) + state2.add_nedge( + state2.add_access("T"), + state2.add_access("output_2"), + sdfg.make_array_memlet("T"), + ) sdfg.validate() apply_fusion(sdfg, removed_maps=0, strict_dataflow=True) - sdfg.view() # However without strict dataflow, the merge is possible. apply_fusion(sdfg, removed_maps=1, strict_dataflow=False) @@ -1238,8 +1265,10 @@ def test_inner_map_dependency_resolved(): apply_fusion(sdfg, removed_maps=1, final_maps=1) -def _impl_fusion_intermediate_different_access(modified_shape: bool): - +def _impl_fusion_intermediate_different_access( + modified_shape: bool, + traditional_memlet_direction: bool +): def ref(A, B): T = np.zeros((A.shape[0] + 1, 2)) for i in range(A.shape[0]): @@ -1338,14 +1367,26 @@ def ref(A, B): ) ) + temp_subset = ( + "0, 0:2" + if modified_shape + else "0:2" + ) + T_subset = "__i0 + 1, 0:2" + + if traditional_memlet_direction: + mem_data = "T" + mem_subset = T_subset + mem_other_subset = temp_subset + else: + mem_data = "temp" + mem_subset = temp_subset + mem_other_subset = T_subset + state.add_edge( temp, None, mx1, "IN_temp", - dace.Memlet( - "temp[0, 0:2] -> [__i0 + 1, 0:2]" - if modified_shape - else "temp[0:2] -> [__i0 + 1, 0:2]" - ) + dace.Memlet(f"{mem_data}[{mem_subset}] -> [{mem_other_subset}]") ) state.add_edge( mx1, "OUT_temp", @@ -1384,11 +1425,19 @@ def ref(A, B): def test_fusion_intermediate_different_access(): - _impl_fusion_intermediate_different_access(modified_shape=False) + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_2(): + _impl_fusion_intermediate_different_access(modified_shape=False, traditional_memlet_direction=True) def test_fusion_intermediate_different_access_mod_shape(): - _impl_fusion_intermediate_different_access(modified_shape=True) + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=False) + + +def test_fusion_intermediate_different_access_mod_shape_2(): + _impl_fusion_intermediate_different_access(modified_shape=True, traditional_memlet_direction=True) @pytest.mark.skip(reason="This feature is not yet fully supported.")