Skip to content

Commit

Permalink
fix require statements inside loops
Browse files Browse the repository at this point in the history
  • Loading branch information
dfremont committed Nov 23, 2024
1 parent 4ae8ee2 commit 56995da
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
7 changes: 3 additions & 4 deletions src/scenic/core/dynamics/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs):
self._objects = [] # ordered for reproducibility
self._sampledObjects = self._objects
self._externalParameters = []
self._pendingRequirements = defaultdict(list)
self._pendingRequirements = []
self._requirements = []
# things needing to be sampled to evaluate the requirements
self._requirementDeps = set()
Expand Down Expand Up @@ -409,9 +409,8 @@ def _registerObject(self, obj):

def _addRequirement(self, ty, reqID, req, line, name, prob):
"""Save a requirement defined at compile-time for later processing."""
assert reqID not in self._pendingRequirements
preq = PendingRequirement(ty, req, line, prob, name, self._ego)
self._pendingRequirements[reqID] = preq
self._pendingRequirements.append((reqID, preq))

def _addDynamicRequirement(self, ty, req, line, name):
"""Add a requirement defined during a dynamic simulation."""
Expand All @@ -429,7 +428,7 @@ def _compileRequirements(self):
namespace = self._dummyNamespace if self._dummyNamespace else self.__dict__
requirementSyntax = self._requirementSyntax
assert requirementSyntax is not None
for reqID, requirement in self._pendingRequirements.items():
for reqID, requirement in self._pendingRequirements:
syntax = requirementSyntax[reqID] if requirementSyntax else None

# Catch the simple case where someone has most likely forgotten the "monitor"
Expand Down
6 changes: 4 additions & 2 deletions src/scenic/core/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ def compile(self, namespace, scenario, syntax=None):

# Construct closure
def closure(values, monitor=None):
# rebind any names referring to sampled objects
# rebind any names referring to sampled objects (for require statements,
# rebind all names, since we want their values at the time the requirement
# was created)
# note: need to extract namespace here rather than close over value
# from above because of https://github.com/uqfoundation/dill/issues/532
namespace = condition.atomics()[0].closure.__globals__
for name, value in bindings.items():
if value in values:
if ty == RequirementType.require or value in values:
namespace[name] = values[value]
# rebind ego object, which can be referred to implicitly
boundEgo = None if ego is None else values[ego]
Expand Down
9 changes: 5 additions & 4 deletions src/scenic/syntax/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,11 +1359,12 @@ def createRequirementLike(
"""Create a call to a function that implements requirement-like features, such as `record` and `terminate when`.
Args:
functionName (str): Name of the requirement-like function to call. Its signature must be `(reqId: int, body: () -> bool, lineno: int, name: str | None)`
functionName (str): Name of the requirement-like function to call. Its signature
must be `(reqId: int, body: () -> bool, lineno: int, name: str | None)`
body (ast.AST): AST node to evaluate for checking the condition
lineno (int): Line number in the source code
name (Optional[str], optional): Optional name for requirements. Defaults to None.
prob (Optional[float], optional): Optional probability for requirements. Defaults to None.
name (Optional[str]): Optional name for requirements. Defaults to None.
prob (Optional[float]): Optional probability for requirements. Defaults to None.
"""
propTransformer = PropositionTransformer(self.filename)
newBody, self.nextSyntaxId = propTransformer.transform(body, self.nextSyntaxId)
Expand All @@ -1374,7 +1375,7 @@ def createRequirementLike(
value=ast.Call(
func=ast.Name(functionName, loadCtx),
args=[
ast.Constant(requirementId), # requirement IDre
ast.Constant(requirementId), # requirement ID
newBody, # body
ast.Constant(lineno), # line number
ast.Constant(name), # requirement name
Expand Down
12 changes: 12 additions & 0 deletions tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def test_requirement():
assert all(0 <= x <= 10 for x in xs)


def test_requirement_in_loop():
scenario = compileScenic(
"""
ego = new Object at Range(-10, 10) @ Range(-10, 10)
for i in range(2):
require ego.position[i] >= 0
"""
)
poss = [sampleEgo(scenario, maxIterations=150).position for i in range(60)]
assert all(0 <= pos.x <= 10 and 0 <= pos.y <= 10 for pos in poss)


def test_soft_requirement():
scenario = compileScenic(
"""
Expand Down

0 comments on commit 56995da

Please sign in to comment.