From 77a07b7e4d4a40fb3151a9d7172e02c11c556c5f Mon Sep 17 00:00:00 2001 From: Ben Weber Date: Mon, 18 Jan 2021 12:06:24 +0100 Subject: [PATCH] Added symbol resolution with externals support --- dusk/grammar.py | 10 +- dusk/match.py | 17 ++- dusk/passes/symbol_resolution.py | 240 ++++++++++++++++++++++++++++++- dusk/script/__init__.py | 6 + dusk/semantics.py | 7 +- dusk/test.py | 7 +- 6 files changed, 270 insertions(+), 17 deletions(-) diff --git a/dusk/grammar.py b/dusk/grammar.py index 7cc0ced..56d5460 100644 --- a/dusk/grammar.py +++ b/dusk/grammar.py @@ -40,6 +40,8 @@ Capture, Repeat, FixedList, + EmptyList, + name, BreakPoint, ) from dusk.semantics import ( @@ -61,14 +63,9 @@ # Short cuts -EmptyList = FixedList() AnyContext = OneOf(Load, Store, Del, AugLoad, AugStore, Param) -def name(id, ctx=Load) -> Name: - return Name(id=id, ctx=ctx) - - def transform(matcher) -> t.Callable: def decorator(transformer: t.Callable) -> t.Callable: def transformer_with_matcher(self, node, *args, **kwargs): @@ -95,6 +92,9 @@ class Grammar: def __init__(self): self.ctx = DuskContextHelper() + # TODO: somewhere we should check that the function is a valid stencil + # e.g., no kwargs etc, mostly `Grammar.stencil` without capturing/processing + @transform( FunctionDef( name=Capture(str).to("name"), diff --git a/dusk/match.py b/dusk/match.py index 25b4edb..a709b93 100644 --- a/dusk/match.py +++ b/dusk/match.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from ast import AST, stmt, expr +import ast from dusk.errors import ASTError from dusk.util import pprint_matcher as pprint @@ -16,8 +16,10 @@ "Optional", "Capture", "FixedList", + "EmptyList", "BreakPoint", "NoMatch", + "name", ] @@ -55,6 +57,9 @@ def match(self, nodes, **kwargs): match(matcher, node, **kwargs) +EmptyList = FixedList() + + class Repeat(Matcher): _fields = ("matcher", "n") @@ -162,11 +167,15 @@ def match(self, node, **kwargs): match(self.matcher, node, **kwargs) +def name(id, ctx=ast.Load) -> ast.Name: + return ast.Name(id=id, ctx=ctx) + + def match(matcher, node, **kwargs) -> None: # this should be probably more flexible than hardcoding all possibilities if isinstance(matcher, Matcher): matcher.match(node, **kwargs) - elif isinstance(matcher, AST): + elif isinstance(matcher, ast.AST): match_ast(matcher, node, **kwargs) elif isinstance(matcher, type): match_type(matcher, node, **kwargs) @@ -184,7 +193,7 @@ def does_match(matcher, node, **kwargs) -> bool: return False -def match_ast(matcher: AST, node, **kwargs): +def match_ast(matcher: ast.AST, node, **kwargs): if not isinstance(node, type(matcher)): raise NoMatch( f"Expected node type '{type(matcher)}', but got '{type(node)}'!", node @@ -194,7 +203,7 @@ def match_ast(matcher: AST, node, **kwargs): try: match(getattr(matcher, field), getattr(node, field), **kwargs) except NoMatch as e: - if e.loc is None and isinstance(node, (stmt, expr)): + if e.loc is None and isinstance(node, (ast.stmt, ast.expr)): # add location info if possible e.loc_from_node(node) raise e diff --git a/dusk/passes/symbol_resolution.py b/dusk/passes/symbol_resolution.py index 83704c2..ba5bf96 100644 --- a/dusk/passes/symbol_resolution.py +++ b/dusk/passes/symbol_resolution.py @@ -1,6 +1,242 @@ +from __future__ import annotations + +import typing as t +from types import CellType + +import builtins +from ast import * + +from dusk import grammar, errors +from dusk.match import ( + does_match, + Ignore as _, + Optional, + Capture, + FixedList, + EmptyList, + name, +) from dusk.integration import StencilObject def resolve_symbols(stencil_object: StencilObject) -> None: - # FIXME: implement - pass + SymbolResolver(stencil_object).resolve_symbols() + + +class SymbolResolver: + + # TODO: check preconditions? + # TODO: check postconditions? + + externals: DictScope[t.Any] + api_fields: DictScope[t.Any] + temp_fields: DictScope[t.Any] + _current_scope: DictScope[t.Any] + + def __init__(self, stencil_object: StencilObject): + self.stencil_object = stencil_object + + _builtins: DictScope[t.Any] = DictScope( + symbols=builtins.__dict__, + can_add_symbols=False, + allow_shadowing=True, + parent=None, + ) + globals: DictScope[t.Any] = DictScope( + symbols=stencil_object.callable.__globals__, + can_add_symbols=False, + allow_shadowing=True, + parent=_builtins, + ) + closure = {} + if stencil_object.callable.__closure__ is not None: + # FIXME: add a test for a proper closure + closure = dict( + zip( + stencil_object.callable.__code__.co_freevars, + (c.cell_contents for c in stencil_object.callable.__closure__), + ) + ) + + self.externals: DictScope[t.Any] = DictScope( + symbols=closure, + can_add_symbols=False, + allow_shadowing=True, + parent=globals, + ) + self.api_fields = DictScope(parent=self.externals) + self.temp_fields = DictScope(parent=self.api_fields) + self._current_scope = self.temp_fields + + def resolve_symbols(self): + self.stencil(self.stencil_object.pyast) + + @grammar.transform( + FunctionDef( + name=_, + args=arguments( + posonlyargs=EmptyList, + args=Capture(list).to("api_fields"), + vararg=None, + kwonlyargs=EmptyList, + kw_defaults=EmptyList, + kwarg=None, + defaults=EmptyList, + ), + body=Capture(list).to("body"), + decorator_list=_, + returns=_, + type_comment=None, + ) + ) + def stencil(self, api_fields: t.List[arg], body: t.List[stmt]): + for field in api_fields: + self.api_field(field) + + remaining_stmts = [] + for stmt in body: + + if isinstance(stmt, AnnAssign): + self.temp_field(stmt) + # vertical iteration variables: + elif isinstance(stmt, With): + remaining_stmts.append(stmt) + self.vertical_loop(stmt) + + else: + remaining_stmts.append(stmt) + self.resolve_names(stmt) + + # TODO: reenable when symbol resolution properly moved to only this pass + # body.clear() + # body.extend(remaining_stmts) + + @grammar.transform( + arg( + arg=Capture(str).to("name"), + annotation=Capture(expr).to("field_type"), + type_comment=None, + ) + ) + def api_field(self, name: str, field_type: expr): + self.resolve_names(field_type) + self.api_fields.try_add(name, field_type) + + @grammar.transform( + AnnAssign( + target=name(Capture(str).to("name"), ctx=Store), + value=None, + annotation=Capture(expr).to("field_type"), + simple=1, + ), + ) + def temp_field(self, name: str, field_type: expr): + self.resolve_names(field_type) + self.temp_fields.try_add(name, field_type) + + @grammar.transform( + With( + items=FixedList( + withitem( + context_expr=Capture(expr).to("domain"), + optional_vars=Optional(name(Capture(str).to("var"), ctx=Store)), + ), + ), + body=Capture(list).to("body"), + type_comment=None, + ), + ) + def vertical_loop(self, domain: expr, body: t.List, var: str = None): + self.resolve_names(domain) + + # FIXME: should we add a context manager again? + previous_scope = self._current_scope + self._current_scope = DictScope(parent=previous_scope) + + if var is not None: + self._current_scope.try_add(var, domain) + + for stmt in body: + self.resolve_names(stmt) + + self._current_scope = previous_scope + + def resolve_names(self, node: AST): + for child in walk(node): + if not isinstance(child, Name): + continue + + name = child.id + + if not self._current_scope.contains(name): + raise errors.SemanticError(f"Undeclared variable '{name}'!", child) + child.decl = self._current_scope.fetch(name) + + +T = t.TypeVar("T") + + +class DictScope(t.Generic[T]): + + symbols: t.Dict[str, T] + can_add_symbols: bool + # whether child scopes are allowed to shadow symbols from this scope + allow_shadowing: bool + parent: t.Optional[DictScope[T]] + + def __init__( + self, + symbols: t.Optional[t.Dict[str, T]] = None, + can_add_symbols: bool = True, + allow_shadowing: bool = False, + parent: t.Optional[DictScope[T]] = None, + ): + if symbols is None: + symbols = {} + self.symbols = symbols + self.allow_shadowing = allow_shadowing + self.can_add_symbols = can_add_symbols + self.parent = parent + + # this is only to check fetching! + def contains(self, name: str) -> bool: + if name in self.symbols: + return True + if self.parent is not None: + return self.parent.contains(name) + return False + + def fetch(self, name: str) -> T: + if name in self.symbols: + return self.symbols[name] + if self.parent is not None: + return self.parent.fetch(name) + raise KeyError(f"Couldn't find symbol '{name}' in scope!") + + def try_add(self, name: str, symbol: T) -> None: + + if not self.can_add_symbols: + raise KeyError( + f"This scope doesn't allow adding of symbols ('{name}', '{symbol}')!" + ) + + if name in self.symbols: + raise KeyError( + f"Symbol '{name}' already exists in this scope ('{symbol}')!" + ) + + parent = self.parent + while parent is not None: + + if not parent.allow_shadowing and name in parent.symbols: + raise KeyError( + f"Symbol '{name}' illegally shadows another symbol " + f"('{symbol}', '{parent.symbols[name]}')!" + ) + + parent = parent.parent + + self.symbols[name] = symbol + + def local_iter(self) -> t.Iterator[T]: + return iter(self.symbols.values()) diff --git a/dusk/script/__init__.py b/dusk/script/__init__.py index 89d705f..51e8725 100644 --- a/dusk/script/__init__.py +++ b/dusk/script/__init__.py @@ -17,6 +17,8 @@ "Field", "IndexField", "domain", + "levels_upward", + "levels_downward", "HorizontalDomains", "sparse", "reduce_over", @@ -27,6 +29,10 @@ ] + __math_all__ +# FIXME: remove this hack when `domain` properly works +levels_upward = levels_downward = "levels_hack" + + def stencil(stencil: typing.Callable) -> typing.Callable: integration.stencil_collection.append(integration.StencilObject(stencil)) return stencil diff --git a/dusk/semantics.py b/dusk/semantics.py index ba5a539..fbaad8a 100644 --- a/dusk/semantics.py +++ b/dusk/semantics.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import NewType, Optional, ClassVar, Iterator, Iterable, List, Dict +from typing import NewType, Optional, ClassVar, Iterator, List, Dict from enum import Enum, auto, unique from dataclasses import dataclass @@ -44,7 +44,6 @@ class Scope: symbols: Dict[str, Symbol] parent: Optional[Scope] - # TODO: better error messages def __init__(self, parent: Optional[Scope] = None) -> None: self.symbols = {} self.parent = parent @@ -61,11 +60,11 @@ def fetch(self, name: str) -> Symbol: return self.symbols[name] if self.parent is not None: return self.parent.fetch(name) - raise KeyError + raise KeyError(f"Couldn't find symbol '{name}' in scope!") def add(self, name: str, symbol: Symbol) -> None: if self.contains(name): - raise KeyError + raise KeyError(f"Symbol '{name}' already exists in scope!") self.symbols[name] = symbol diff --git a/dusk/test.py b/dusk/test.py index 18e9230..25f2e05 100644 --- a/dusk/test.py +++ b/dusk/test.py @@ -1,6 +1,6 @@ from typing import Callable -from dusk.integration import StencilObject +from dusk.integration import StencilObject, stencil_collection from dusk.transpile import validate as validate_sir, stencil_object_to_sir @@ -14,8 +14,11 @@ def stencil_test(validate: bool = True) -> Callable: def decorator(stencil: Callable) -> Callable: assert stencil.__name__.startswith("test_") + stencil_object = StencilObject(stencil) + stencil_collection.append(stencil_object) + def test_stencil() -> None: - sir = stencil_object_to_sir(StencilObject(stencil)) + sir = stencil_object_to_sir(stencil_object) if validate: validate_sir(sir)