From 017d58ffc20786941faefce5b4e8e953faa6a688 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 22 Jan 2025 09:17:59 +0100 Subject: [PATCH] Apply liftstructviews to meta codes in control flow regions (#1885) --- dace/codegen/targets/cpp.py | 3 +- .../passes/lift_struct_views.py | 89 +++++++++++-------- tests/passes/lift_struct_views_test.py | 38 ++++++++ 3 files changed, 94 insertions(+), 36 deletions(-) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 89239abcb3..2e97fc5e76 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1063,7 +1063,8 @@ def _Name(self, t: ast.Name): # Replace values with their code-generated names (for example, persistent arrays) desc = self.sdfg.arrays[t.id] - self.write(ptr(t.id, desc, self.sdfg, self.codegen)) + ref = '' if not isinstance(desc, data.View) else '*' + self.write(ref + ptr(t.id, desc, self.sdfg, self.codegen)) def _Attribute(self, t: ast.Attribute): from dace.frontend.python.astutils import rname diff --git a/dace/transformation/passes/lift_struct_views.py b/dace/transformation/passes/lift_struct_views.py index 6744161000..8962d8539e 100644 --- a/dace/transformation/passes/lift_struct_views.py +++ b/dace/transformation/passes/lift_struct_views.py @@ -1,14 +1,14 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union from dace import SDFG, Memlet, SDFGState from dace.frontend.python import astutils from dace.properties import CodeBlock from dace.sdfg import nodes as nd from dace.sdfg.graph import Edge, MultiConnectorEdge -from dace.sdfg.sdfg import InterstateEdge, memlets_in_ast +from dace.sdfg.sdfg import InterstateEdge from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.transformation import pass_pipeline as ppl from dace import data as dt @@ -16,11 +16,10 @@ import sys -if sys.version_info >= (3, 8): - from typing import Literal - dirtype = Literal['in', 'out'] -else: - dirtype = "Literal['in', 'out']" +from typing import Literal + +from tests.npbench.misc.stockham_fft_test import R +dirtype = Literal['in', 'out'] class RecodeAttributeNodes(ast.NodeTransformer): @@ -193,21 +192,21 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: class InterstateEdgeRecoder(ast.NodeTransformer): sdfg: SDFG - edge: Edge[InterstateEdge] + element: Union[Edge[InterstateEdge], Tuple[ControlFlowBlock, CodeBlock]] data_name: str data: Union[dt.Structure, dt.ContainerArray] views_constructed: Set[str] - isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] + _lifting_state: SDFGState - def __init__(self, sdfg: SDFG, edge: Edge[InterstateEdge], data_name: str, - data: Union[dt.Structure, dt.ContainerArray], - isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]): + def __init__(self, sdfg: SDFG, element: Union[Edge[InterstateEdge], Tuple[ControlFlowBlock, CodeBlock]], + data_name: str, data: Union[dt.Structure, dt.ContainerArray], + lifting_state: Optional[SDFGState] = None): self.sdfg = sdfg - self.edge = edge + self.element = element self.data_name = data_name self.data = data self.views_constructed = set() - self.isedge_lifting_state_dict = isedge_lifting_state_dict + self._lifting_state = lifting_state def _handle_simple_name_access(self, node: ast.Attribute) -> Any: struct: dt.Structure = self.data @@ -283,25 +282,25 @@ def _handle_sliced_access(self, node: ast.Attribute, val: ast.Subscript) -> Any: return self.generic_visit(replacement) def _get_or_create_lifting_state(self) -> Tuple[SDFGState, nd.AccessNode]: - # Add a state for lifting before the edge, if there isn't one that was created already. - if self.edge.data in self.isedge_lifting_state_dict: - lift_state = self.isedge_lifting_state_dict[self.edge.data] - else: - pre_node: ControlFlowBlock = self.edge.src - lift_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting') - self.isedge_lifting_state_dict[self.edge.data] = lift_state + # Add a state for lifting before the access, if there isn't one that was created already. + if self._lifting_state is None: + if isinstance(self.element, Edge): + pre_node: ControlFlowBlock = self.element.src + self._lifting_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting') + else: + self._lifting_state = self.element[0].parent_graph.add_state_before(self.element[0]) # Add a node for the original data container so the view can be connected to it. This may already be a view from # a previous iteration of lifting, but in that case it is already correctly connected to a root data container. data_node = None - for dn in lift_state.data_nodes(): + for dn in self._lifting_state.data_nodes(): if dn.data == self.data_name: data_node = dn break if data_node is None: - data_node = lift_state.add_access(self.data_name) + data_node = self._lifting_state.add_access(self.data_name) - return lift_state, data_node + return self._lifting_state, data_node def visit_Attribute(self, node: ast.Attribute) -> Any: if not node.value: @@ -360,8 +359,6 @@ class LiftStructViews(ppl.Pass): CATEGORY: str = 'Optimization Preparation' - _isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] = dict() - def modifies(self) -> ppl.Modifies: return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes | ppl.Modifies.Tasklets | ppl.Modifies.Memlets @@ -371,6 +368,26 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {} + def _lift_control_flow_region_access(self, cfg: ControlFlowRegion, result: Dict[str, Set[str]]) -> bool: + lifted_something = False + lifting_state = None + for code_block in cfg.get_meta_codeblocks(): + codes = code_block.code if isinstance(code_block.code, list) else [code_block.code] + for code in codes: + for data in _data_containers_in_ast(code, cfg.sdfg.arrays.keys()): + if '.' in data: + continue + container = cfg.sdfg.arrays[data] + if isinstance(container, (dt.Structure, dt.ContainerArray)): + if lifting_state is None: + lifting_state = cfg.parent_graph.add_state_before(cfg) + visitor = InterstateEdgeRecoder(cfg.sdfg, (cfg, code_block), data, container, lifting_state) + visitor.visit(code) + if visitor.views_constructed: + result[data].update(visitor.views_constructed) + lifted_something = True + return lifted_something + def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], result: Dict[str, Set[str]]) -> bool: lifted_something = False for k in edge.data.assignments.keys(): @@ -383,12 +400,13 @@ def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], resul continue container = cfg.sdfg.arrays[data] if isinstance(container, (dt.Structure, dt.ContainerArray)): - visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict) + visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container) new_code = visitor.visit(assignment_ast) edge.data.assignments[k] = astutils.unparse(new_code) assignment_ast = new_code - result[data].update(visitor.views_constructed) - lifted_something = True + if visitor.views_constructed: + result[data].update(visitor.views_constructed) + lifted_something = True if not edge.data.is_unconditional(): condition_ast = edge.data.condition.code[0] data_in_edge = _data_containers_in_ast(condition_ast, cfg.sdfg.arrays.keys()) @@ -397,12 +415,13 @@ def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], resul continue container = cfg.sdfg.arrays[data] if isinstance(container, (dt.Structure, dt.ContainerArray)): - visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict) + visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container) new_code = visitor.visit(condition_ast) edge.data.condition = CodeBlock([new_code]) condition_ast = new_code - result[data].update(visitor.views_constructed) - lifted_something = True + if visitor.views_constructed: + result[data].update(visitor.views_constructed) + lifted_something = True return lifted_something def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.Tasklet, @@ -416,11 +435,9 @@ def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd. # Perform lifting. code_list = tasklet.code.code if isinstance(tasklet.code.code, list) else [tasklet.code.code] - new_code_list = [] for code in code_list: visitor = RecodeAttributeNodes(state, data_node, connector, data, tasklet, edge.data, direction) - new_code = visitor.visit(code) - new_code_list.append(new_code) + visitor.visit(code) new_names.update(visitor.views_constructed) # Clean up by removing the lifted connector and connected edges. @@ -471,6 +488,8 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Dict[str, Set[str]]]: lifted_something_this_round = True for edge in cfg.edges(): lifted_something_this_round |= self._lift_isedge(cfg, edge, result) + + lifted_something_this_round |= self._lift_control_flow_region_access(cfg, result) if not lifted_something_this_round: break else: diff --git a/tests/passes/lift_struct_views_test.py b/tests/passes/lift_struct_views_test.py index 71f19215b5..533fd24eac 100644 --- a/tests/passes/lift_struct_views_test.py +++ b/tests/passes/lift_struct_views_test.py @@ -1,7 +1,9 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the LiftStructViews pass. """ +import numpy as np import dace +from dace.sdfg.state import LoopRegion from dace.transformation.pass_pipeline import FixedPointPipeline from dace.transformation.passes.lift_struct_views import LiftStructViews @@ -166,9 +168,45 @@ def test_sliced_multi_tasklet_access_to_cont_array(): assert sdfg.is_valid() +def test_lift_in_loop_meta_code(): + sdfg = dace.SDFG('lift_in_loop_meta_code') + bounds_struct = dace.data.Structure({ + 'start': dace.data.Scalar(dace.int32), + 'end': dace.data.Scalar(dace.int32), + 'step': dace.data.Scalar(dace.int32), + }) + sdfg.add_datadesc('bounds', bounds_struct) + sdfg.add_array('A', (20,), dace.int32) + loop = LoopRegion('loop', 'i < bounds.end', 'i', 'i = bounds.start', 'i = i + bounds.step') + sdfg.add_node(loop, is_start_block=True) + state = loop.add_state('state', is_start_block=True) + a_write = state.add_access('A') + t1 = state.add_tasklet('t1', {}, {'o1'}, 'o1 = 1') + state.add_edge(t1, 'o1', a_write, None, dace.Memlet('A[i]')) + + assert len(sdfg.nodes()) == 1 + assert len(sdfg.arrays) == 2 + assert sdfg.is_valid() + + FixedPointPipeline([LiftStructViews()]).apply_pass(sdfg, {}) + + assert len(sdfg.nodes()) == 2 + assert len(sdfg.arrays) == 5 + assert sdfg.is_valid() + + a = np.zeros((20,), np.int32) + valid = np.full((20,), 1, np.int32) + inpBounds = bounds_struct.dtype._typeclass.as_ctypes()(start=0, end=20, step=1) + func = sdfg.compile() + func(A=a, bounds=inpBounds) + + assert np.allclose(a, valid) + + if __name__ == '__main__': test_simple_tasklet_access() test_sliced_tasklet_access() test_sliced_multi_tasklet_access() test_tasklet_access_to_cont_array() test_sliced_multi_tasklet_access_to_cont_array() + test_lift_in_loop_meta_code()