Skip to content

Commit

Permalink
[Feature] Boundary visualization for limited-size environments (#142)
Browse files Browse the repository at this point in the history
* add boundary visualization for limited-size environments

* add boundary visualization for limited-size environments

* add boundary visualization for limited-size environments

* introduce "visualize_semidims" to display boundaries

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* Update vmas/simulator/scenario.py

Co-authored-by: Matteo Bettini <[email protected]>

* add boundary visualization for limited-size environments

* disabled "visualize_semidims" as boundaries are already being plotted in this scenario

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* disabled "visualize_semidims" as boundaries are already being plotted in this scenario

* add boundary visualization for limited-size environments

* Update vmas/simulator/environment/environment.py

Co-authored-by: Matteo Bettini <[email protected]>

* add boundary visualization for limited-size environments

---------

Co-authored-by: Matteo Bettini <[email protected]>
  • Loading branch information
Giovannibriglia and matteobettini authored Sep 20, 2024
1 parent 26ceb42 commit 132d97b
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vmas/scenarios/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.shaping_factor = 100
self.fall_reward = -10

self.visualize_semidims = False

# Make world
world = World(batch_dim, device, gravity=(0.0, -0.05), y_semidim=1)
# Add agents
Expand Down
2 changes: 2 additions & 0 deletions vmas/scenarios/ball_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.passage_width = 0.2
self.passage_length = 0.103

self.visualize_semidims = False

# Make world
world = World(
batch_dim,
Expand Down
1 change: 1 addition & 0 deletions vmas/scenarios/football.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class Scenario(BaseScenario):
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.init_params(**kwargs)
self.visualize_semidims = False
world = self.init_world(batch_dim, device)
self.init_agents(world)
self.init_ball(world)
Expand Down
2 changes: 2 additions & 0 deletions vmas/scenarios/joint_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
ScenarioUtils.check_kwargs_consumed(kwargs)

self.plot_grid = True
self.visualize_semidims = False

# Make world
world = World(
batch_dim,
Expand Down
1 change: 1 addition & 0 deletions vmas/scenarios/joint_passage_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
assert self.n_passages == 3 or self.n_passages == 4

self.plot_grid = False
self.visualize_semidims = False

# Make world
world = World(
Expand Down
2 changes: 2 additions & 0 deletions vmas/scenarios/mpe/simple_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.respawn_at_catch = kwargs.pop("respawn_at_catch", False)
ScenarioUtils.check_kwargs_consumed(kwargs)

self.visualize_semidims = False

world = World(
batch_dim=batch_dim,
device=device,
Expand Down
2 changes: 2 additions & 0 deletions vmas/scenarios/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.passage_width = 0.2
self.passage_length = 0.103

self.visualize_semidims = False

# Make world
world = World(batch_dim, device, x_semidim=1, y_semidim=1)
# Add agents
Expand Down
1 change: 1 addition & 0 deletions vmas/scenarios/road_traffic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Scenario(BaseScenario):

def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self.init_params(batch_dim, device, **kwargs)
self.visualize_semidims = False
world = self.init_world(batch_dim, device)
self.init_agents(world)
return world
Expand Down
1 change: 1 addition & 0 deletions vmas/scenarios/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
assert len(self.covs) == self.n_gaussians

self.plot_grid = False
self.visualize_semidims = False
self.n_x_cells = int((2 * self.xdim) / self.grid_spacing)
self.n_y_cells = int((2 * self.ydim) / self.grid_spacing)
self.max_pdf = torch.zeros((batch_dim,), device=device, dtype=torch.float32)
Expand Down
61 changes: 61 additions & 0 deletions vmas/simulator/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,9 @@ def render(
)

# Render
if self.scenario.visualize_semidims:
self.plot_boundary()

self._set_agent_comm_messages(env_index)

if plot_position_function is not None:
Expand Down Expand Up @@ -770,6 +773,64 @@ def render(
# render to display or array
return self.viewer.render(return_rgb_array=mode == "rgb_array")

def plot_boundary(self):
# include boundaries in the rendering if the environment is dimension-limited
if self.world.x_semidim is not None or self.world.y_semidim is not None:
from vmas.simulator.rendering import Line
from vmas.simulator.utils import Color

# set a big value for the cases where the environment is dimension-limited only in one coordinate
infinite_value = 100

x_semi = (
self.world.x_semidim
if self.world.x_semidim is not None
else infinite_value
)
y_semi = (
self.world.y_semidim
if self.world.y_semidim is not None
else infinite_value
)

# set the color for the boundary line
color = Color.GRAY.value

# Define boundary points based on whether world semidims are provided
if (
self.world.x_semidim is not None and self.world.y_semidim is not None
) or self.world.y_semidim is not None:
boundary_points = [
(-x_semi, y_semi),
(x_semi, y_semi),
(x_semi, -y_semi),
(-x_semi, -y_semi),
]
else:
boundary_points = [
(-x_semi, y_semi),
(-x_semi, -y_semi),
(x_semi, y_semi),
(x_semi, -y_semi),
]

# Create lines by connecting points
for i in range(
0,
len(boundary_points),
1
if (
self.world.x_semidim is not None
and self.world.y_semidim is not None
)
else 2,
):
start = boundary_points[i]
end = boundary_points[(i + 1) % len(boundary_points)]
line = Line(start, end, width=0.7)
line.set_color(*color)
self.viewer.add_onetime(line)

def plot_function(
self, f, precision, plot_range, cmap_range, cmap_alpha, cmap_name
):
Expand Down
2 changes: 2 additions & 0 deletions vmas/simulator/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self):
"""Whether to plot a grid in the scenario rendering background. This can be changed in the :class:`~make_world` function. """
self.grid_spacing = 0.1
"""If :class:`~plot_grid`, the distance between lines in the background grid. This can be changed in the :class:`~make_world` function. """
self.visualize_semidims = True
"""Whether to display boundaries in dimension-limited environment. This can be changed in the :class:`~make_world` function. """

@property
def world(self):
Expand Down

0 comments on commit 132d97b

Please sign in to comment.