Skip to content

Commit

Permalink
Apply liftstructviews to meta codes in control flow regions (#1885)
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Jan 22, 2025
1 parent 1c2d7b5 commit 017d58f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 36 deletions.
3 changes: 2 additions & 1 deletion dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,8 @@ def _Name(self, t: ast.Name):

# Replace values with their code-generated names (for example, persistent arrays)
desc = self.sdfg.arrays[t.id]
self.write(ptr(t.id, desc, self.sdfg, self.codegen))
ref = '' if not isinstance(desc, data.View) else '*'
self.write(ref + ptr(t.id, desc, self.sdfg, self.codegen))

def _Attribute(self, t: ast.Attribute):
from dace.frontend.python.astutils import rname
Expand Down
89 changes: 54 additions & 35 deletions dace/transformation/passes/lift_struct_views.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import ast
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Optional, Set, Tuple, Union

from dace import SDFG, Memlet, SDFGState
from dace.frontend.python import astutils
from dace.properties import CodeBlock
from dace.sdfg import nodes as nd
from dace.sdfg.graph import Edge, MultiConnectorEdge
from dace.sdfg.sdfg import InterstateEdge, memlets_in_ast
from dace.sdfg.sdfg import InterstateEdge
from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion
from dace.transformation import pass_pipeline as ppl
from dace import data as dt
from dace import dtypes


import sys
if sys.version_info >= (3, 8):
from typing import Literal
dirtype = Literal['in', 'out']
else:
dirtype = "Literal['in', 'out']"
from typing import Literal

from tests.npbench.misc.stockham_fft_test import R
dirtype = Literal['in', 'out']

class RecodeAttributeNodes(ast.NodeTransformer):

Expand Down Expand Up @@ -193,21 +192,21 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
class InterstateEdgeRecoder(ast.NodeTransformer):

sdfg: SDFG
edge: Edge[InterstateEdge]
element: Union[Edge[InterstateEdge], Tuple[ControlFlowBlock, CodeBlock]]
data_name: str
data: Union[dt.Structure, dt.ContainerArray]
views_constructed: Set[str]
isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]
_lifting_state: SDFGState

def __init__(self, sdfg: SDFG, edge: Edge[InterstateEdge], data_name: str,
data: Union[dt.Structure, dt.ContainerArray],
isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]):
def __init__(self, sdfg: SDFG, element: Union[Edge[InterstateEdge], Tuple[ControlFlowBlock, CodeBlock]],
data_name: str, data: Union[dt.Structure, dt.ContainerArray],
lifting_state: Optional[SDFGState] = None):
self.sdfg = sdfg
self.edge = edge
self.element = element
self.data_name = data_name
self.data = data
self.views_constructed = set()
self.isedge_lifting_state_dict = isedge_lifting_state_dict
self._lifting_state = lifting_state

def _handle_simple_name_access(self, node: ast.Attribute) -> Any:
struct: dt.Structure = self.data
Expand Down Expand Up @@ -283,25 +282,25 @@ def _handle_sliced_access(self, node: ast.Attribute, val: ast.Subscript) -> Any:
return self.generic_visit(replacement)

def _get_or_create_lifting_state(self) -> Tuple[SDFGState, nd.AccessNode]:
# Add a state for lifting before the edge, if there isn't one that was created already.
if self.edge.data in self.isedge_lifting_state_dict:
lift_state = self.isedge_lifting_state_dict[self.edge.data]
else:
pre_node: ControlFlowBlock = self.edge.src
lift_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting')
self.isedge_lifting_state_dict[self.edge.data] = lift_state
# Add a state for lifting before the access, if there isn't one that was created already.
if self._lifting_state is None:
if isinstance(self.element, Edge):
pre_node: ControlFlowBlock = self.element.src
self._lifting_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting')
else:
self._lifting_state = self.element[0].parent_graph.add_state_before(self.element[0])

# Add a node for the original data container so the view can be connected to it. This may already be a view from
# a previous iteration of lifting, but in that case it is already correctly connected to a root data container.
data_node = None
for dn in lift_state.data_nodes():
for dn in self._lifting_state.data_nodes():
if dn.data == self.data_name:
data_node = dn
break
if data_node is None:
data_node = lift_state.add_access(self.data_name)
data_node = self._lifting_state.add_access(self.data_name)

return lift_state, data_node
return self._lifting_state, data_node

def visit_Attribute(self, node: ast.Attribute) -> Any:
if not node.value:
Expand Down Expand Up @@ -360,8 +359,6 @@ class LiftStructViews(ppl.Pass):

CATEGORY: str = 'Optimization Preparation'

_isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] = dict()

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes | ppl.Modifies.Tasklets | ppl.Modifies.Memlets

Expand All @@ -371,6 +368,26 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {}

