diff --git a/.github/unittest/install_dependencies.sh b/.github/unittest/install_dependencies.sh index 00e4a71f..adb7aa7b 100644 --- a/.github/unittest/install_dependencies.sh +++ b/.github/unittest/install_dependencies.sh @@ -7,7 +7,7 @@ python -m pip install --upgrade pip -pip install -e . +pip install -e ".[gymnasium]" python -m pip install flake8 pytest pytest-cov tqdm matplotlib==3.8 python -m pip install cvxpylayers # Navigation heuristic diff --git a/README.md b/README.md index 87a1d635..ca89f4b9 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Scenario creation is made simple and modular to incentivize contributions. VMAS simulates agents and landmarks of different shapes and supports rotations, elastic collisions, joints, and custom gravity. Holonomic motion models are used for the agents to simplify simulation. Custom sensors such as LIDARs are available and the simulator supports inter-agent communication. Vectorization in [PyTorch](https://pytorch.org/) allows VMAS to perform simulations in a batch, seamlessly scaling to tens of thousands of parallel environments on accelerated hardware. -VMAS has an interface compatible with [OpenAI Gym](https://github.com/openai/gym), with [RLlib](https://docs.ray.io/en/latest/rllib/index.html), with [torchrl](https://github.com/pytorch/rl) and its MARL training library: [BenchMARL](https://github.com/facebookresearch/BenchMARL), +VMAS has an interface compatible with [OpenAI Gym](https://github.com/openai/gym), with [Gymnasium](https://gymnasium.farama.org/), with [RLlib](https://docs.ray.io/en/latest/rllib/index.html), with [torchrl](https://github.com/pytorch/rl) and its MARL training library: [BenchMARL](https://github.com/facebookresearch/BenchMARL), enabling out-of-the-box integration with a wide range of RL algorithms. The implementation is inspired by [OpenAI's MPE](https://github.com/openai/multiagent-particle-envs). Alongside VMAS's scenarios, we port and vectorize all the scenarios in MPE. @@ -113,28 +113,37 @@ git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git cd VectorizedMultiAgentSimulator pip install -e . ``` -By default, vmas has only the core requirements. Here are some optional packages you may want to install: +By default, vmas has only the core requirements. To install further dependencies to enable training with [Gymnasium](https://gymnasium.farama.org/) wrappers, [RLLib](https://docs.ray.io/en/latest/rllib/index.html) wrappers, for rendering, and testing, you may want to install these further options: ```bash -# Training -pip install "ray[rllib]"==2.1.0 # We support versions "ray[rllib]<=2.2,>=1.13" -pip install torchrl +# install gymnasium for gymnasium wrappers +pip install vmas[gymnasium] -# Logging -pip installl wandb +# install rllib for rllib wrapper +pip install vmas[rllib] -# Rendering -pip install opencv-python moviepy matplotlib +# install rendering dependencies +pip install vmas[render] -# Tests -pip install pytest pyyaml pytest-instafail tqdm +# install testing dependencies +pip install vmas[test] + +# install all dependencies +pip install vmas[all] +``` + +You can also install the following training libraries: + +```bash +pip install benchmarl # For training in BenchMARL +pip install torchrl # For training in TorchRL +pip install "ray[rllib]"==2.1.0 # For training in RLlib. We support versions "ray[rllib]<=2.2,>=1.13" ``` ### Run To use the simulator, simply create an environment by passing the name of the scenario you want (from the `scenarios` folder) to the `make_env` function. -The function arguments are explained in the documentation. The function returns an environment -object with the OpenAI gym interface: +The function arguments are explained in the documentation. The function returns an environment object with the VMAS interface: Here is an example: ```python @@ -143,17 +152,24 @@ Here is an example: num_envs=32, device="cpu", # Or "cuda" for GPU continuous_actions=True, - wrapper=None, # One of: None, vmas.Wrapper.RLLIB, and vmas.Wrapper.GYM + wrapper=None, # One of: None, "rllib", "gym", "gymnasium", "gymnasium_vec" max_steps=None, # Defines the horizon. None is infinite horizon. seed=None, # Seed of the environment dict_spaces=False, # By default tuple spaces are used with each element in the tuple being an agent. # If dict_spaces=True, the spaces will become Dict with each key being the agent's name grad_enabled=False, # If grad_enabled the simulator is differentiable and gradients can flow from output to input + terminated_truncated=False, # If terminated_truncated the simulator will return separate `terminated` and `truncated` flags in the `done()`, `step()`, and `get_from_scenario()` functions instead of a single `done` flag **kwargs # Additional arguments you want to pass to the scenario initialization ) ``` A further example that you can run is contained in `use_vmas_env.py` in the `examples` directory. +With the `terminated_truncated` flag set to `True`, the simulator will return separate `terminated` and `truncated` flags +in the `done()`, `step()`, and `get_from_scenario()` functions instead of a single `done` flag. +This is useful when you want to know if the environment is done because the episode has ended or +because the maximum episode length/ timestep horizon has been reached. +See [the Gymnasium documentation](https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/) for more details on this. + #### RLlib To see how to use VMAS in RLlib, check out the script in `examples/rllib.py`. @@ -235,7 +251,7 @@ Each format will work regardless of the fact that tuples or dictionary spaces ha - **Simple**: Complex vectorized physics engines exist (e.g., [Brax](https://github.com/google/brax)), but they do not scale efficiently when dealing with multiple agents. This defeats the computational speed goal set by vectorization. VMAS uses a simple custom 2D dynamics engine written in PyTorch to provide fast simulation. - **General**: The core of VMAS is structured so that it can be used to implement general high-level multi-robot problems in 2D. It can support adversarial as well as cooperative scenarios. Holonomic point-robot simulation has been chosen to focus on general high-level problems, without learning low-level custom robot controls through MARL. - **Extensible**: VMAS is not just a simulator with a set of environments. It is a framework that can be used to create new multi-agent scenarios in a format that is usable by the whole MARL community. For this purpose, we have modularized the process of creating a task and introduced interactive rendering to debug it. You can define your own scenario in minutes. Have a look at the dedicated section in this document. -- **Compatible**: VMAS has wrappers for [RLlib](https://docs.ray.io/en/latest/rllib/index.html), [torchrl](https://pytorch.org/rl/reference/generated/torchrl.envs.libs.vmas.VmasEnv.html), and [OpenAI Gym](https://github.com/openai/gym). RLlib and torchrl have a large number of already implemented RL algorithms. +- **Compatible**: VMAS has wrappers for [RLlib](https://docs.ray.io/en/latest/rllib/index.html), [torchrl](https://pytorch.org/rl/reference/generated/torchrl.envs.libs.vmas.VmasEnv.html), [OpenAI Gym](https://github.com/openai/gym) and [Gymnasium](https://gymnasium.farama.org/). RLlib and torchrl have a large number of already implemented RL algorithms. Keep in mind that this interface is less efficient than the unwrapped version. For an example of wrapping, see the main of `make_env`. - **Tested**: Our scenarios come with tests which run a custom designed heuristic on each scenario. - **Entity shapes**: Our entities (agent and landmarks) can have different customizable shapes (spheres, boxes, lines). diff --git a/docs/source/conf.py b/docs/source/conf.py index 19983c59..ca1bef31 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,7 +39,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "sphinx": ("https://www.sphinx-doc.org/en/master/", None), - "torch": ("https://pytorch.org/docs/master", None), + "torch": ("https://pytorch.org/docs/stable/", None), "torchrl": ("https://pytorch.org/rl/stable/", None), "tensordict": ("https://pytorch.org/tensordict/stable", None), "benchmarl": ("https://benchmarl.readthedocs.io/en/latest/", None), diff --git a/docs/source/usage/installation.rst b/docs/source/usage/installation.rst index f81116b2..15a6f2b8 100644 --- a/docs/source/usage/installation.rst +++ b/docs/source/usage/installation.rst @@ -29,6 +29,21 @@ Install optional requirements By default, vmas has only the core requirements. Here are some optional packages you may want to install. +Wrappers +^^^^^^^^ + +If you want to use VMAS environment wrappers, you may want to install VMAS +with the following options: + +.. code-block:: console + + # install gymnasium for gymnasium wrapper + pip install vmas[gymnasium] + + # install rllib for rllib wrapper + pip install vmas[rllib] + + Training ^^^^^^^^ @@ -40,12 +55,14 @@ You may want to install one of the following training libraries pip install torchrl pip install "ray[rllib]"==2.1.0 # We support versions "ray[rllib]<=2.2,>=1.13" -Logging -^^^^^^^ +Utils +^^^^^ -You may want to install the following rendering and logging tools +You may want to install the following additional tools .. code-block:: console - pip install wandb - pip install opencv-python moviepy matplotlib + # install rendering dependencies + pip install vmas[render] + # install testing dependencies + pip install vmas[test] diff --git a/setup.py b/setup.py index feb6e156..c67edbf3 100644 --- a/setup.py +++ b/setup.py @@ -30,5 +30,11 @@ def get_version(): author_email="mb2389@cl.cam.ac.uk", packages=find_packages(), install_requires=["numpy", "torch", "pyglet<=1.5.27", "gym", "six"], + extras_require={ + "gymnasium": ["gymnasium", "shimmy"], + "rllib": ["ray[rllib]<=2.2"], + "render": ["opencv-python", "moviepy", "matplotlib", "opencv-python"], + "test": ["pytest", "pytest-instafail", "pyyaml", "tqdm"], + }, include_package_data=True, ) diff --git a/tests/test_vmas.py b/tests/test_vmas.py index 6a9caa43..3a782b36 100644 --- a/tests/test_vmas.py +++ b/tests/test_vmas.py @@ -2,7 +2,6 @@ # ProrokLab (https://www.proroklab.org/) # All rights reserved. import math -import os import random import sys from pathlib import Path @@ -18,13 +17,9 @@ def scenario_names(): scenarios = [] scenarios_folder = Path(__file__).parent.parent / "vmas" / "scenarios" - for _, _, filenames in os.walk(scenarios_folder): - scenarios += filenames - scenarios = [ - scenario.split(".")[0] - for scenario in scenarios - if scenario.endswith(".py") and not scenario.startswith("__") - ] + for path in scenarios_folder.glob("**/*.py"): + if path.is_file() and not path.name.startswith("__"): + scenarios.append(path.stem) return scenarios diff --git a/tests/test_wrappers/__init__.py b/tests/test_wrappers/__init__.py new file mode 100644 index 00000000..7bc178aa --- /dev/null +++ b/tests/test_wrappers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. diff --git a/tests/test_wrappers/test_gym_wrapper.py b/tests/test_wrappers/test_gym_wrapper.py new file mode 100644 index 00000000..e9f746fe --- /dev/null +++ b/tests/test_wrappers/test_gym_wrapper.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. + +import gym +import numpy as np +import pytest +from torch import Tensor + +from vmas import make_env +from vmas.simulator.environment import Environment + + +TEST_SCENARIOS = [ + "balance", + "discovery", + "give_way", + "joint_passage", + "navigation", + "passage", + "transport", + "waterfall", + "simple_world_comm", +] + + +def _check_obs_type(obss, obs_shapes, dict_space, return_numpy): + if dict_space: + assert isinstance( + obss, dict + ), f"Expected dictionary of observations, got {type(obss)}" + for k, obs in obss.items(): + obs_shape = obs_shapes[k] + assert ( + obs.shape == obs_shape + ), f"Expected shape {obs_shape}, got {obs.shape}" + if return_numpy: + assert isinstance( + obs, np.ndarray + ), f"Expected numpy array, got {type(obs)}" + else: + assert isinstance( + obs, Tensor + ), f"Expected torch tensor, got {type(obs)}" + else: + assert isinstance( + obss, list + ), f"Expected list of observations, got {type(obss)}" + for obs, shape in zip(obss, obs_shapes): + assert obs.shape == shape, f"Expected shape {shape}, got {obs.shape}" + if return_numpy: + assert isinstance( + obs, np.ndarray + ), f"Expected numpy array, got {type(obs)}" + else: + assert isinstance( + obs, Tensor + ), f"Expected torch tensor, got {type(obs)}" + + +@pytest.mark.parametrize("scenario", TEST_SCENARIOS) +@pytest.mark.parametrize("return_numpy", [True, False]) +@pytest.mark.parametrize("continuous_actions", [True, False]) +@pytest.mark.parametrize("dict_space", [True, False]) +def test_gym_wrapper( + scenario, return_numpy, continuous_actions, dict_space, max_steps=10 +): + env = make_env( + scenario=scenario, + num_envs=1, + device="cpu", + continuous_actions=continuous_actions, + dict_spaces=dict_space, + wrapper="gym", + wrapper_kwargs={"return_numpy": return_numpy}, + max_steps=max_steps, + ) + + assert ( + len(env.observation_space) == env.unwrapped.n_agents + ), "Expected one observation per agent" + assert ( + len(env.action_space) == env.unwrapped.n_agents + ), "Expected one action per agent" + if dict_space: + assert isinstance( + env.observation_space, gym.spaces.Dict + ), "Expected Dict observation space" + assert isinstance( + env.action_space, gym.spaces.Dict + ), "Expected Dict action space" + obs_shapes = { + k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() + } + else: + assert isinstance( + env.observation_space, gym.spaces.Tuple + ), "Expected Tuple observation space" + assert isinstance( + env.action_space, gym.spaces.Tuple + ), "Expected Tuple action space" + obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] + + assert isinstance( + env.unwrapped, Environment + ), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" + + obss = env.reset() + _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) + + for _ in range(max_steps): + actions = [ + env.unwrapped.get_random_action(agent).numpy() + for agent in env.unwrapped.agents + ] + obss, rews, done, info = env.step(actions) + _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) + + assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" + if not dict_space: + assert isinstance( + rews, list + ), f"Expected list of rewards but got {type(rews)}" + + rew_values = rews + else: + assert isinstance( + rews, dict + ), f"Expected dictionary of rewards but got {type(rews)}" + rew_values = list(rews.values()) + assert all( + isinstance(rew, float) for rew in rew_values + ), f"Expected float rewards but got {type(rew_values[0])}" + + assert isinstance(done, bool), f"Expected bool for done but got {type(done)}" + + assert isinstance( + info, dict + ), f"Expected info to be a dictionary but got {type(info)}" + + assert done, "Expected done to be True after 100 steps" diff --git a/tests/test_wrappers/test_gymnasium_vec_wrapper.py b/tests/test_wrappers/test_gymnasium_vec_wrapper.py new file mode 100644 index 00000000..93cb4b1c --- /dev/null +++ b/tests/test_wrappers/test_gymnasium_vec_wrapper.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. + +import gymnasium as gym +import numpy as np +import pytest +import torch +from vmas import make_env +from vmas.simulator.environment import Environment + +from test_wrappers.test_gym_wrapper import _check_obs_type, TEST_SCENARIOS + + +@pytest.mark.parametrize("scenario", TEST_SCENARIOS) +@pytest.mark.parametrize("return_numpy", [True, False]) +@pytest.mark.parametrize("continuous_actions", [True, False]) +@pytest.mark.parametrize("dict_space", [True, False]) +@pytest.mark.parametrize("num_envs", [1, 10]) +def test_gymnasium_wrapper( + scenario, return_numpy, continuous_actions, dict_space, num_envs, max_steps=10 +): + env = make_env( + scenario=scenario, + num_envs=num_envs, + device="cpu", + continuous_actions=continuous_actions, + dict_spaces=dict_space, + wrapper="gymnasium_vec", + terminated_truncated=True, + wrapper_kwargs={"return_numpy": return_numpy}, + max_steps=max_steps, + ) + + assert isinstance( + env.unwrapped, Environment + ), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" + + assert ( + len(env.observation_space) == env.unwrapped.n_agents + ), "Expected one observation per agent" + assert ( + len(env.action_space) == env.unwrapped.n_agents + ), "Expected one action per agent" + if dict_space: + assert isinstance( + env.observation_space, gym.spaces.Dict + ), "Expected Dict observation space" + assert isinstance( + env.action_space, gym.spaces.Dict + ), "Expected Dict action space" + obs_shapes = { + k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() + } + else: + assert isinstance( + env.observation_space, gym.spaces.Tuple + ), "Expected Tuple observation space" + assert isinstance( + env.action_space, gym.spaces.Tuple + ), "Expected Tuple action space" + obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] + + obss, info = env.reset() + _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) + assert isinstance( + info, dict + ), f"Expected info to be a dictionary but got {type(info)}" + + for _ in range(max_steps): + actions = [ + env.unwrapped.get_random_action(agent).numpy() + for agent in env.unwrapped.agents + ] + obss, rews, terminated, truncated, info = env.step(actions) + _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) + + assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" + if not dict_space: + assert isinstance( + rews, list + ), f"Expected list of rewards but got {type(rews)}" + + rew_values = rews + else: + assert isinstance( + rews, dict + ), f"Expected dictionary of rewards but got {type(rews)}" + rew_values = list(rews.values()) + if return_numpy: + assert all( + isinstance(rew, np.ndarray) for rew in rew_values + ), f"Expected np.array rewards but got {type(rew_values[0])}" + else: + assert all( + isinstance(rew, torch.Tensor) for rew in rew_values + ), f"Expected torch tensor rewards but got {type(rew_values[0])}" + + if return_numpy: + assert isinstance( + terminated, np.ndarray + ), f"Expected np.array for terminated but got {type(terminated)}" + assert isinstance( + truncated, np.ndarray + ), f"Expected np.array for truncated but got {type(truncated)}" + else: + assert isinstance( + terminated, torch.Tensor + ), f"Expected torch tensor for terminated but got {type(terminated)}" + assert isinstance( + truncated, torch.Tensor + ), f"Expected torch tensor for truncated but got {type(truncated)}" + + assert isinstance( + info, dict + ), f"Expected info to be a dictionary but got {type(info)}" + + assert all(truncated), "Expected done to be True after 100 steps" diff --git a/tests/test_wrappers/test_gymnasium_wrapper.py b/tests/test_wrappers/test_gymnasium_wrapper.py new file mode 100644 index 00000000..108110a0 --- /dev/null +++ b/tests/test_wrappers/test_gymnasium_wrapper.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. + +import gymnasium as gym +import pytest +from vmas import make_env +from vmas.simulator.environment import Environment + +from test_wrappers.test_gym_wrapper import _check_obs_type, TEST_SCENARIOS + + +@pytest.mark.parametrize("scenario", TEST_SCENARIOS) +@pytest.mark.parametrize("return_numpy", [True, False]) +@pytest.mark.parametrize("continuous_actions", [True, False]) +@pytest.mark.parametrize("dict_space", [True, False]) +def test_gymnasium_wrapper( + scenario, return_numpy, continuous_actions, dict_space, max_steps=10 +): + env = make_env( + scenario=scenario, + num_envs=1, + device="cpu", + continuous_actions=continuous_actions, + dict_spaces=dict_space, + wrapper="gymnasium", + terminated_truncated=True, + wrapper_kwargs={"return_numpy": return_numpy}, + max_steps=max_steps, + ) + + assert ( + len(env.observation_space) == env.unwrapped.n_agents + ), "Expected one observation per agent" + assert ( + len(env.action_space) == env.unwrapped.n_agents + ), "Expected one action per agent" + if dict_space: + assert isinstance( + env.observation_space, gym.spaces.Dict + ), "Expected Dict observation space" + assert isinstance( + env.action_space, gym.spaces.Dict + ), "Expected Dict action space" + obs_shapes = { + k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() + } + else: + assert isinstance( + env.observation_space, gym.spaces.Tuple + ), "Expected Tuple observation space" + assert isinstance( + env.action_space, gym.spaces.Tuple + ), "Expected Tuple action space" + obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] + + assert isinstance( + env.unwrapped, Environment + ), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" + + obss, info = env.reset() + _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) + assert isinstance( + info, dict + ), f"Expected info to be a dictionary but got {type(info)}" + + for _ in range(max_steps): + actions = [ + env.unwrapped.get_random_action(agent).numpy() + for agent in env.unwrapped.agents + ] + obss, rews, terminated, truncated, info = env.step(actions) + _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) + + assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" + if not dict_space: + assert isinstance( + rews, list + ), f"Expected list of rewards but got {type(rews)}" + + rew_values = rews + else: + assert isinstance( + rews, dict + ), f"Expected dictionary of rewards but got {type(rews)}" + rew_values = list(rews.values()) + assert all( + isinstance(rew, float) for rew in rew_values + ), f"Expected float rewards but got {type(rew_values[0])}" + + assert isinstance( + terminated, bool + ), f"Expected bool for terminated but got {type(terminated)}" + assert isinstance( + truncated, bool + ), f"Expected bool for truncated but got {type(truncated)}" + + assert isinstance( + info, dict + ), f"Expected info to be a dictionary but got {type(info)}" + + assert truncated, "Expected done to be True after 100 steps" diff --git a/vmas/interactive_rendering.py b/vmas/interactive_rendering.py index e343a86d..e1467ae1 100644 --- a/vmas/interactive_rendering.py +++ b/vmas/interactive_rendering.py @@ -10,6 +10,8 @@ If you have more than 1 agent, you can control another one with W,A,S,D and switch the agent with these controls using LSHIFT """ + +from argparse import ArgumentParser, BooleanOptionalAction from operator import add from typing import Dict, Union @@ -17,7 +19,6 @@ from torch import Tensor from vmas.make_env import make_env -from vmas.simulator.environment import Wrapper from vmas.simulator.environment.gym import GymWrapper from vmas.simulator.scenario import BaseScenario from vmas.simulator.utils import save_video @@ -49,9 +50,9 @@ def __init__( # hard-coded keyboard events self.current_agent_index = 0 self.current_agent_index2 = 1 - self.n_agents = self.env.unwrapped().n_agents - self.agents = self.env.unwrapped().agents - self.continuous = self.env.unwrapped().continuous_actions + self.n_agents = self.env.unwrapped.n_agents + self.agents = self.env.unwrapped.agents + self.continuous = self.env.unwrapped.continuous_actions self.reset = False self.keys = np.array( [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] @@ -74,10 +75,10 @@ def __init__( self.text_lines = [] self.font_size = 15 self.env.render() - self.text_idx = len(self.env.unwrapped().text_lines) + self.text_idx = len(self.env.unwrapped.text_lines) self._init_text() - self.env.unwrapped().viewer.window.on_key_press = self._key_press - self.env.unwrapped().viewer.window.on_key_release = self._key_release + self.env.unwrapped.viewer.window.on_key_press = self._key_press + self.env.unwrapped.viewer.window.on_key_release = self._key_release self._cycle() @@ -95,7 +96,7 @@ def _cycle(self): save_video( self.render_name, self.frame_list, - fps=1 / self.env.env.world.dt, + fps=1 / self.env.unwrapped.world.dt, ) self.env.reset() self.reset = False @@ -117,6 +118,7 @@ def _cycle(self): ] = self.u2[ : self.agents[self.current_agent_index2].dynamics.needed_action_size ] + obs, rew, done, info = self.env.step(action_list) if self.display_info and self.n_agents > 0: @@ -137,7 +139,7 @@ def _cycle(self): message = f"Done: {done}" self._write_values(4, message) - message = f"Selected: {self.env.unwrapped().agents[self.current_agent_index].name}" + message = f"Selected: {self.env.unwrapped.agents[self.current_agent_index].name}" self._write_values(5, message) frame = self.env.render( @@ -157,7 +159,7 @@ def _init_text(self): text_line = rendering.TextLine( y=(self.text_idx + i) * 40, font_size=self.font_size ) - self.env.unwrapped().viewer.add_geom(text_line) + self.env.unwrapped.viewer.add_geom(text_line) self.text_lines.append(text_line) def _write_values(self, index: int, message: str): @@ -292,8 +294,8 @@ def set_u(self): @staticmethod def format_obs(obs): - if isinstance(obs, Tensor): - return list(np.around(obs.cpu().tolist(), decimals=2)) + if isinstance(obs, (Tensor, np.ndarray)): + return list(np.around(obs.tolist(), decimals=2)) elif isinstance(obs, Dict): return {key: InteractiveEnv.format_obs(value) for key, value in obs.items()} else: @@ -343,8 +345,9 @@ def render_interactively( num_envs=1, device="cpu", continuous_actions=True, - wrapper=Wrapper.GYM, + wrapper="gym", seed=0, + wrapper_kwargs={"return_numpy": False}, # Environment specific variables **kwargs, ), @@ -357,6 +360,34 @@ def render_interactively( ) +def parse_args(): + parser = ArgumentParser(description="Interactive rendering") + parser.add_argument( + "--scenario", + type=str, + default="waterfall", + help="Scenario to load. Can be the name of a file in `vmas.scenarios` folder or a :class:`~vmas.simulator.scenario.BaseScenario` class", + ) + parser.add_argument( + "--control_two_agents", + action=BooleanOptionalAction, + default=True, + help="Whether to control two agents or just one", + ) + parser.add_argument( + "--display_info", + action=BooleanOptionalAction, + default=True, + help="Whether to display on the screen the following info from the first controlled agent: name, reward, total reward, done, and observation", + ) + parser.add_argument( + "--save_render", + action="store_true", + help="Whether to save a video of the render up to the first reset", + ) + return parser.parse_args() + + if __name__ == "__main__": # Use this script to interactively play with scenarios # @@ -365,14 +396,11 @@ def render_interactively( # You can control agent actions with the arrow keys and M/N (left/right control the first action, up/down control the second, M/N controls the third) # If you have more than 1 agent, you can control another one with W,A,S,D and Q,E in the same way. # and switch the agent with these controls using LSHIFT - - scenario_name = "waterfall" - - # Scenario specific variables + args = parse_args() render_interactively( - scenario_name, - control_two_agents=True, - save_render=False, - display_info=True, + scenario=args.scenario, + control_two_agents=args.control_two_agents, + save_render=args.save_render, + display_info=args.display_info, ) diff --git a/vmas/make_env.py b/vmas/make_env.py index 44cc81af..8ed6a3d9 100644 --- a/vmas/make_env.py +++ b/vmas/make_env.py @@ -15,15 +15,15 @@ def make_env( num_envs: int, device: DEVICE_TYPING = "cpu", continuous_actions: bool = True, - wrapper: Optional[ - Wrapper - ] = None, # One of: None, vmas.Wrapper.RLLIB, and vmas.Wrapper.GYM + wrapper: Optional[Union[Wrapper, str]] = None, max_steps: Optional[int] = None, seed: Optional[int] = None, dict_spaces: bool = False, multidiscrete_actions: bool = False, clamp_actions: bool = False, grad_enabled: bool = False, + terminated_truncated: bool = False, + wrapper_kwargs: Optional[dict] = None, **kwargs, ): """Create a vmas environment. @@ -38,8 +38,8 @@ def make_env( will be placed on this device. Default is ``"cpu"``, continuous_actions (bool, optional): Whether to use continuous actions. If ``False``, actions will be discrete. The number of actions and their size will depend on the chosen scenario. Default is ``True``, - wrapper (:class:`~vmas.simulator.environment.Wrapper`, optional): Wrapper class to use. For example can be Wrapper.RLLIB. - Default is ``None``, + wrapper (Union[Wrapper, str], optional): Wrapper class to use. For example, it can be + ``"rllib"``, ``"gym"``, ``"gymnasium"``, ``"gymnasium_vec"``. Default is ``None``. max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can be terminating or not. If ``max_steps`` is specified, the scenario is also terminated whenever this horizon is reached, @@ -53,6 +53,9 @@ def make_env( an error when ``continuous_actions==True`` and actions are out of bounds, grad_enabled (bool, optional): If ``True`` the simulator will not call ``detach()`` on input actions and gradients can be taken from the simulator output. Default is ``False``. + terminated_truncated (bool, optional): Weather to use terminated and truncated flags in the output of the step method (or single done). + Default is ``False``. + wrapper_kwargs (dict, optional): Keyword arguments to pass to the wrapper class. Default is ``{}``. **kwargs (dict, optional): Keyword arguments to pass to the :class:`~vmas.simulator.scenario.BaseScenario` class. Examples: @@ -72,6 +75,7 @@ def make_env( if not scenario.endswith(".py"): scenario += ".py" scenario = scenarios.load(scenario).Scenario() + env = Environment( scenario, num_envs=num_envs, @@ -83,7 +87,14 @@ def make_env( multidiscrete_actions=multidiscrete_actions, clamp_actions=clamp_actions, grad_enabled=grad_enabled, + terminated_truncated=terminated_truncated, **kwargs, ) - return wrapper.get_env(env) if wrapper is not None else env + if wrapper is not None and isinstance(wrapper, str): + wrapper = Wrapper[wrapper.upper()] + + if wrapper_kwargs is None: + wrapper_kwargs = {} + + return wrapper.get_env(env, **wrapper_kwargs) if wrapper is not None else env diff --git a/vmas/simulator/environment/__init__.py b/vmas/simulator/environment/__init__.py index bd0ff306..eff37a66 100644 --- a/vmas/simulator/environment/__init__.py +++ b/vmas/simulator/environment/__init__.py @@ -9,13 +9,25 @@ class Wrapper(Enum): RLLIB = 0 GYM = 1 + GYMNASIUM = 2 + GYMNASIUM_VEC = 3 - def get_env(self, env: Environment): + def get_env(self, env: Environment, **kwargs): if self is self.RLLIB: from vmas.simulator.environment.rllib import VectorEnvWrapper - return VectorEnvWrapper(env) + return VectorEnvWrapper(env, **kwargs) elif self is self.GYM: from vmas.simulator.environment.gym import GymWrapper - return GymWrapper(env) + return GymWrapper(env, **kwargs) + elif self is self.GYMNASIUM: + from vmas.simulator.environment.gym.gymnasium import GymnasiumWrapper + + return GymnasiumWrapper(env, **kwargs) + elif self is self.GYMNASIUM_VEC: + from vmas.simulator.environment.gym.gymnasium_vec import ( + GymnasiumVectorizedWrapper, + ) + + return GymnasiumVectorizedWrapper(env, **kwargs) diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 9779a150..3ef6c385 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -8,6 +8,7 @@ import numpy as np import torch + from gym import spaces from torch import Tensor @@ -45,6 +46,7 @@ def __init__( multidiscrete_actions: bool = False, clamp_actions: bool = False, grad_enabled: bool = False, + terminated_truncated: bool = False, **kwargs, ): if multidiscrete_actions: @@ -64,6 +66,7 @@ def __init__( self.dict_spaces = dict_spaces self.clamp_action = clamp_actions self.grad_enabled = grad_enabled + self.terminated_truncated = terminated_truncated observations = self.reset(seed=seed) @@ -140,7 +143,7 @@ def get_from_scenario( if dict_agent_names is None: dict_agent_names = self.dict_spaces - obs = rewards = infos = dones = None + obs = rewards = infos = terminated = truncated = dones = None if get_observations: obs = {} if dict_agent_names else [] @@ -173,10 +176,15 @@ def get_from_scenario( else: infos.append(info) - if get_dones: - dones = self.done() + if self.terminated_truncated: + if get_dones: + terminated, truncated = self.done() + result = [obs, rewards, terminated, truncated, infos] + else: + if get_dones: + dones = self.done() + result = [obs, rewards, dones, infos] - result = [obs, rewards, dones, infos] return [data for data in result if data is not None] def seed(self, seed=None): @@ -260,20 +268,25 @@ def step(self, actions: Union[List, Dict]): self.scenario.post_step() self.steps += 1 - obs, rewards, dones, infos = self.get_from_scenario( + + return self.get_from_scenario( get_observations=True, get_infos=True, get_rewards=True, get_dones=True, ) - return obs, rewards, dones, infos - def done(self): - dones = self.scenario.done().clone() + terminated = self.scenario.done().clone() if self.max_steps is not None: - dones += self.steps >= self.max_steps - return dones + truncated = self.steps >= self.max_steps + else: + truncated = torch.zeros_like(terminated) + + if self.terminated_truncated: + return terminated, truncated + else: + return terminated + truncated def get_action_space(self): if not self.dict_spaces: @@ -839,10 +852,13 @@ def plot_function( if plot_range is None: assert self.viewer.bounds is not None, "Set viewer bounds before plotting" x_min, x_max, y_min, y_max = self.viewer.bounds.tolist() - plot_range = [x_min - precision, x_max - precision], [ - y_min - precision, - y_max + precision, - ] + plot_range = ( + [x_min - precision, x_max - precision], + [ + y_min - precision, + y_max + precision, + ], + ) geom = render_function_util( f=f, diff --git a/vmas/simulator/environment/gym.py b/vmas/simulator/environment/gym.py deleted file mode 100644 index 745c7b6e..00000000 --- a/vmas/simulator/environment/gym.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2022-2024. -# ProrokLab (https://www.proroklab.org/) -# All rights reserved. -from typing import List, Optional - -import gym -import numpy as np -import torch - -from vmas.simulator.environment.environment import Environment -from vmas.simulator.utils import extract_nested_with_index - - -class GymWrapper(gym.Env): - metadata = Environment.metadata - - def __init__( - self, - env: Environment, - ): - assert ( - env.num_envs == 1 - ), f"GymEnv wrapper is not vectorised, got env.num_envs: {env.num_envs}" - - self._env = env - self.observation_space = self._env.observation_space - self.action_space = self._env.action_space - - def unwrapped(self) -> Environment: - return self._env - - @property - def env(self): - return self._env - - def step(self, action): - action = self._action_list_to_tensor(action) - obs, rews, done, info = self._env.step(action) - done = done[0].item() - if self._env.dict_spaces: - for agent in obs.keys(): - obs[agent] = extract_nested_with_index(obs[agent], index=0) - info[agent] = extract_nested_with_index(info[agent], index=0) - rews[agent] = rews[agent][0].item() - else: - for i in range(self._env.n_agents): - obs[i] = extract_nested_with_index(obs[i], index=0) - info[i] = extract_nested_with_index(info[i], index=0) - rews[i] = rews[i][0].item() - return obs, rews, done, info - - def reset( - self, - *, - seed: Optional[int] = None, - return_info: bool = False, - options: Optional[dict] = None, - ): - if seed is not None: - self._env.seed(seed) - obs = self._env.reset_at(index=0) - if self._env.dict_spaces: - for agent in obs.keys(): - obs[agent] = extract_nested_with_index(obs[agent], index=0) - else: - for i in range(self._env.n_agents): - obs[i] = extract_nested_with_index(obs[i], index=0) - return obs - - def render( - self, - mode="human", - agent_index_focus: Optional[int] = None, - visualize_when_rgb: bool = False, - **kwargs, - ) -> Optional[np.ndarray]: - return self._env.render( - mode=mode, - env_index=0, - agent_index_focus=agent_index_focus, - visualize_when_rgb=visualize_when_rgb, - **kwargs, - ) - - def _action_list_to_tensor(self, list_in: List) -> List: - assert ( - len(list_in) == self._env.n_agents - ), f"Expecting actions for {self._env.n_agents} agents, got {len(list_in)} actions" - actions = [] - for agent in self._env.agents: - actions.append( - torch.zeros( - 1, - self._env.get_agent_action_size(agent), - device=self._env.device, - dtype=torch.float32, - ) - ) - - for i in range(self._env.n_agents): - act = torch.tensor(list_in[i], dtype=torch.float32, device=self._env.device) - if len(act.shape) == 0: - assert ( - self._env.get_agent_action_size(self._env.agents[i]) == 1 - ), f"Action of agent {i} is supposed to be an scalar int" - else: - assert len(act.shape) == 1 and act.shape[ - 0 - ] == self._env.get_agent_action_size(self._env.agents[i]), ( - f"Action of agent {i} hase wrong shape: " - f"expected {self._env.get_agent_action_size(self._env.agents[i])}, got {act.shape[0]}" - ) - actions[i][0] = act - return actions diff --git a/vmas/simulator/environment/gym/__init__.py b/vmas/simulator/environment/gym/__init__.py new file mode 100644 index 00000000..9eed47da --- /dev/null +++ b/vmas/simulator/environment/gym/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. + +from .gym import GymWrapper diff --git a/vmas/simulator/environment/gym/base.py b/vmas/simulator/environment/gym/base.py new file mode 100644 index 00000000..4d2bb78b --- /dev/null +++ b/vmas/simulator/environment/gym/base.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. + +from abc import ABC, abstractmethod +from collections import namedtuple +from typing import List, Optional + +import numpy as np +import torch + +from vmas.simulator.environment import Environment + +from vmas.simulator.utils import extract_nested_with_index, TorchUtils + + +EnvData = namedtuple( + "EnvData", ["obs", "rews", "terminated", "truncated", "done", "info"] +) + + +class BaseGymWrapper(ABC): + def __init__(self, env: Environment, return_numpy: bool, vectorized: bool): + self._env = env + self.return_numpy = return_numpy + self.dict_spaces = env.dict_spaces + self.vectorized = vectorized + + @property + def env(self): + return self._env + + def _maybe_to_numpy(self, tensor): + return TorchUtils.to_numpy(tensor) if self.return_numpy else tensor + + def _convert_output(self, data, item: bool = False): + if not self.vectorized: + data = extract_nested_with_index(data, index=0) + if item: + return data.item() + return self._maybe_to_numpy(data) + + def _compress_infos(self, infos): + if isinstance(infos, dict): + return infos + elif isinstance(infos, list): + return {self._env.agents[i].name: info for i, info in enumerate(infos)} + else: + raise ValueError( + f"Expected list or dictionary for infos but got {type(infos)}" + ) + + def _convert_env_data( + self, obs=None, rews=None, info=None, terminated=None, truncated=None, done=None + ): + if self.dict_spaces: + for agent in obs.keys(): + if obs is not None: + obs[agent] = self._convert_output(obs[agent]) + if info is not None: + info[agent] = self._convert_output(info[agent]) + if rews is not None: + rews[agent] = self._convert_output(rews[agent], item=True) + else: + for i in range(self._env.n_agents): + if obs is not None: + obs[i] = self._convert_output(obs[i]) + if info is not None: + info[i] = self._convert_output(info[i]) + if rews is not None: + rews[i] = self._convert_output(rews[i], item=True) + terminated = ( + self._convert_output(terminated, item=True) + if terminated is not None + else None + ) + truncated = ( + self._convert_output(truncated, item=True) + if truncated is not None + else None + ) + done = self._convert_output(done, item=True) if done is not None else None + info = self._compress_infos(info) if info is not None else None + return EnvData( + obs=obs, + rews=rews, + terminated=terminated, + truncated=truncated, + done=done, + info=info, + ) + + def _action_list_to_tensor(self, list_in: List) -> List: + assert ( + len(list_in) == self._env.n_agents + ), f"Expecting actions for {self._env.n_agents} agents, got {len(list_in)} actions" + + dtype = torch.float32 if self._env.continuous_actions else torch.long + + return [ + torch.tensor(act, device=self._env.device, dtype=dtype).reshape( + self._env.num_envs, self._env.get_agent_action_size(agent) + ) + if not isinstance(act, torch.Tensor) + else act.to(dtype=dtype, device=self._env.device).reshape( + self._env.num_envs, self._env.get_agent_action_size(agent) + ) + for agent, act in zip(self._env.agents, list_in) + ] + + @abstractmethod + def step(self, action): + raise NotImplementedError + + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ): + raise NotImplementedError + + @abstractmethod + def render( + self, + agent_index_focus: Optional[int] = None, + visualize_when_rgb: bool = False, + **kwargs, + ) -> Optional[np.ndarray]: + raise NotImplementedError diff --git a/vmas/simulator/environment/gym/gym.py b/vmas/simulator/environment/gym/gym.py new file mode 100644 index 00000000..bbcb25b8 --- /dev/null +++ b/vmas/simulator/environment/gym/gym.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022-2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. +from typing import Optional + +import gym +import numpy as np + +from vmas.simulator.environment.environment import Environment +from vmas.simulator.environment.gym.base import BaseGymWrapper + + +class GymWrapper(gym.Env, BaseGymWrapper): + metadata = Environment.metadata + + def __init__( + self, + env: Environment, + return_numpy: bool = True, + ): + super().__init__(env, return_numpy=return_numpy, vectorized=False) + assert ( + env.num_envs == 1 + ), f"GymEnv wrapper is not vectorised, got env.num_envs: {env.num_envs}" + + assert ( + not self._env.terminated_truncated + ), "GymWrapper is not compatible with termination and truncation flags. Please set `terminated_truncated=False` in the VMAS environment." + self.observation_space = self._env.observation_space + self.action_space = self._env.action_space + + @property + def unwrapped(self) -> Environment: + return self._env + + def step(self, action): + action = self._action_list_to_tensor(action) + obs, rews, done, info = self._env.step(action) + env_data = self._convert_env_data( + obs=obs, + rews=rews, + info=info, + done=done, + ) + return env_data.obs, env_data.rews, env_data.done, env_data.info + + def reset( + self, + *, + seed: Optional[int] = None, + return_info: bool = False, + options: Optional[dict] = None, + ): + if seed is not None: + self._env.seed(seed) + obs = self._env.reset_at(index=0) + env_data = self._convert_env_data(obs=obs) + return env_data.obs + + def render( + self, + mode="human", + agent_index_focus: Optional[int] = None, + visualize_when_rgb: bool = False, + **kwargs, + ) -> Optional[np.ndarray]: + return self._env.render( + mode=mode, + env_index=0, + agent_index_focus=agent_index_focus, + visualize_when_rgb=visualize_when_rgb, + **kwargs, + ) diff --git a/vmas/simulator/environment/gym/gymnasium.py b/vmas/simulator/environment/gym/gymnasium.py new file mode 100644 index 00000000..db797857 --- /dev/null +++ b/vmas/simulator/environment/gym/gymnasium.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022-2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. +import importlib +from typing import Optional + +import numpy as np + +from vmas.simulator.environment.environment import Environment +from vmas.simulator.environment.gym.base import BaseGymWrapper + + +if ( + importlib.util.find_spec("gymnasium") is not None + and importlib.util.find_spec("shimmy") is not None +): + import gymnasium as gym + from shimmy.openai_gym_compatibility import _convert_space +else: + raise ImportError( + "Gymnasium or shimmy is not installed. Please install it with `pip install gymnasium shimmy`." + ) + + +class GymnasiumWrapper(gym.Env, BaseGymWrapper): + metadata = Environment.metadata + + def __init__( + self, + env: Environment, + return_numpy: bool = True, + render_mode: str = "human", + ): + super().__init__(env, return_numpy=return_numpy, vectorized=False) + assert ( + env.num_envs == 1 + ), "GymnasiumEnv wrapper only supports singleton VMAS environment! For vectorized environments, use vectorized wrapper with `wrapper=gymnasium_vec`." + + assert ( + self._env.terminated_truncated + ), "GymnasiumWrapper is only compatible with termination and truncation flags. Please set `terminated_truncated=True` in the VMAS environment." + self.observation_space = _convert_space(self._env.observation_space) + self.action_space = _convert_space(self._env.action_space) + self.render_mode = render_mode + + @property + def unwrapped(self) -> Environment: + return self._env + + def step(self, action): + action = self._action_list_to_tensor(action) + obs, rews, terminated, truncated, info = self._env.step(action) + env_data = self._convert_env_data( + obs=obs, rews=rews, info=info, terminated=terminated, truncated=truncated + ) + return ( + env_data.obs, + env_data.rews, + env_data.terminated, + env_data.truncated, + env_data.info, + ) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ): + if seed is not None: + self._env.seed(seed) + obs, info = self._env.reset_at(index=0, return_info=True) + env_data = self._convert_env_data(obs=obs, info=info) + return env_data.obs, env_data.info + + def render( + self, + agent_index_focus: Optional[int] = None, + visualize_when_rgb: bool = False, + **kwargs, + ) -> Optional[np.ndarray]: + return self._env.render( + mode=self.render_mode, + env_index=0, + agent_index_focus=agent_index_focus, + visualize_when_rgb=visualize_when_rgb, + **kwargs, + ) diff --git a/vmas/simulator/environment/gym/gymnasium_vec.py b/vmas/simulator/environment/gym/gymnasium_vec.py new file mode 100644 index 00000000..bc9c6423 --- /dev/null +++ b/vmas/simulator/environment/gym/gymnasium_vec.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022-2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. +import importlib +from typing import Optional + +import numpy as np + +from vmas.simulator.environment.environment import Environment +from vmas.simulator.environment.gym.base import BaseGymWrapper + + +if ( + importlib.util.find_spec("gymnasium") is not None + and importlib.util.find_spec("shimmy") is not None +): + import gymnasium as gym + from gymnasium.vector.utils import batch_space + from shimmy.openai_gym_compatibility import _convert_space +else: + raise ImportError( + "Gymnasium or shimmy is not installed. Please install it with `pip install gymnasium shimmy`." + ) + + +class GymnasiumVectorizedWrapper(gym.Env, BaseGymWrapper): + metadata = Environment.metadata + + def __init__( + self, + env: Environment, + return_numpy: bool = True, + render_mode: str = "human", + ): + super().__init__(env, return_numpy=return_numpy, vectorized=True) + self._num_envs = self._env.num_envs + assert ( + self._env.terminated_truncated + ), "GymnasiumWrapper is only compatible with termination and truncation flags. Please set `terminated_truncated=True` in the VMAS environment." + self.single_observation_space = _convert_space(self._env.observation_space) + self.single_action_space = _convert_space(self._env.action_space) + self.observation_space = batch_space( + self.single_observation_space, n=self._num_envs + ) + self.action_space = batch_space(self.single_action_space, n=self._num_envs) + self.render_mode = render_mode + + @property + def unwrapped(self) -> Environment: + return self._env + + def step(self, action): + action = self._action_list_to_tensor(action) + obs, rews, terminated, truncated, info = self._env.step(action) + env_data = self._convert_env_data( + obs=obs, rews=rews, info=info, terminated=terminated, truncated=truncated + ) + return ( + env_data.obs, + env_data.rews, + env_data.terminated, + env_data.truncated, + env_data.info, + ) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ): + if seed is not None: + self._env.seed(seed) + obs, info = self._env.reset(return_info=True) + env_data = self._convert_env_data(obs=obs, info=info) + return env_data.obs, env_data.info + + def render( + self, + agent_index_focus: Optional[int] = None, + visualize_when_rgb: bool = False, + **kwargs, + ) -> Optional[np.ndarray]: + return self._env.render( + mode=self.render_mode, + agent_index_focus=agent_index_focus, + visualize_when_rgb=visualize_when_rgb, + **kwargs, + ) diff --git a/vmas/simulator/environment/rllib.py b/vmas/simulator/environment/rllib.py index 861a46d4..07381312 100644 --- a/vmas/simulator/environment/rllib.py +++ b/vmas/simulator/environment/rllib.py @@ -1,19 +1,28 @@ # Copyright (c) 2022-2024. # ProrokLab (https://www.proroklab.org/) # All rights reserved. +import importlib from typing import Dict, List, Optional, Tuple import numpy as np import torch from numpy import ndarray -from ray import rllib -from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType + from torch import Tensor from vmas.simulator.environment.environment import Environment from vmas.simulator.utils import INFO_TYPE, OBS_TYPE, REWARD_TYPE, TorchUtils +if importlib.util.find_spec("ray") is not None: + from ray import rllib + from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType +else: + raise ImportError( + "RLLib is not installed. Please install it with `pip install ray[rllib]<=2.2`." + ) + + class VectorEnvWrapper(rllib.VectorEnv): """ Vector environment wrapper for rllib @@ -23,6 +32,10 @@ def __init__( self, env: Environment, ): + assert ( + not env.terminated_truncated + ), "Rllib wrapper is not compatible with termination and truncation flags. Please set `terminated_truncated=False` in the VMAS environment." + self._env = env super().__init__( observation_space=self._env.observation_space,