Skip to content

Commit

Permalink
fix allowCollisions; tests for built-in requirements
Browse files Browse the repository at this point in the history
Thanks to @afsafzal on GitHub for pointing out the bug.
  • Loading branch information
dfremont committed Mar 31, 2021
1 parent a4630db commit 5eac20e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 14 deletions.
28 changes: 15 additions & 13 deletions src/scenic/core/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,15 @@ def validate(self):
if staticVisibility and oi.requireVisible is True and oi is not self.egoObject:
if not self.egoObject.canSee(oi):
raise InvalidScenarioError(f'Object at {oi.position} is not visible from ego')
# Require object to not intersect another object
for j in range(i):
oj = objects[j]
if not staticBounds[j]:
continue
if oi.intersects(oj):
raise InvalidScenarioError(f'Object at {oi.position} intersects'
f' object at {oj.position}')
if not oi.allowCollisions:
# Require object to not intersect another object
for j in range(i):
oj = objects[j]
if oj.allowCollisions or not staticBounds[j]:
continue
if oi.intersects(oj):
raise InvalidScenarioError(f'Object at {oi.position} intersects'
f' object at {oj.position}')

def hasStaticBounds(self, obj):
if needsSampling(obj.position):
Expand Down Expand Up @@ -224,11 +225,12 @@ def generate(self, maxIterations=2000, verbosity=0, feedback=None):
rejection = 'object visibility'
break
# Require object to not intersect another object
for j in range(i):
vj = sample[objects[j]]
if vi.intersects(vj):
rejection = 'object intersection'
break
if not vi.allowCollisions:
for j in range(i):
vj = sample[objects[j]]
if not vj.allowCollisions and vi.intersects(vj):
rejection = 'object intersection'
break
if rejection is not None:
break
if rejection is not None:
Expand Down
64 changes: 63 additions & 1 deletion tests/syntax/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import scenic
from scenic.core.errors import ScenicSyntaxError, InvalidScenarioError
from tests.utils import compileScenic, sampleScene, sampleEgo
from tests.utils import compileScenic, sampleScene, sampleSceneFrom, sampleEgo

## Basic

Expand Down Expand Up @@ -66,6 +66,56 @@ def test_runtime_parse_error_in_requirement():
with pytest.raises(ScenicSyntaxError):
sampleScene(scenario, maxIterations=1)

## Enforcement of built-in requirements

def test_containment_requirement():
scenario = compileScenic("""
foo = RectangularRegion(0@0, 0, 10, 10)
ego = Object at Range(0, 10) @ 0, with regionContainedIn foo
""")
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert all(0 <= x <= 5 for x in xs)

def test_visibility_requirement():
scenario = compileScenic("""
ego = Object with visibleDistance 10, with viewAngle 90 deg, facing 45 deg
other = Object at Range(-10, 10) @ 0
""")
xs = [sampleScene(scenario, maxIterations=60).objects[1].position.x for i in range(60)]
assert all(-10 <= x <= 0.5 for x in xs)

def test_visibility_requirement_disabled():
scenario = compileScenic("""
ego = Object with visibleDistance 10, with viewAngle 90 deg, facing 45 deg
other = Object at Range(-10, 10) @ 0, with requireVisible False
""")
xs = [sampleScene(scenario, maxIterations=60).objects[1].position.x for i in range(60)]
assert any(x > 0.5 for x in xs)

def test_intersection_requirement():
scenario = compileScenic("""
ego = Object at Range(0, 2) @ 0
other = Object
""")
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert all(x >= 1 for x in xs)

def test_intersection_requirement_disabled_1():
scenario = compileScenic("""
ego = Object at Range(0, 2) @ 0, with allowCollisions True
other = Object
""")
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert any(x < 1 for x in xs)

def test_intersection_requirement_disabled_2():
scenario = compileScenic("""
ego = Object at Range(0, 2) @ 0
other = Object with allowCollisions True
""")
xs = [sampleEgo(scenario, maxIterations=60).position.x for i in range(60)]
assert any(x < 1 for x in xs)

## Static violations of built-in requirements

def test_static_containment_violation():
Expand All @@ -89,9 +139,21 @@ def test_static_visibility_violation():
Object at 0@10
""")

def test_static_visibility_violation_disabled():
sampleSceneFrom("""
ego = Object at 10@0, facing -90 deg, with viewAngle 90 deg
Object at 0@10, with requireVisible False
""")

def test_static_intersection_violation():
with pytest.raises(InvalidScenarioError):
compileScenic("""
ego = Object at 0@0
Object at 1@0
""")

def test_static_intersection_violation_disabled():
sampleSceneFrom("""
ego = Object at 0@0
Object at 1@0, with allowCollisions True
""")

0 comments on commit 5eac20e

Please sign in to comment.