Skip to content

Commit

Permalink
Preliminary implementation of horizontal domains
Browse files Browse the repository at this point in the history
This is a preliminary implementation. It mainly contains the
functionality for horizontal domains. There's still quite a bit of
cleanup left to do.
  • Loading branch information
BenWeber42 committed Jan 22, 2021
1 parent 4b38e24 commit 5a4ad62
Show file tree
Hide file tree
Showing 23 changed files with 408 additions and 143 deletions.
95 changes: 44 additions & 51 deletions dusk/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
name,
BreakPoint,
)
from dusk.passes.constant_folder import DUSK_CONSTANT_KIND
from dusk.semantics import (
Symbol,
SymbolKind,
Expand All @@ -52,7 +53,7 @@
VerticalIterationVariable,
DuskContextHelper,
)
from dusk.script import stencil as stencil_decorator
from dusk.script import internal
from dusk.script.stubs import (
LOCATION_TYPES,
UNARY_MATH_FUNCTIONS,
Expand Down Expand Up @@ -328,58 +329,37 @@ def if_stmt(self, condition: expr, body: t.List, orelse: t.List):
return make_if_stmt(condition, body, orelse)

@transform(
With(
items=FixedList(
# TODO: hardcoded strings
withitem(
context_expr=OneOf(
name(
Capture(OneOf("levels_upward", "levels_downward")).to(
"order"
),
),
Subscript(
value=name(
id=Capture(
OneOf("levels_upward", "levels_downward")
).to("order")
),
slice=Slice(
lower=Capture(_).to("lower"),
upper=Capture(_).to("upper"),
step=None,
),
ctx=Load,
BreakPoint(
With(
items=FixedList(
# TODO: hardcoded strings
withitem(
context_expr=Constant(
value=Capture(internal.Domain).to("domain"),
kind=DUSK_CONSTANT_KIND,
),
optional_vars=Optional(name(Capture(str).to("var"), ctx=Store)),
),
optional_vars=Optional(name(Capture(str).to("var"), ctx=Store)),
),
body=Capture(_).to("body"),
type_comment=None,
),
body=Capture(_).to("body"),
type_comment=None,
),
active=False,
)
)
def vertical_loop(self, order, body, upper=None, lower=None, var: str = None):
def vertical_loop(self, domain: internal.Domain, body, var: str = None):

if lower is None:
lower_level, lower_offset = sir.Interval.Start, 0
else:
lower_level, lower_offset = self.vertical_interval_bound(lower)
# FIXME: gracefully handle if the domain can't be constant folded

if upper is None:
upper_level, upper_offset = sir.Interval.End, 0
else:
upper_level, upper_offset = self.vertical_interval_bound(upper)
if not domain.valid():
raise SemanticError("Invalid domain!")

order_mapper = {
"levels_upward": sir.VerticalRegion.Forward,
"levels_downward": sir.VerticalRegion.Backward,
}
with self.ctx.vertical_region(var):
return make_vertical_region_decl_stmt(
make_ast(self.statements(body)),
make_interval(lower_level, upper_level, lower_offset, upper_offset),
order_mapper[order],
ast=make_ast(self.statements(body)),
interval=domain.vertical_domain.to_sir(),
loop_order=domain.vertical_direction.to_sir(),
IRange=domain.horizontal_domain.to_sir(),
)

# TODO: richer vertical interval bounds
Expand All @@ -399,7 +379,7 @@ def vertical_interval_bound(self, bound):
@transform(
With(
items=FixedList(
# TODO: bad hardcoded string `neighbors`
# TODO: bad hardcoded string `sparse`
withitem(
context_expr=Subscript(
value=name(id="sparse"),
Expand All @@ -421,12 +401,12 @@ def loop_stmt(self, neighborhood, body: t.List):

return make_loop_stmt(body, neighborhood, include_center)

@transform(Capture(expr).to("expr"))
def expression(self, expr: expr):
@transform(Capture(OneOf(bool, int, float, expr)).to("expr"))
def expression(self, expr: t.Union[bool, int, float, expr]):
return make_expr(
dispatch(
{
Constant: self.constant,
OneOf(bool, int, float, Constant): self.constant,
Name: self.var,
Subscript: self.subscript,
UnaryOp: self.unop,
Expand All @@ -440,8 +420,13 @@ def expression(self, expr: expr):
)
)

@transform(Constant(value=Capture(_).to("value"), kind=None))
def constant(self, value):
@transform(
OneOf(
Constant(value=Capture(_).to("value"), kind=Optional(str)),
Capture(OneOf(bool, int, float)).to("value"),
)
)
def constant(self, value: Any):
# TODO: properly distinguish between float and double
built_in_type_map = {bool: "Boolean", int: "Integer", float: "Double"}

Expand Down Expand Up @@ -887,9 +872,17 @@ def reduction(

weights = None
if "weights" in kwargs:
# TODO: check for `kwargs["weight"].ctx == Load`?
weights = [self.expression(weight) for weight in kwargs["weights"].elts]
weights = self.list_of_expressions(kwargs["weights"])

return make_reduction_over_neighbor_expr(
op, expr, init, neighborhood, weights, include_center
)

@transform(
OneOf(
List(elts=Capture(list).to("exprs"), ctx=AnyContext),
Constant(value=Capture(list).to("exprs"), kind=DUSK_CONSTANT_KIND),
)
)
def list_of_expressions(self, exprs: list):
return [self.expression(expr) for expr in exprs]
3 changes: 2 additions & 1 deletion dusk/integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, List
from typing import Any, Optional, Callable, List

from importlib.util import spec_from_file_location, module_from_spec
import ast
Expand All @@ -9,6 +9,7 @@
class StencilObject:
callable: Callable
filename: str
stencil_scope: Any
pyast: Optional[ast.FunctionDef] = None
globals: Optional[sir.GlobalVariableMap]
sir_node: Optional[sir.SIR] = None
Expand Down
Empty file added dusk/passes/__init__.py
Empty file.
165 changes: 165 additions & 0 deletions dusk/passes/constant_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Any, Union, Tuple, Optional, Callable, Iterator

import ast

from dusk import integration, errors
from dusk.script import internal


DUSK_CONSTANT_KIND = "__dusk_constant_kind__"


def constant_fold(stencil_object: integration.StencilObject) -> None:

inline_compiletime_constants(stencil_object)
constant_fold_expr(stencil_object)


def inline_compiletime_constants(stencil_object: integration.StencilObject) -> None:

for node, setter in post_order_mut_iter(stencil_object.pyast.body):
# FIXME: is this in invariant after symbol resolution?
# or should we do `hasattr(node, "decl")` instead?
if (
isinstance(node, ast.Name)
and isinstance(node.ctx, ast.Load)
and isinstance(node.decl, internal.CompileTimeConstant)
):
constant = ast.Constant(value=node.decl, kind=DUSK_CONSTANT_KIND)
ast.copy_location(constant, node)
setter(constant)


def constant_fold_expr(stencil_object: integration.StencilObject) -> None:
for node, setter in post_order_mut_iter(stencil_object.pyast.body):
if expr_is_constant_foldable(node):
setter(evaluate_constant_foldable(node))


def expr_is_constant_foldable(node: ast.AST):
# This deliberately doesn't work for nested expressions

if not isinstance(node, ast.expr):
return False

if isinstance(node, (ast.Constant, ast.Name)):
return False

for field in node._fields:
child = getattr(node, field)

if isinstance(child, ast.AST):
if not is_constant_or_childless(child):
return False

if isinstance(child, list) and not all(
is_constant_or_childless(subchild) for subchild in child
):
return False

return True


def is_constant_or_childless(node: ast.AST):
return 0 == len(node._fields) or isinstance(node, ast.Constant)


def evaluate_constant_foldable(node: ast.AST) -> Any:
# TODO: replace `ast.Constant(value=$value, kind="dusk")` with
# `ast.Name(id=$name, ctx=ast.Load())` and add
# `$name: $value` to the locals dict ($name is a fresh variable name)
# should probably try to clone the expression?

var_counter = 0

def make_fresh_name():
nonlocal var_counter
var_counter += 1
return f"var{var_counter}"

locals = {}

def make_local_var_node(node: ast.Constant):
fresh_name = make_fresh_name()
assert fresh_name not in locals
locals[fresh_name] = node.value
var = ast.Name(id=fresh_name, ctx=ast.Load())
ast.copy_location(var, node)
return var

copy = type(node)()
ast.copy_location(copy, node)

for field in node._fields:
child = getattr(node, field)

if is_dusk_constant(child):
child_copy = make_local_var_node(child)

elif isinstance(child, list):
child_copy = [
subchild
if not is_dusk_constant(subchild)
else make_local_var_node(subchild)
for subchild in child
]

else:
child_copy = child

setattr(copy, field, child_copy)

copy = ast.Expression(body=copy)

# TODO: filename & location info?
# TODO: handle exceptions?
value = eval(compile(copy, mode="eval", filename="<unknown>"), {}, locals)
constant_node = ast.Constant(value=value, kind=DUSK_CONSTANT_KIND)
ast.copy_location(constant_node, node)

return constant_node


def is_dusk_constant(node: ast.AST):
return isinstance(node, ast.Constant) and node.kind == DUSK_CONSTANT_KIND


# possible strategy:
# replace `CompileTimeConstant` `ast.Name` with `ast.Constant`
# If a node only has `ast.Constant` children -> constant fold
# for `ast.Constant` that have python objects, replace with `ast.Name`
# and put the value in the locals dict

# TODO: how to do execution of `ast.Constant(..., kind="dusk")`?
# -> adding local closures


def post_order_mut_iter(
node: Union[ast.AST, list],
slot: Optional[Union[Tuple[ast.AST, str], Tuple[list, int]]] = None,
) -> Iterator[Tuple[ast.AST, Any]]:

if isinstance(node, ast.AST):
for field in node._fields:
yield from post_order_mut_iter(getattr(node, field), (node, field))

if slot is not None:
yield node, make_setter(slot)

elif isinstance(node, list):
for index, child in enumerate(node):
yield from post_order_mut_iter(child, (node, index))


def make_setter(
slot: Union[Tuple[ast.AST, str], Tuple[list, int]]
) -> Callable[[ast.AST], None]:
parent, entry = slot

def ast_setter(node: ast.AST) -> None:
setattr(parent, entry, node)

def list_setter(node: ast.AST) -> None:
parent[entry] = node

return ast_setter if isinstance(parent, ast.AST) else list_setter
3 changes: 3 additions & 0 deletions dusk/passes/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from dusk.grammar import Grammar
from dusk.passes.symbol_resolution import resolve_symbols
from dusk.passes.resolve_globals import resolve_globals
from dusk.passes.constant_folder import constant_fold


def stencil_object_to_sir(stencil_object: StencilObject) -> sir.Stencil:

add_filename(stencil_object)
add_pyast(stencil_object)
resolve_symbols(stencil_object)
constant_fold(stencil_object)
resolve_globals(stencil_object)
add_sir(stencil_object)

Expand All @@ -32,6 +34,7 @@ def add_pyast(stencil_object: StencilObject) -> None:
source,
filename=stencil_object.filename,
type_comments=True,
feature_version=(3, 8),
)
assert isinstance(stencil_ast, ast.Module)
assert len(stencil_ast.body) == 1
Expand Down
3 changes: 3 additions & 0 deletions dusk/passes/symbol_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SymbolResolver:
# TODO: check preconditions?
# TODO: check postconditions?

stencil_object: StencilObject
externals: DictScope[t.Any]
api_fields: DictScope[t.Any]
temp_fields: DictScope[t.Any]
Expand Down Expand Up @@ -68,6 +69,8 @@ def __init__(self, stencil_object: StencilObject):
self.temp_fields = DictScope(parent=self.api_fields)
self._current_scope = self.temp_fields

stencil_object.stencil_scope = self.temp_fields

def resolve_symbols(self):
self.stencil(self.stencil_object.pyast)

Expand Down
6 changes: 0 additions & 6 deletions dusk/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
"Field",
"IndexField",
"domain",
"levels_upward",
"levels_downward",
"HorizontalDomains",
"sparse",
"reduce_over",
Expand All @@ -30,10 +28,6 @@
] + __math_all__


# FIXME: remove this hack when `domain` properly works
levels_upward = levels_downward = "levels_hack"


def stencil(stencil: typing.Callable) -> typing.Callable:
integration.stencil_collection.append(integration.StencilObject(stencil))
return stencil
Expand Down
Loading

0 comments on commit 5a4ad62

Please sign in to comment.