Skip to content

Commit

Permalink
fix require statements inside functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dfremont committed Nov 23, 2024
1 parent 56995da commit bff61c4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/scenic/core/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, ty, condition, line, prob, name, ego):
bindings = {}
atomGlobals = None
for atom in atoms:
bindings.update(getAllGlobals(atom.closure))
bindings.update(getNameBindings(atom.closure))
globs = atom.closure.__globals__
if atomGlobals is not None:
assert globs is atomGlobals
Expand Down Expand Up @@ -134,24 +134,30 @@ def closure(values, monitor=None):
return CompiledRequirement(self, closure, deps, condition)


def getAllGlobals(req, restrictTo=None):
def getNameBindings(req, restrictTo=None):
"""Find all names the given lambda depends on, along with their current bindings."""
namespace = req.__globals__
if restrictTo is not None and restrictTo is not namespace:
return {}
externals = inspect.getclosurevars(req)
assert not externals.nonlocals # TODO handle these
globs = dict(externals.builtins)
for name, value in externals.globals.items():
globs[name] = value
if inspect.isfunction(value):
subglobs = getAllGlobals(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in globs:
assert value is globs[name]
else:
globs[name] = value
return globs
allBindings = dict(externals.builtins)

def addBindings(bindings):
for name, value in bindings.items():
allBindings[name] = value
if inspect.isfunction(value):
subglobs = getNameBindings(value, restrictTo=namespace)
for name, value in subglobs.items():
if name in allBindings:
assert value is allBindings[name]
else:
allBindings[name] = value

addBindings(externals.globals)
if restrictTo is None:
# At the top level, include nonlocal variables captured in the closure
addBindings(externals.nonlocals)
return allBindings


class BoundRequirement:
Expand Down
32 changes: 32 additions & 0 deletions tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@ def test_requirement_in_loop():
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_requirement_in_function():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
def f(i):
require ego.position[i] >= 0
for i in range(2):
f(i)
"""
)
poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)]
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_requirement_in_function_helper():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
m = 0
def f():
assert m == 0
return ego.y + m
def g():
require ego.x < f()
g()
m = -100
"""
)
poss = [sampleEgo(scenario, maxIterations=60).position for i in range(60)]
assert all(pos.x < pos.y for pos in poss)


def test_soft_requirement():
scenario = compileScenic(
"""
Expand Down

0 comments on commit bff61c4

Please sign in to comment.