Skip to content

Commit

Permalink
Included the changes suggested by Philip Schaad.
Browse files Browse the repository at this point in the history
This includes
- The renaming to `SingleUseData`.
- The correct scanning of InterstateEdges, before no nested edges were considered.
- Inclusion of the conditions, such as the conditions used by `ConditionalBlock`, in the scan, they are now handled in the same way as interstate edges.

There is also a new test for the last case.
  • Loading branch information
philip-paul-mueller committed Jan 28, 2025
1 parent d2d9f50 commit ea01d03
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 72 deletions.
64 changes: 36 additions & 28 deletions dace/transformation/passes/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,19 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]:

@properties.make_properties
@transformation.explicit_cf_compatible
class FindExclusiveData(ppl.Pass):
class FindSingleUseData(ppl.Pass):
"""
For each SDFG find all data descriptors that are referenced in exactly one location.
This means that for every data descriptor there exists exactly one AccessNode that
refers to that data. In addition to this the following rules applies as well:
- If the data is read by at least one interstate edge it will not be classified as exclusive.
- If there is no reference to a data descriptor, i.e. it exists inside `SDFG.arrays`
but there is no AccessNode, then it is _not_ classified as exclusive.
In addition to the requirement that there exists exactly one AccessNode that
refers to a data descriptor the following conditions have to be meet as well:
- The data is not read on an interstate edge.
- The data is not accessed in the branch condition, loop condition, etc. of
control flow regions.
- There must be at least one AccessNode that refers to the data. I.e. if it exists
inside `SDFG.arrays` but there is no AccessNode, then it is _not_ included.
It is also important to note that the degree of the AccessNodes are ignored.
"""

CATEGORY: str = 'Analysis'
Expand All @@ -353,49 +357,53 @@ def modifies(self) -> ppl.Modifies:

def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.AccessNodes & ppl.Modifies.States
return modified & ppl.Modifies.AccessNodes & ppl.Modifies.CFG

def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFG, Set[str]]:
"""
:return: A dictionary mapping SDFGs to a `set` of strings containing the name
of the data descriptors that are exclusively used.
of the data descriptors that are only used once.
"""
# pschaad: Should we index on cfg or the SDFG itself.
# TODO(pschaad): Should we index on cfg or the SDFG itself.
exclusive_data: Dict[SDFG, Set[str]] = {}
for nsdfg in sdfg.all_sdfgs_recursive():
exclusive_data[nsdfg] = self._find_exclusive_data_in_sdfg(nsdfg)
exclusive_data[nsdfg] = self._find_single_use_data_in_sdfg(nsdfg)
return exclusive_data