def _lift_control_flow_region_access(self, cfg: ControlFlowRegion, result: Dict[str, Set[str]]) -> bool:
lifted_something = False
lifting_state = None
for code_block in cfg.get_meta_codeblocks():
codes = code_block.code if isinstance(code_block.code, list) else [code_block.code]
for code in codes:
for data in _data_containers_in_ast(code, cfg.sdfg.arrays.keys()):
if '.' in data:
continue
container = cfg.sdfg.arrays[data]
if isinstance(container, (dt.Structure, dt.ContainerArray)):
if lifting_state is None:
lifting_state = cfg.parent_graph.add_state_before(cfg)
visitor = InterstateEdgeRecoder(cfg.sdfg, (cfg, code_block), data, container, lifting_state)
visitor.visit(code)
if visitor.views_constructed:
result[data].update(visitor.views_constructed)
lifted_something = True
return lifted_something

def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], result: Dict[str, Set[str]]) -> bool:
lifted_something = False
for k in edge.data.assignments.keys():
Expand All @@ -383,12 +400,13 @@ def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], resul
continue
container = cfg.sdfg.arrays[data]
if isinstance(container, (dt.Structure, dt.ContainerArray)):
visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict)
visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container)
new_code = visitor.visit(assignment_ast)
edge.data.assignments[k] = astutils.unparse(new_code)
assignment_ast = new_code
result[data].update(visitor.views_constructed)
lifted_something = True
if visitor.views_constructed:
result[data].update(visitor.views_constructed)
lifted_something = True
if not edge.data.is_unconditional():
condition_ast = edge.data.condition.code[0]
data_in_edge = _data_containers_in_ast(condition_ast, cfg.sdfg.arrays.keys())
Expand All @@ -397,12 +415,13 @@ def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], resul
continue
container = cfg.sdfg.arrays[data]
if isinstance(container, (dt.Structure, dt.ContainerArray)):
visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict)
visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container)
new_code = visitor.visit(condition_ast)
edge.data.condition = CodeBlock([new_code])
condition_ast = new_code
result[data].update(visitor.views_constructed)
lifted_something = True
if visitor.views_constructed:
result[data].update(visitor.views_constructed)
lifted_something = True
return lifted_something

def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.Tasklet,
Expand All @@ -416,11 +435,9 @@ def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.

# Perform lifting.
code_list = tasklet.code.code if isinstance(tasklet.code.code, list) else [tasklet.code.code]
new_code_list = []
for code in code_list:
visitor = RecodeAttributeNodes(state, data_node, connector, data, tasklet, edge.data, direction)
new_code = visitor.visit(code)
new_code_list.append(new_code)
visitor.visit(code)
new_names.update(visitor.views_constructed)

# Clean up by removing the lifted connector and connected edges.
Expand Down Expand Up @@ -471,6 +488,8 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Dict[str, Set[str]]]:
lifted_something_this_round = True
for edge in cfg.edges():
lifted_something_this_round |= self._lift_isedge(cfg, edge, result)

lifted_something_this_round |= self._lift_control_flow_region_access(cfg, result)
if not lifted_something_this_round:
break
else:
Expand Down
38 changes: 38 additions & 0 deletions tests/passes/lift_struct_views_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests the LiftStructViews pass. """

import numpy as np
import dace
from dace.sdfg.state import LoopRegion
from dace.transformation.pass_pipeline import FixedPointPipeline
from dace.transformation.passes.lift_struct_views import LiftStructViews

Expand Down Expand Up @@ -166,9 +168,45 @@ def test_sliced_multi_tasklet_access_to_cont_array():
assert sdfg.is_valid()


def test_lift_in_loop_meta_code():
sdfg = dace.SDFG('lift_in_loop_meta_code')
bounds_struct = dace.data.Structure({
'start': dace.data.Scalar(dace.int32),
'end': dace.data.Scalar(dace.int32),
'step': dace.data.Scalar(dace.int32),
})
sdfg.add_datadesc('bounds', bounds_struct)
sdfg.add_array('A', (20,), dace.int32)
loop = LoopRegion('loop', 'i < bounds.end', 'i', 'i = bounds.start', 'i = i + bounds.step')
sdfg.add_node(loop, is_start_block=True)
state = loop.add_state('state', is_start_block=True)
a_write = state.add_access('A')
t1 = state.add_tasklet('t1', {}, {'o1'}, 'o1 = 1')
state.add_edge(t1, 'o1', a_write, None, dace.Memlet('A[i]'))

assert len(sdfg.nodes()) == 1
assert len(sdfg.arrays) == 2
assert sdfg.is_valid()

FixedPointPipeline([LiftStructViews()]).apply_pass(sdfg, {})

assert len(sdfg.nodes()) == 2
assert len(sdfg.arrays) == 5
assert sdfg.is_valid()

a = np.zeros((20,), np.int32)
valid = np.full((20,), 1, np.int32)
inpBounds = bounds_struct.dtype._typeclass.as_ctypes()(start=0, end=20, step=1)
func = sdfg.compile()
func(A=a, bounds=inpBounds)

assert np.allclose(a, valid)


if __name__ == '__main__':
test_simple_tasklet_access()
test_sliced_tasklet_access()
test_sliced_multi_tasklet_access()
test_tasklet_access_to_cont_array()
test_sliced_multi_tasklet_access_to_cont_array()
test_lift_in_loop_meta_code()

0 comments on commit 017d58f

Please sign in to comment.