diff --git a/src/scenic/core/requirements.py b/src/scenic/core/requirements.py index 1774e6e7e..6ef6ab2f8 100644 --- a/src/scenic/core/requirements.py +++ b/src/scenic/core/requirements.py @@ -53,7 +53,7 @@ def __init__(self, ty, condition, line, prob, name, ego): bindings = {} atomGlobals = None for atom in atoms: - bindings.update(getAllGlobals(atom.closure)) + bindings.update(getNameBindings(atom.closure)) globs = atom.closure.__globals__ if atomGlobals is not None: assert globs is atomGlobals @@ -134,24 +134,30 @@ def closure(values, monitor=None): return CompiledRequirement(self, closure, deps, condition) -def getAllGlobals(req, restrictTo=None): +def getNameBindings(req, restrictTo=None): """Find all names the given lambda depends on, along with their current bindings.""" namespace = req.__globals__ if restrictTo is not None and restrictTo is not namespace: return {} externals = inspect.getclosurevars(req) - assert not externals.nonlocals # TODO handle these - globs = dict(externals.builtins) - for name, value in externals.globals.items(): - globs[name] = value - if inspect.isfunction(value): - subglobs = getAllGlobals(value, restrictTo=namespace) - for name, value in subglobs.items(): - if name in globs: - assert value is globs[name] - else: - globs[name] = value - return globs + allBindings = dict(externals.builtins) + + def addBindings(bindings): + for name, value in bindings.items(): + allBindings[name] = value + if inspect.isfunction(value): + subglobs = getNameBindings(value, restrictTo=namespace) + for name, value in subglobs.items(): + if name in allBindings: + assert value is allBindings[name] + else: + allBindings[name] = value + + addBindings(externals.globals) + if restrictTo is None: + # At the top level, include nonlocal variables captured in the closure + addBindings(externals.nonlocals) + return allBindings class BoundRequirement: diff --git a/tests/syntax/test_requirements.py b/tests/syntax/test_requirements.py index a1ffb9d36..f19179f9d 100644 --- a/tests/syntax/test_requirements.py +++ b/tests/syntax/test_requirements.py @@ -31,6 +31,38 @@ def test_requirement_in_loop(): assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss) +def test_requirement_in_function(): + scenario = compileScenic( + """ + ego = new Object at Range(-10, 10) @ Range(-10, 10) + def f(i): + require ego.position[i] >= 0 + for i in range(2): + f(i) + """ + ) + poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)] + assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss) + + +def test_requirement_in_function_helper(): + scenario = compileScenic( + """ + ego = new Object at Range(-10, 10) @ Range(-10, 10) + m = 0 + def f(): + assert m == 0 + return ego.y + m + def g(): + require ego.x < f() + g() + m = -100 + """ + ) + poss = [sampleEgo(scenario, maxIterations=60).position for i in range(60)] + assert all(pos.x < pos.y for pos in poss) + + def test_soft_requirement(): scenario = compileScenic( """