Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete Transition to Control Flow Regions #1676

Merged
merged 128 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
036e407
Add data dependency analyses
phschaad Sep 16, 2024
4aa13ed
Fix type
phschaad Sep 16, 2024
a228f34
Fix types
phschaad Sep 16, 2024
10c3b6c
Update tests
phschaad Sep 16, 2024
77ca17f
Fixes
phschaad Sep 16, 2024
4e6035d
Fix tests
phschaad Sep 17, 2024
b61a283
Add tests
phschaad Sep 17, 2024
05b1c28
Add loop lifting capabilities
phschaad Sep 18, 2024
f08d95e
Adjust loop detection to LLVM canonical semantics
phschaad Sep 19, 2024
1d90346
Test fix
phschaad Sep 19, 2024
6b5ef0c
Remove unnecessary imports
phschaad Sep 19, 2024
ba5ccdf
Merge branch 'master' into cf_block_data_deps
phschaad Sep 30, 2024
23af038
Improved loop detection
phschaad Oct 2, 2024
3fbe26b
Loop detection and lifting fixes
phschaad Oct 2, 2024
2bd5d00
Work on conditional lifting
phschaad Oct 2, 2024
dc640a9
Improve conditional block interface
phschaad Oct 7, 2024
e14e7a4
Merge branch 'cf_block_data_deps' into users/phschaad/adapt_passes
phschaad Oct 7, 2024
02d53fc
Remove files from other PR
phschaad Oct 7, 2024
338db80
Merge branch 'cf_block_data_deps' into users/phschaad/adapt_passes
phschaad Oct 7, 2024
074a990
Add back missing file
phschaad Oct 7, 2024
cd8a258
Adapt dead state elimination
phschaad Oct 7, 2024
f77655f
Adapt DeadDataflowElimination
phschaad Oct 7, 2024
49475ec
Fixes
phschaad Oct 8, 2024
094c896
Merge branch 'cf_block_data_deps' into users/phschaad/adapt_passes
phschaad Oct 8, 2024
2a74901
Bugfix
phschaad Oct 8, 2024
35cfded
Merge branch 'cf_block_data_deps' into users/phschaad/adapt_passes
phschaad Oct 8, 2024
8127900
Adapt trivial loop elimination and state elimination
phschaad Oct 8, 2024
04f41c2
gdapt passes:
phschaad Oct 8, 2024
852719d
Adapt loop to map
phschaad Oct 9, 2024
352171a
Fixes in loop to map
phschaad Oct 9, 2024
bbddc88
Finished LoopToMap
phschaad Oct 9, 2024
5ef1f46
Adapt loop unrolling
phschaad Oct 9, 2024
2d26e0f
Adapt loop peeling and unrolling
phschaad Oct 10, 2024
145c0ea
Added tests for the `_read_and_write_sets()`.
philip-paul-mueller Oct 11, 2024
38e748b
Added the fix from my MapFusion PR.
philip-paul-mueller Oct 11, 2024
3748c03
Now made `read_and_write_sets()` fully adhere to their own definition.
philip-paul-mueller Oct 11, 2024
6d83976
Update refine nested access
phschaad Oct 11, 2024
bb4e9c8
Fixes
phschaad Oct 11, 2024
bd76961
Adjust SDFG nesting
phschaad Oct 11, 2024
3ab4bf3
Updated a test for the `PruneConnectors` transformation.
philip-paul-mueller Oct 11, 2024
26e4ff0
Adapt more passes and add conditional pruning pass
phschaad Oct 11, 2024
b4feddf
Added code to `test_more_than_a_map` to ensure that the transformatio…
philip-paul-mueller Oct 11, 2024
324fa34
More fixes
phschaad Oct 11, 2024
103b4e5
Merge remote-tracking branch 'origin/master' into users/phschaad/adap…
phschaad Oct 13, 2024
e1c25b2
Merge remote-tracking branch 'spcl/master' into read-write-sets
philip-paul-mueller Oct 14, 2024
70fa3db
Added the new memlet creation syntax.
philip-paul-mueller Oct 14, 2024
b187a82
Modified some comments to make them clearer.
philip-paul-mueller Oct 14, 2024
9c6cb6c
Modified the `tests/transformations/move_loop_into_map_test.py::test_…
philip-paul-mueller Oct 14, 2024
251833f
Adapt scalar fission and symbol ssa
phschaad Oct 15, 2024
5b7bdad
Fixes
phschaad Oct 15, 2024
99225a5
Adapt reference reduction pass
phschaad Oct 15, 2024
d6c7c8b
Adapt constant propagation
phschaad Oct 16, 2024
607b098
Fix pytest arguments
phschaad Oct 16, 2024
05b3e59
Fixes
phschaad Oct 16, 2024
f5b617c
Fix invalid graph manipulation in test
phschaad Oct 16, 2024
33c9287
Fixes
phschaad Oct 16, 2024
9185897
Adapt composite fusion
phschaad Oct 16, 2024
bf7d822
Fix StencilTiling
phschaad Oct 17, 2024
78803d5
Add some region inlining
phschaad Oct 17, 2024
26a7fbc
Fixes to deepcopy
phschaad Oct 17, 2024
1862ba4
More various fixes
phschaad Oct 17, 2024
422edb5
Yet more fixes
phschaad Oct 17, 2024
a66c610
More fixes
phschaad Oct 18, 2024
b507a4b
Fixes
phschaad Oct 18, 2024
4084dfe
More fixes
phschaad Oct 18, 2024
d010620
More fixes
phschaad Oct 18, 2024
a047d37
And yet more
phschaad Oct 18, 2024
39db909
Another one
phschaad Oct 18, 2024
b5fc16f
Merge branch 'master' into read-write-sets
philip-paul-mueller Oct 22, 2024
b7fe242
Added a test to highlights the error.
philip-paul-mueller Oct 22, 2024
b546b07
I now removed the filtering inside the read and write set.
philip-paul-mueller Oct 22, 2024
8f9e72f
And more
phschaad Oct 22, 2024
6067e3b
Merge remote-tracking branch 'origin/master' into users/phschaad/adap…
phschaad Oct 22, 2024
2c4c17b
Fix inline multistate
phschaad Oct 22, 2024
ae20590
Fixed `state_test.py::test_read_and_write_set_filter`.
philip-paul-mueller Oct 23, 2024
db211fa
Fixed the `state_test.py::test_read_write_set` test.
philip-paul-mueller Oct 23, 2024
570437b
Fixed the `state_test.py::test_read_write_set_y_formation` test.
philip-paul-mueller Oct 23, 2024
6806dc1
Fix cyclic dependency
phschaad Oct 23, 2024
ab255e8
Fixes to codegen and data instrumentation
phschaad Oct 23, 2024
e97f5bc
Fix subgraph nesting
phschaad Oct 23, 2024
8de1c1e
Fixes to GPU codegen
phschaad Oct 23, 2024
cb80f0b
Fixed `move_loop_into_map_test.py::MoveLoopIntoMapTest::test_more_tha…
philip-paul-mueller Oct 23, 2024
b704a43
Fixed `prune_connectors_test.py::test_read_write_*`.
philip-paul-mueller Oct 23, 2024
f74d6e8
General improvements to some tests.
philip-paul-mueller Oct 23, 2024
e103924
Updated `refine_nested_access_test.py::test_rna_read_and_write_sets_d…
philip-paul-mueller Oct 23, 2024
56e756d
More GPU fixes
phschaad Oct 23, 2024
6e14e6d
More fixes
phschaad Oct 23, 2024
1902e3d
Merge branch 'read-write-sets' into users/phschaad/adapt_passes
phschaad Oct 23, 2024
0c65359
More bugfixes
phschaad Oct 23, 2024
ac72cb1
Fixes
phschaad Oct 23, 2024
3c4cea0
Merge remote-tracking branch 'origin/master' into users/phschaad/adap…
phschaad Oct 24, 2024
cc9e5f4
FPGA fixes
phschaad Oct 24, 2024
cfa0299
Merge remote-tracking branch 'origin/main' into users/phschaad/adapt_…
phschaad Oct 25, 2024
bc9f61e
Adapt state propagation into a pass to adapt it
phschaad Oct 29, 2024
2a26d24
Merge remote-tracking branch 'origin/main' into users/phschaad/adapt_…
phschaad Oct 29, 2024
f63f75d
Fix w/d test inlining
phschaad Oct 29, 2024
48c7cb4
Fix to block fusion
phschaad Oct 29, 2024
a0e2c59
Derped a test..
phschaad Oct 29, 2024
ca00134
Merge branch 'main' into users/phschaad/adapt_passes
phschaad Oct 31, 2024
8d03dc5
Merge branch 'main' into users/phschaad/adapt_passes
phschaad Nov 5, 2024
0828fa1
Fix inlining
phschaad Nov 5, 2024
07c822c
Merge branch 'main' into users/phschaad/adapt_passes
phschaad Nov 12, 2024
3094fa1
Update SDFV
phschaad Nov 12, 2024
38fcbaf
Endless loop in constant prop fix
phschaad Nov 12, 2024
567d307
Fix propagation
phschaad Nov 12, 2024
61ac6ee
More fixes
phschaad Nov 13, 2024
8c488de
Update gitignore
phschaad Nov 13, 2024
7e4bc3d
Fix loop symbol type inference and loop to map
phschaad Nov 13, 2024
40d4a12
Fix traversal for defined symbols
phschaad Nov 13, 2024
c61af96
Merge branch 'main' into users/phschaad/adapt_passes
phschaad Nov 25, 2024
ef595b4
Skip FV3 pipeline until adapted to V2
phschaad Dec 3, 2024
3af9a70
Merge branch 'main' into users/phschaad/adapt_passes
phschaad Dec 3, 2024
d27ec31
Fix bug introduced through merge
phschaad Dec 3, 2024
786a07e
Merge branch 'main' into users/phschaad/adapt_passes
phschaad Dec 9, 2024
72508bd
Update dace/transformation/interstate/block_fusion.py
phschaad Dec 9, 2024
ce82732
Address review comments
phschaad Dec 9, 2024
83b26c1
Merge branch 'users/phschaad/adapt_passes' of github.com:spcl/dace in…
phschaad Dec 9, 2024
3c13369
More comments
phschaad Dec 9, 2024
c4e78d7
Add doc comments
phschaad Dec 9, 2024
3a2b342
Address more comments
phschaad Dec 9, 2024
2f7d6aa
Address more review comments
phschaad Dec 10, 2024
a469286
Inlining fix
phschaad Dec 10, 2024
a8546ab
Fixes to control flow raising and codegen
phschaad Dec 11, 2024
c9d6b51
Renamed experimental_cfg_blocks to explicit_control_flow
phschaad Dec 11, 2024
e2a6466
Added more extensible meta access replacement function
phschaad Dec 11, 2024
6f82fc6
Fixes
phschaad Dec 11, 2024
34a3247
Add more API methods
phschaad Dec 11, 2024
fad3424
Address comments
phschaad Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/pyFV3-ci.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
name: NASA/NOAA pyFV3 repository build test

# Temporarily disabled for main, and instead applied to a specific DaCe v1 maintenance branch (v1/maintenance). Once
# the FV3 bridge has been adapted to DaCe v1, this will need to be reverted back to apply to main.
on:
push:
branches: [ main, ci-fix ]
#branches: [ main, ci-fix ]
branches: [ v1/maintenance, ci-fix ]
pull_request:
branches: [ main, ci-fix ]
#branches: [ main, ci-fix ]
branches: [ v1/maintenance, ci-fix ]
merge_group:
branches: [ main, ci-fix ]
#branches: [ main, ci-fix ]
branches: [ v1/maintenance, ci-fix ]

defaults:
run:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ src.VC.VC.opendb

# DaCe
.dacecache/
# Ignore dacecache if added as a symlink
.dacecache
out.sdfg
*.out
results.log
Expand Down
67 changes: 42 additions & 25 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
"""
Various classes to facilitate the code generation of structured control
flow elements (e.g., ``for``, ``if``, ``while``) from state machines in SDFGs.
Expand Down Expand Up @@ -62,8 +62,8 @@
import sympy as sp
from dace import dtypes
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion,
ReturnBlock, SDFGState)
from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion,
LoopRegion, ReturnBlock, SDFGState)
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.graph import Edge
from dace.properties import CodeBlock
Expand Down Expand Up @@ -200,7 +200,10 @@ class BreakCFBlock(ControlFlow):
block: BreakBlock

def as_cpp(self, codegen, symbols) -> str:
return 'break;\n'
cfg = self.block.parent_graph
expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label)
expr += 'break;\n'
return expr

@property
def first_block(self) -> BreakBlock:
Expand All @@ -214,7 +217,10 @@ class ContinueCFBlock(ControlFlow):
block: ContinueBlock

def as_cpp(self, codegen, symbols) -> str:
return 'continue;\n'
cfg = self.block.parent_graph
expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label)
expr += 'continue;\n'
return expr

@property
def first_block(self) -> ContinueBlock:
Expand All @@ -228,7 +234,10 @@ class ReturnCFBlock(ControlFlow):
block: ReturnBlock

def as_cpp(self, codegen, symbols) -> str:
return 'return;\n'
cfg = self.block.parent_graph
expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label)
expr += 'return;\n'
return expr

@property
def first_block(self) -> ReturnBlock:
Expand Down Expand Up @@ -316,7 +325,13 @@ def as_cpp(self, codegen, symbols) -> str:
# One unconditional edge
if (len(out_edges) == 1 and out_edges[0].data.is_unconditional()):
continue
expr += f'goto __state_exit_{sdfg.cfg_id};\n'
if self.region:
expr += f'goto __state_exit_{self.region.cfg_id};\n'
else:
expr += f'goto __state_exit_{sdfg.cfg_id};\n'

if self.region and not isinstance(self.region, SDFG):
expr += f'__state_exit_{self.region.cfg_id}:;\n'

return expr

Expand Down Expand Up @@ -536,10 +551,14 @@ def as_cpp(self, codegen, symbols) -> str:
expr = ''

if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable:
init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
lsyms = {}
lsyms.update(symbols)
if codegen.dispatcher.defined_vars.has(self.loop.loop_variable) and not self.loop.loop_variable in lsyms:
lsyms[self.loop.loop_variable] = codegen.dispatcher.defined_vars.get(self.loop.loop_variable)[1]
init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=lsyms)
init = init.strip(';')

update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=lsyms)
update = update.strip(';')

if self.loop.inverted:
Expand Down Expand Up @@ -571,6 +590,8 @@ def as_cpp(self, codegen, symbols) -> str:
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += '\n}\n'

expr += f'__state_exit_{self.loop.cfg_id}:;\n'

return expr

@property
Expand Down Expand Up @@ -1018,21 +1039,16 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion,
start: Optional[ControlFlowBlock] = None,
stop: Optional[ControlFlowBlock] = None,
generate_children_of: Optional[ControlFlowBlock] = None,
branch_merges: Optional[Dict[ControlFlowBlock,
ControlFlowBlock]] = None,
ptree: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None,
visited: Optional[Set[ControlFlowBlock]] = None):
if branch_merges is None:
branch_merges = cfg_analysis.branch_merges(cfg)

if ptree is None:
ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False)

start = start if start is not None else cfg.start_block

def make_empty_block():
def make_empty_block(region):
return GeneralBlock(dispatch_state, parent_block,
last_block=False, region=None, elements=[], gotos_to_ignore=[],
last_block=False, region=region, elements=[], gotos_to_ignore=[],
gotos_to_break=[], gotos_to_continue=[], assignments_to_ignore=[], sequential=True)

# Traverse states in custom order
Expand All @@ -1059,18 +1075,18 @@ def make_empty_block():
cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, [])
for cond, branch in node.branches:
if branch is not None:
body = make_empty_block()
body = make_empty_block(branch)
body.parent = cfg_block
_structured_control_flow_traversal_with_regions(branch, dispatch_state, body)
cfg_block.branch_bodies.append((cond, body))
elif isinstance(node, ControlFlowRegion):
if isinstance(node, LoopRegion):
body = make_empty_block()
body = make_empty_block(node)
cfg_block = GeneralLoopScope(dispatch_state, parent_block, False, node, body)
body.parent = cfg_block
_structured_control_flow_traversal_with_regions(node, dispatch_state, body)
else:
cfg_block = make_empty_block()
cfg_block = make_empty_block(node)
cfg_block.region = node
_structured_control_flow_traversal_with_regions(node, dispatch_state, cfg_block)

Expand All @@ -1095,13 +1111,14 @@ def make_empty_block():
return visited - {stop}


def structured_control_flow_tree_with_regions(sdfg: SDFG, dispatch_state: Callable[[SDFGState], str]) -> ControlFlow:
def structured_control_flow_tree_with_regions(cfg: ControlFlowRegion,
dispatch_state: Callable[[SDFGState], str]) -> ControlFlow:
"""
Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from an SDFG based on the
Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from a CFG based on the
control flow regions it contains.
:param sdfg: The SDFG to iterate over.
:return: Control-flow block representing the entire SDFG.
:param cfg: The graph to iterate over.
:return: Control-flow block representing the entire graph.
"""
root_block = GeneralBlock(dispatch_state=dispatch_state,
parent=None,
Expand All @@ -1113,7 +1130,7 @@ def structured_control_flow_tree_with_regions(sdfg: SDFG, dispatch_state: Callab
gotos_to_break=[],
assignments_to_ignore=[],
sequential=True)
_structured_control_flow_traversal_with_regions(sdfg, dispatch_state, root_block)
_structured_control_flow_traversal_with_regions(cfg, dispatch_state, root_block)
_reset_block_parents(root_block)
return root_block

Expand All @@ -1127,7 +1144,7 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState
:param sdfg: The SDFG to iterate over.
:return: Control-flow block representing the entire SDFG.
"""
if sdfg.root_sdfg.using_experimental_blocks:
if sdfg.root_sdfg.using_explicit_control_flow:
return structured_control_flow_tree_with_regions(sdfg, dispatch_state)

# Avoid import loops
Expand Down
66 changes: 34 additions & 32 deletions dace/codegen/instrumentation/data/data_dump.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
from dace import config, data as dt, dtypes, registry, SDFG
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
from dace import data as dt, dtypes, registry, SDFG
from dace.sdfg import nodes, is_devicelevel_gpu
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.instrumentation.provider import InstrumentationProvider
from dace.sdfg.scope import is_devicelevel_fpga
from dace.sdfg.state import SDFGState
from dace.sdfg.state import ControlFlowRegion, SDFGState
from dace.codegen import common
from dace.codegen import cppunparse
from dace.codegen.targets import cpp
Expand Down Expand Up @@ -101,7 +101,8 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod
if sdfg.parent is None:
sdfg.append_exit_code('delete __state->serializer;\n')

def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream):
def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream,
global_stream: CodeIOStream):
if state.symbol_instrument == dtypes.DataInstrumentationType.No_Instrumentation:
return

Expand All @@ -119,17 +120,17 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea
condition_preamble = f'if ({cond_string})' + ' {'
condition_postamble = '}'

state_id = sdfg.node_id(state)
local_stream.write(condition_preamble, sdfg, state_id)
state_id = cfg.node_id(state)
local_stream.write(condition_preamble, cfg, state_id)
defined_symbols = state.defined_symbols()
for sym, _ in defined_symbols.items():
local_stream.write(
f'__state->serializer->save_symbol("{sym}", "{state_id}", {cpp.sym2cpp(sym)});\n', sdfg, state_id
f'__state->serializer->save_symbol("{sym}", "{state_id}", {cpp.sym2cpp(sym)});\n', cfg, state_id
)
local_stream.write(condition_postamble, sdfg, state_id)
local_stream.write(condition_postamble, cfg, state_id)

def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, outer_stream: CodeIOStream,
inner_stream: CodeIOStream, global_stream: CodeIOStream):
def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.AccessNode,
outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream):
from dace.codegen.dispatcher import DefinedType # Avoid import loop

if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node):
Expand Down Expand Up @@ -159,9 +160,9 @@ def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, oute
ptrname = '&' + ptrname

# Create UUID
state_id = sdfg.node_id(state)
state_id = cfg.node_id(state)
node_id = state.node_id(node)
uuid = f'{sdfg.cfg_id}_{state_id}_{node_id}'
uuid = f'{cfg.cfg_id}_{state_id}_{node_id}'

# Get optional pre/postamble for instrumenting device data
preamble, postamble = '', ''
Expand All @@ -174,13 +175,13 @@ def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, oute
strides = ', '.join(cpp.sym2cpp(s) for s in desc.strides)

# Write code
inner_stream.write(condition_preamble, sdfg, state_id, node_id)
inner_stream.write(preamble, sdfg, state_id, node_id)
inner_stream.write(condition_preamble, cfg, state_id, node_id)
inner_stream.write(preamble, cfg, state_id, node_id)
inner_stream.write(
f'__state->serializer->save({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, '
f'"{node.data}", "{uuid}", {shape}, {strides});\n', sdfg, state_id, node_id)
inner_stream.write(postamble, sdfg, state_id, node_id)
inner_stream.write(condition_postamble, sdfg, state_id, node_id)
f'"{node.data}", "{uuid}", {shape}, {strides});\n', cfg, state_id, node_id)
inner_stream.write(postamble, cfg, state_id, node_id)
inner_stream.write(condition_postamble, cfg, state_id, node_id)


@registry.autoregister_params(type=dtypes.DataInstrumentationType.Restore)
Expand Down Expand Up @@ -216,7 +217,8 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod
if sdfg.parent is None:
sdfg.append_exit_code('delete __state->serializer;\n')

def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream):
def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream,
global_stream: CodeIOStream):
if state.symbol_instrument == dtypes.DataInstrumentationType.No_Instrumentation:
return

Expand All @@ -234,18 +236,18 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea
condition_preamble = f'if ({cond_string})' + ' {'
condition_postamble = '}'

state_id = sdfg.node_id(state)
local_stream.write(condition_preamble, sdfg, state_id)
state_id = state.block_id
local_stream.write(condition_preamble, cfg, state_id)
defined_symbols = state.defined_symbols()
for sym, sym_type in defined_symbols.items():
local_stream.write(
f'{cpp.sym2cpp(sym)} = __state->serializer->restore_symbol<{sym_type.ctype}>("{sym}", "{state_id}");\n',
sdfg, state_id
cfg, state_id
)
local_stream.write(condition_postamble, sdfg, state_id)
local_stream.write(condition_postamble, cfg, state_id)

def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, outer_stream: CodeIOStream,
inner_stream: CodeIOStream, global_stream: CodeIOStream):
def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.AccessNode,
outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream):
from dace.codegen.dispatcher import DefinedType # Avoid import loop

if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node):
Expand Down Expand Up @@ -275,21 +277,21 @@ def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, ou
ptrname = '&' + ptrname

# Create UUID
state_id = sdfg.node_id(state)
state_id = cfg.node_id(state)
node_id = state.node_id(node)
uuid = f'{sdfg.cfg_id}_{state_id}_{node_id}'
uuid = f'{cfg.cfg_id}_{state_id}_{node_id}'

# Get optional pre/postamble for instrumenting device data
preamble, postamble = '', ''
if desc.storage == dtypes.StorageType.GPU_Global:
self._setup_gpu_runtime(sdfg, global_stream)
self._setup_gpu_runtime(cfg, global_stream)
preamble, postamble, ptrname = self._generate_copy_to_device(node, desc, ptrname)

# Write code
inner_stream.write(condition_preamble, sdfg, state_id, node_id)
inner_stream.write(preamble, sdfg, state_id, node_id)
inner_stream.write(condition_preamble, cfg, state_id, node_id)
inner_stream.write(preamble, cfg, state_id, node_id)
inner_stream.write(
f'__state->serializer->restore({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, '
f'"{node.data}", "{uuid}");\n', sdfg, state_id, node_id)
inner_stream.write(postamble, sdfg, state_id, node_id)
inner_stream.write(condition_postamble, sdfg, state_id, node_id)
f'"{node.data}", "{uuid}");\n', cfg, state_id, node_id)
inner_stream.write(postamble, cfg, state_id, node_id)
inner_stream.write(condition_postamble, cfg, state_id, node_id)
Loading
Loading