def _find_exclusive_data_in_sdfg(self, sdfg: SDFG) -> Set[str]:
"""Scans an SDFG and computes the exclusive data for that SDFG.
def _find_single_use_data_in_sdfg(self, sdfg: SDFG) -> Set[str]:
"""Scans an SDFG and computes the data that is only used once in the SDFG.
This function only scans `sdfg` and does not go into nested ones.
The rules used to classify data descriptors are outlined above. The function
will not scan nested SDFGs.
:return: The set of data descriptors that have exclusive access.
:return: The set of data descriptors that are used once in the SDFG.
"""
# Data descriptor that are classified, up to now, as exclusive.
# We add and data that we do not know to it the first time we seen it
# and might remove it if we found another reference.
exclusive_data: Set[str] = set()
# If we encounter a data descriptor for the first time we immediately
# classify it as single use. We will undo this decision as soon as
# learn that it is used somewhere else.
single_use_data: Set[str] = set()
previously_seen: Set[str] = set()

for state in sdfg.states():
for dnode in state.data_nodes():
data_name: str = dnode.data
if data_name in exclusive_data:
exclusive_data.discard(data_name) # Classified too early; Undo
if data_name in single_use_data:
single_use_data.discard(data_name) # Classified too early -> Undo
elif data_name not in previously_seen:
exclusive_data.add(data_name) # Never seen; Assume it is exclusive.
single_use_data.add(data_name) # Never seen -> Assume single use
previously_seen.add(data_name)

# Compute the set of all data that is accessed, i.e. read, by the edges.
interstate_read_symbols: Set[str] = set()
for edge in sdfg.edges():
interstate_read_symbols.update(edge.data.free_symbols)
# By definition, data that is referenced by interstate edges is not single
# use data, also remove it.
for edge in sdfg.all_interstate_edges():
single_use_data.difference_update(edge.data.free_symbols)

# By definition, data that is referenced by the conditions (branching condition,
# loop condition, ...) is not single use data, also remove that.
for cfr in sdfg.all_control_flow_regions():
single_use_data.difference_update(cfr.used_symbols(all_symbols=True, with_contents=False))

# Enforces the first rule, "if data is accessed by an interstate edge it will
# not be classified as exclusive".
return exclusive_data.difference(interstate_read_symbols)
return single_use_data


@properties.make_properties
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
from typing import Dict, Set, Tuple
import dace
from dace.transformation.passes.analysis import FindExclusiveData
from dace.transformation.passes.analysis import FindSingleUseData

def perform_scan(sdfg: dace.SDFG) -> Dict[dace.SDFG, Set[str]]:
scanner = FindExclusiveData()
scanner = FindSingleUseData()
return scanner.apply_pass(sdfg, None)


def _make_all_exclusive_data_but_one_unused_sdfg() -> dace.SDFG:
sdfg = dace.SDFG('all_exclusive_data_but_one_unused_sdfg')
def _make_all_single_use_data_but_one_unused_sdfg() -> dace.SDFG:
sdfg = dace.SDFG('all_single_use_data_but_one_unused_sdfg')
state1 = sdfg.add_state(is_start_block=True)
state2 = sdfg.add_state_after(state1)

Expand All @@ -35,18 +35,19 @@ def _make_all_exclusive_data_but_one_unused_sdfg() -> dace.SDFG:
return sdfg


def test_all_exclusive_data_but_one_unused():
sdfg = _make_all_exclusive_data_but_one_unused_sdfg()
def test_all_single_use_data_but_one_unused():
sdfg = _make_all_single_use_data_but_one_unused_sdfg()
assert len(sdfg.arrays) == 5

# Because it is not used `e` is not considered to be exclusively used.
# This is a matter of definition.
expected_exclusive_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'}
# Because `e` is not used inside the SDFG, it is not included in the returned set,
# all other descriptors are included because they appear once.
expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'}

exclusive_set = perform_scan(sdfg)
assert len(exclusive_set[sdfg]) == 4
single_use_set = perform_scan(sdfg)

assert exclusive_set[sdfg] == expected_exclusive_set
assert len(single_use_set[sdfg]) == 4
assert len(single_use_set) == 1
assert single_use_set[sdfg] == expected_single_use_set


def _make_multiple_access_same_state_sdfg() -> dace.SDFG:
Expand Down Expand Up @@ -79,12 +80,13 @@ def test_multiple_access_same_state():
sdfg = _make_multiple_access_same_state_sdfg()
assert len(sdfg.arrays) == 3

# `a` is not exclusive because there exists multiple access nodes in a single
# state for `a`.
expected_exclusive_set = {aname for aname in sdfg.arrays.keys() if aname != 'a'}
exclusive_set = perform_scan(sdfg)
assert len(exclusive_set[sdfg]) == 2
assert expected_exclusive_set == exclusive_set[sdfg]
# `a` is not single use data because there are multiple access nodes for it
# in a single state.
expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'a'}
single_use_set = perform_scan(sdfg)
assert len(single_use_set) == 1
assert len(single_use_set[sdfg]) == 2
assert expected_single_use_set == single_use_set[sdfg]


def _make_multiple_single_access_node_same_state_sdfg() -> dace.SDFG:
Expand Down Expand Up @@ -119,13 +121,14 @@ def test_multiple_single_access_node_same_state_sdfg() -> dace.SDFG:
sdfg = _make_multiple_single_access_node_same_state_sdfg()
assert len(sdfg.arrays) == 3

# Unlike `test_multiple_access_same_state()` here `a` is included in the exclusive
# Unlike `test_multiple_access_same_state()` here `a` is included in the single use
# set, because, there is only a single AccessNode, that is used multiple times,
# i.e. has an output degree larger than one.
expected_exclusive_set = sdfg.arrays.keys()
exclusive_set = perform_scan(sdfg)
assert len(exclusive_set[sdfg]) == 3
assert expected_exclusive_set == exclusive_set[sdfg]
expected_single_use_set = sdfg.arrays.keys()
single_use_set = perform_scan(sdfg)
assert len(single_use_set) == 1
assert len(single_use_set[sdfg]) == 3
assert expected_single_use_set == single_use_set[sdfg]


def _make_multiple_access_different_states_sdfg() -> dace.SDFG:
Expand Down Expand Up @@ -161,11 +164,12 @@ def test_multiple_access_different_states():
sdfg = _make_multiple_access_different_states_sdfg()
assert len(sdfg.arrays) == 3

# `a` is not included in the exclusive set, because it is used in two different states.
exclusive_set = perform_scan(sdfg)
expected_exclusive_set = {aname for aname in sdfg.arrays.keys() if aname != 'a'}
assert len(exclusive_set[sdfg]) == 2
assert expected_exclusive_set == exclusive_set[sdfg]
# `a` is not included in the single use set, because it is used in two different states.
single_use_set = perform_scan(sdfg)
expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'a'}
assert len(single_use_set) == 1
assert len(single_use_set[sdfg]) == 2
assert expected_single_use_set == single_use_set[sdfg]


def _make_access_only_on_interstate_edge_sdfg() -> dace.SDFG:
Expand Down Expand Up @@ -201,12 +205,13 @@ def test_access_only_on_interstate_edge():
sdfg = _make_access_only_on_interstate_edge_sdfg()
assert len(sdfg.arrays) == 5

# `e` is only accessed on the interstate edge. So it is technically an exclusive
# data. But by definition we handle this case as non exclusive.
expected_exclusive_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'}
exclusive_set = perform_scan(sdfg)
assert len(exclusive_set[sdfg]) == 4
assert exclusive_set[sdfg] == expected_exclusive_set
# `e` is only accessed on the interstate edge. So it is technically an single use
# data. But by definition we handle this case as non single_use.
expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'}
single_use_set = perform_scan(sdfg)
assert len(single_use_set) == 1
assert len(single_use_set[sdfg]) == 4
assert single_use_set[sdfg] == expected_single_use_set


def _make_additional_access_on_interstate_edge_sdfg() -> dace.SDFG:
Expand Down Expand Up @@ -248,11 +253,13 @@ def test_additional_access_on_interstate_edge():
sdfg = _make_additional_access_on_interstate_edge_sdfg()
assert len(sdfg.arrays) == 6

# As in `test_access_only_on_interstate_edge` `e` is not part of the exclusive set.
expected_exclusive_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'}
exclusive_set = perform_scan(sdfg)
assert len(exclusive_set[sdfg]) == 5
assert exclusive_set[sdfg] == expected_exclusive_set
# There is one AccessNode for `a`, but as in `test_access_only_on_interstate_edge`
# `e` is also used on the inter state edge, so it is not included.
expected_single_use_set = {aname for aname in sdfg.arrays.keys() if aname != 'e'}
single_use_set = perform_scan(sdfg)
assert len(single_use_set) == 1
assert len(single_use_set[sdfg]) == 5
assert single_use_set[sdfg] == expected_single_use_set


def _make_access_nested_nsdfg() -> dace.SDFG:
Expand Down Expand Up @@ -319,18 +326,109 @@ def test_access_nested_sdfg():
sdfg, nested_sdfg = _make_access_nested_sdfg()
assert all(len(nsdfg.arrays) == 2 for nsdfg in [sdfg, nested_sdfg])

# In both SDFGs all data descriptors are exclusive.
expected_exclusive_set = {'a', 'b'}
exclusive_sets = perform_scan(sdfg)
# In the top and the nested SDFG `a` and `b` are both used once, so for
# both they are included in the single use set.
# Essentially tests if there is separation between the two.
expected_single_use_set = {'a', 'b'}
single_use_sets = perform_scan(sdfg)

assert all(exclusive_sets[nsdfg] == expected_exclusive_set for nsdfg in [sdfg, nested_sdfg])
assert len(single_use_sets) == 2
assert all(single_use_sets[nsdfg] == expected_single_use_set for nsdfg in [sdfg, nested_sdfg])


def _make_conditional_block_sdfg() -> dace.SDFG:
sdfg = dace.SDFG("conditional_block_sdfg")

for name in ["a", "b", "c", "d", "cond", "cond2"]:
sdfg.add_scalar(
name,
dtype=dace.bool_ if name.startswith("cond") else dace.float64,
transient=False
)
sdfg.arrays["b"].transient = True
sdfg.arrays["cond2"].transient = True

entry_state = sdfg.add_state("entry", is_start_block=True)
entry_state.add_nedge(
entry_state.add_access("a"),
entry_state.add_access("b"),
sdfg.make_array_memlet("a")
)
cond_tasklet: dace.nodes.Tasklet = entry_state.add_tasklet(
"cond_processing",
inputs={"__in"},
code="__out = not __in",
outputs={"__out"},
)
entry_state.add_edge(
entry_state.add_access("cond"),
None,
cond_tasklet,
"__in",
dace.Memlet("cond[0]")
)
entry_state.add_edge(
cond_tasklet,
"__out",
entry_state.add_access("cond2"),
None,
dace.Memlet("cond2[0]")
)

if_region = dace.sdfg.state.ConditionalBlock("if")
sdfg.add_node(if_region)
sdfg.add_edge(
entry_state,
if_region,
dace.InterstateEdge()
)

then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg)
tstate = then_body.add_state("true_branch", is_start_block=True)
tstate.add_nedge(
tstate.add_access("b"),
tstate.add_access("c"),
sdfg.make_array_memlet("b")
)
if_region.add_branch(
dace.sdfg.state.CodeBlock("cond2"),
then_body
)

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg)
fstate = else_body.add_state("false_branch", is_start_block=True)
fstate.add_nedge(
fstate.add_access("b"),
fstate.add_access("d"),
sdfg.make_array_memlet("d")
)
if_region.add_branch(
dace.sdfg.state.CodeBlock("not (cond2)"),
else_body
)
sdfg.validate()
return sdfg


def test_conditional_block():
sdfg = _make_conditional_block_sdfg()

# `b` is not in no single use data, because there are three AccessNodes for it.
# `cond2` is no single use data, although there is exactly one AccessNode for
# it, it is used in the condition expression.
expected_single_use_set = {a for a in sdfg.arrays.keys() if a not in ["b", "cond2"]}
single_use_set = perform_scan(sdfg)

assert len(single_use_set) == 1
assert single_use_set[sdfg] == expected_single_use_set


if __name__ == '__main__':
test_all_exclusive_data_but_one_unused()
test_all_single_use_data_but_one_unused()
test_multiple_access_same_state()
test_multiple_single_access_node_same_state_sdfg()
test_multiple_access_different_states()
test_access_only_on_interstate_edge()
test_additional_access_on_interstate_edge()
test_access_nested_sdfg()
test_conditional_block()

0 comments on commit ea01d03

Please sign in to comment.