Skip to content

Commit

Permalink
Added symbol resolution with externals support
Browse files Browse the repository at this point in the history
  • Loading branch information
BenWeber42 committed Jan 18, 2021
1 parent a5eb74c commit 77a07b7
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 17 deletions.
10 changes: 5 additions & 5 deletions dusk/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
Capture,
Repeat,
FixedList,
EmptyList,
name,
BreakPoint,
)
from dusk.semantics import (
Expand All @@ -61,14 +63,9 @@


# Short cuts
EmptyList = FixedList()
AnyContext = OneOf(Load, Store, Del, AugLoad, AugStore, Param)


def name(id, ctx=Load) -> Name:
return Name(id=id, ctx=ctx)


def transform(matcher) -> t.Callable:
def decorator(transformer: t.Callable) -> t.Callable:
def transformer_with_matcher(self, node, *args, **kwargs):
Expand All @@ -95,6 +92,9 @@ class Grammar:
def __init__(self):
self.ctx = DuskContextHelper()

# TODO: somewhere we should check that the function is a valid stencil
# e.g., no kwargs etc, mostly `Grammar.stencil` without capturing/processing

@transform(
FunctionDef(
name=Capture(str).to("name"),
Expand Down
17 changes: 13 additions & 4 deletions dusk/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from ast import AST, stmt, expr
import ast

from dusk.errors import ASTError
from dusk.util import pprint_matcher as pprint
Expand All @@ -16,8 +16,10 @@
"Optional",
"Capture",
"FixedList",
"EmptyList",
"BreakPoint",
"NoMatch",
"name",
]


Expand Down Expand Up @@ -55,6 +57,9 @@ def match(self, nodes, **kwargs):
match(matcher, node, **kwargs)


EmptyList = FixedList()


class Repeat(Matcher):
_fields = ("matcher", "n")

Expand Down Expand Up @@ -162,11 +167,15 @@ def match(self, node, **kwargs):
match(self.matcher, node, **kwargs)


def name(id, ctx=ast.Load) -> ast.Name:
return ast.Name(id=id, ctx=ctx)


def match(matcher, node, **kwargs) -> None:
# this should be probably more flexible than hardcoding all possibilities
if isinstance(matcher, Matcher):
matcher.match(node, **kwargs)
elif isinstance(matcher, AST):
elif isinstance(matcher, ast.AST):
match_ast(matcher, node, **kwargs)
elif isinstance(matcher, type):
match_type(matcher, node, **kwargs)
Expand All @@ -184,7 +193,7 @@ def does_match(matcher, node, **kwargs) -> bool:
return False


def match_ast(matcher: AST, node, **kwargs):
def match_ast(matcher: ast.AST, node, **kwargs):
if not isinstance(node, type(matcher)):
raise NoMatch(
f"Expected node type '{type(matcher)}', but got '{type(node)}'!", node
Expand All @@ -194,7 +203,7 @@ def match_ast(matcher: AST, node, **kwargs):
try:
match(getattr(matcher, field), getattr(node, field), **kwargs)
except NoMatch as e:
if e.loc is None and isinstance(node, (stmt, expr)):
if e.loc is None and isinstance(node, (ast.stmt, ast.expr)):
# add location info if possible
e.loc_from_node(node)
raise e
Expand Down
240 changes: 238 additions & 2 deletions dusk/passes/symbol_resolution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,242 @@
from __future__ import annotations

import typing as t
from types import CellType

import builtins
from ast import *

from dusk import grammar, errors
from dusk.match import (
does_match,
Ignore as _,
Optional,
Capture,
FixedList,
EmptyList,
name,
)
from dusk.integration import StencilObject


def resolve_symbols(stencil_object: StencilObject) -> None:
# FIXME: implement
pass
SymbolResolver(stencil_object).resolve_symbols()


class SymbolResolver:

# TODO: check preconditions?
# TODO: check postconditions?

externals: DictScope[t.Any]
api_fields: DictScope[t.Any]
temp_fields: DictScope[t.Any]
_current_scope: DictScope[t.Any]

def __init__(self, stencil_object: StencilObject):
self.stencil_object = stencil_object

_builtins: DictScope[t.Any] = DictScope(
symbols=builtins.__dict__,
can_add_symbols=False,
allow_shadowing=True,
parent=None,
)
globals: DictScope[t.Any] = DictScope(
symbols=stencil_object.callable.__globals__,
can_add_symbols=False,
allow_shadowing=True,
parent=_builtins,
)
closure = {}
if stencil_object.callable.__closure__ is not None:
# FIXME: add a test for a proper closure
closure = dict(
zip(
stencil_object.callable.__code__.co_freevars,
(c.cell_contents for c in stencil_object.callable.__closure__),
)
)

self.externals: DictScope[t.Any] = DictScope(
symbols=closure,
can_add_symbols=False,
allow_shadowing=True,
parent=globals,
)
self.api_fields = DictScope(parent=self.externals)
self.temp_fields = DictScope(parent=self.api_fields)
self._current_scope = self.temp_fields

def resolve_symbols(self):
self.stencil(self.stencil_object.pyast)

@grammar.transform(
FunctionDef(
name=_,
args=arguments(
posonlyargs=EmptyList,
args=Capture(list).to("api_fields"),
vararg=None,
kwonlyargs=EmptyList,
kw_defaults=EmptyList,
kwarg=None,
defaults=EmptyList,
),
body=Capture(list).to("body"),
decorator_list=_,
returns=_,
type_comment=None,
)
)
def stencil(self, api_fields: t.List[arg], body: t.List[stmt]):
for field in api_fields:
self.api_field(field)

remaining_stmts = []
for stmt in body:

if isinstance(stmt, AnnAssign):
self.temp_field(stmt)
# vertical iteration variables:
elif isinstance(stmt, With):
remaining_stmts.append(stmt)
self.vertical_loop(stmt)

else:
remaining_stmts.append(stmt)
self.resolve_names(stmt)

# TODO: reenable when symbol resolution properly moved to only this pass
# body.clear()
# body.extend(remaining_stmts)

@grammar.transform(
arg(
arg=Capture(str).to("name"),
annotation=Capture(expr).to("field_type"),
type_comment=None,
)
)
def api_field(self, name: str, field_type: expr):
self.resolve_names(field_type)
self.api_fields.try_add(name, field_type)

@grammar.transform(
AnnAssign(
target=name(Capture(str).to("name"), ctx=Store),
value=None,
annotation=Capture(expr).to("field_type"),
simple=1,
),
)
def temp_field(self, name: str, field_type: expr):
self.resolve_names(field_type)
self.temp_fields.try_add(name, field_type)

@grammar.transform(
With(
items=FixedList(
withitem(
context_expr=Capture(expr).to("domain"),
optional_vars=Optional(name(Capture(str).to("var"), ctx=Store)),
),
),
body=Capture(list).to("body"),
type_comment=None,
),
)
def vertical_loop(self, domain: expr, body: t.List, var: str = None):
self.resolve_names(domain)

# FIXME: should we add a context manager again?
previous_scope = self._current_scope
self._current_scope = DictScope(parent=previous_scope)

if var is not None:
self._current_scope.try_add(var, domain)

for stmt in body:
self.resolve_names(stmt)

self._current_scope = previous_scope

def resolve_names(self, node: AST):
for child in walk(node):
if not isinstance(child, Name):
continue

name = child.id

if not self._current_scope.contains(name):
raise errors.SemanticError(f"Undeclared variable '{name}'!", child)
child.decl = self._current_scope.fetch(name)


T = t.TypeVar("T")


class DictScope(t.Generic[T]):

symbols: t.Dict[str, T]
can_add_symbols: bool
# whether child scopes are allowed to shadow symbols from this scope
allow_shadowing: bool
parent: t.Optional[DictScope[T]]

def __init__(
self,
symbols: t.Optional[t.Dict[str, T]] = None,
can_add_symbols: bool = True,
allow_shadowing: bool = False,
parent: t.Optional[DictScope[T]] = None,
):
if symbols is None:
symbols = {}
self.symbols = symbols
self.allow_shadowing = allow_shadowing
self.can_add_symbols = can_add_symbols
self.parent = parent

# this is only to check fetching!
def contains(self, name: str) -> bool:
if name in self.symbols:
return True
if self.parent is not None:
return self.parent.contains(name)
return False

def fetch(self, name: str) -> T:
if name in self.symbols:
return self.symbols[name]
if self.parent is not None:
return self.parent.fetch(name)
raise KeyError(f"Couldn't find symbol '{name}' in scope!")

def try_add(self, name: str, symbol: T) -> None:

if not self.can_add_symbols:
raise KeyError(
f"This scope doesn't allow adding of symbols ('{name}', '{symbol}')!"
)

if name in self.symbols:
raise KeyError(
f"Symbol '{name}' already exists in this scope ('{symbol}')!"
)

parent = self.parent
while parent is not None:

if not parent.allow_shadowing and name in parent.symbols:
raise KeyError(
f"Symbol '{name}' illegally shadows another symbol "
f"('{symbol}', '{parent.symbols[name]}')!"
)

parent = parent.parent

self.symbols[name] = symbol

def local_iter(self) -> t.Iterator[T]:
return iter(self.symbols.values())
6 changes: 6 additions & 0 deletions dusk/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"Field",
"IndexField",
"domain",
"levels_upward",
"levels_downward",
"HorizontalDomains",
"sparse",
"reduce_over",
Expand All @@ -27,6 +29,10 @@
] + __math_all__


# FIXME: remove this hack when `domain` properly works
levels_upward = levels_downward = "levels_hack"


def stencil(stencil: typing.Callable) -> typing.Callable:
integration.stencil_collection.append(integration.StencilObject(stencil))
return stencil
Expand Down
7 changes: 3 additions & 4 deletions dusk/semantics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import NewType, Optional, ClassVar, Iterator, Iterable, List, Dict
from typing import NewType, Optional, ClassVar, Iterator, List, Dict

from enum import Enum, auto, unique
from dataclasses import dataclass
Expand Down Expand Up @@ -44,7 +44,6 @@ class Scope:
symbols: Dict[str, Symbol]
parent: Optional[Scope]

# TODO: better error messages
def __init__(self, parent: Optional[Scope] = None) -> None:
self.symbols = {}
self.parent = parent
Expand All @@ -61,11 +60,11 @@ def fetch(self, name: str) -> Symbol:
return self.symbols[name]
if self.parent is not None:
return self.parent.fetch(name)
raise KeyError
raise KeyError(f"Couldn't find symbol '{name}' in scope!")

def add(self, name: str, symbol: Symbol) -> None:
if self.contains(name):
raise KeyError
raise KeyError(f"Symbol '{name}' already exists in scope!")

self.symbols[name] = symbol

Expand Down
Loading

0 comments on commit 77a07b7

Please sign in to comment.