Skip to content

Commit

Permalink
Requirement Boolean Negation Fix (#289)
Browse files Browse the repository at this point in the history
* Added test to reproduce Issue#286

* Initial attempt for not transforming not.

* convert globals used in requirements to distributions as needed

* Simplified test.

* Update src/scenic/syntax/compiler.py

Co-authored-by: Daniel Fremont <[email protected]>

* Fixed renaming.

---------

Co-authored-by: Daniel Fremont <[email protected]>
  • Loading branch information
Eric-Vin and dfremont authored Jul 11, 2024
1 parent a34460a commit 652a7ec
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/scenic/core/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import rv_ltl
import trimesh

from scenic.core.distributions import Samplable, needsSampling
from scenic.core.distributions import Samplable, needsSampling, toDistribution
from scenic.core.errors import InvalidScenarioError
from scenic.core.lazy_eval import needsLazyEvaluation
from scenic.core.propositions import Atomic, PropositionNode
Expand Down Expand Up @@ -71,6 +71,10 @@ def compile(self, namespace, scenario, syntax=None):
bindings, ego, line = self.bindings, self.egoObject, self.line
condition, ty = self.condition, self.ty

# Convert bound values to distributions as needed
for name, value in bindings.items():
bindings[name] = toDistribution(value)

# Check whether requirement implies any relations used for pruning
canPrune = condition.check_constrains_sampling()
if canPrune:
Expand Down
11 changes: 10 additions & 1 deletion src/scenic/syntax/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class PropositionTransformer(Transformer):
def __init__(self, filename="<unknown>") -> None:
super().__init__(filename)
self.nextSyntaxId = 0
self.inAtomic = False

def transform(
self, node: ast.AST, nextSyntaxId=0
Expand All @@ -260,6 +261,14 @@ def transform(
newNode = self._create_atomic_proposition_factory(node)
return newNode, self.nextSyntaxId

def generic_visit(self, node):
# Don't recurse inside atomics.
old_inAtomic = self.inAtomic
self.inAtomic = True
super_val = super().generic_visit(node)
self.inAtomic = old_inAtomic
return super_val

def _register_requirement_syntax(self, syntax):
"""register requirement syntax for later use
returns an ID for retrieving the syntax
Expand Down Expand Up @@ -337,7 +346,7 @@ def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:

def visit_UnaryOp(self, node):
# rewrite `not` in requirements into a proposition factory
if not isinstance(node.op, ast.Not):
if not isinstance(node.op, ast.Not) or self.inAtomic:
return self.generic_visit(node)

lineNum = ast.Constant(node.lineno)
Expand Down
11 changes: 11 additions & 0 deletions tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,14 @@ def test_random_occlusion():
hasattr(obj, "name") and obj.name == "wall" and (not obj.occluding)
for obj in scene.objects
)


def test_deep_not():
"""Test that a not deep inside a requirement is interpreted correctly."""
with pytest.raises(RejectionException):
sampleSceneFrom(
"""
objs = [new Object at 10@10, new Object at 20@20]
require all(not o.x > 0 for o in objs)
"""
)

0 comments on commit 652a7ec

Please sign in to comment.