-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Terminated/truncated support and Gymnasium wrapper (#143)
* add gymnasium integration but maintain openai gym support * update documentation (default being gym) * by default preserve original interface of all functions * Update gymnasium/ gym integration - base VMAS environment uses OpenAI gym spaces - base VMAS environment has new flag `terminated_truncated` (default: False) that determines whether `done()` and `step()` return the default `done` value or separate values for `terminated` and `truncated` - update `gymnasium` wrapper to convert gym spaces of base environment to gymnasium spaces - add `gymnasium_vec` wrapper that can wrap vectorized VMAS environment as gymnasium environment - add new installation options of VMAS for optional dependencies (used for features like rllib, torchrl, gymnasium, rendering, testing) - add `return_numpy` flag in gymnasium wrappers (default: True) to determine whether to convert torch tensors to numpy --> passed through by `make_env` function - add `render_mode` flag in gymnasium wrappers (default: "human") to determine mode to render --> passed through by `make_env` function * use gymnasium and shimmy tools to convert spaces + use vmas to_numpy conversion * update VMAS wrappers - add base VMAS wrapper class for type conversion between tensors and np for singleton and vectorized envs - change default of gym wrapper to return np data - update interactive rendering to be compatible with non gym wrapper class (to preserve tensor types) - add error messages for gymnasium and rllib wrappers without installing first * update vmas wrapper base class, move wrappers and add wrapper tests * incorporate feedback - update github dependency installation - unify get scenario test function and limit wrapper tests to fewer scenarios - allow import of all gym wrappers from `vmas.simulator.environment.gym` - consider env continuous_actions for action type conversion in wrappers - compress info to single nested info if needed rather than combining keys * remove import error * Revert "remove import error" This reverts commit 2d0ad62. * import optional deps only when needed * relative imports * installation docs * interactive render * docs * more docs * various * small nits * gym wrapper tests for dict spaces check obs shapes matching obs key --------- Co-authored-by: Matteo Bettini <[email protected]>
- Loading branch information
1 parent
132d97b
commit 659c390
Showing
21 changed files
with
940 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,5 +30,11 @@ def get_version(): | |
author_email="[email protected]", | ||
packages=find_packages(), | ||
install_requires=["numpy", "torch", "pyglet<=1.5.27", "gym", "six"], | ||
extras_require={ | ||
"gymnasium": ["gymnasium", "shimmy"], | ||
"rllib": ["ray[rllib]<=2.2"], | ||
"render": ["opencv-python", "moviepy", "matplotlib", "opencv-python"], | ||
"test": ["pytest", "pytest-instafail", "pyyaml", "tqdm"], | ||
}, | ||
include_package_data=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (c) 2024. | ||
# ProrokLab (https://www.proroklab.org/) | ||
# All rights reserved. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) 2024. | ||
# ProrokLab (https://www.proroklab.org/) | ||
# All rights reserved. | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
from torch import Tensor | ||
|
||
from vmas import make_env | ||
from vmas.simulator.environment import Environment | ||
|
||
|
||
TEST_SCENARIOS = [ | ||
"balance", | ||
"discovery", | ||
"give_way", | ||
"joint_passage", | ||
"navigation", | ||
"passage", | ||
"transport", | ||
"waterfall", | ||
"simple_world_comm", | ||
] | ||
|
||
|
||
def _check_obs_type(obss, obs_shapes, dict_space, return_numpy): | ||
if dict_space: | ||
assert isinstance( | ||
obss, dict | ||
), f"Expected dictionary of observations, got {type(obss)}" | ||
for k, obs in obss.items(): | ||
obs_shape = obs_shapes[k] | ||
assert ( | ||
obs.shape == obs_shape | ||
), f"Expected shape {obs_shape}, got {obs.shape}" | ||
if return_numpy: | ||
assert isinstance( | ||
obs, np.ndarray | ||
), f"Expected numpy array, got {type(obs)}" | ||
else: | ||
assert isinstance( | ||
obs, Tensor | ||
), f"Expected torch tensor, got {type(obs)}" | ||
else: | ||
assert isinstance( | ||
obss, list | ||
), f"Expected list of observations, got {type(obss)}" | ||
for obs, shape in zip(obss, obs_shapes): | ||
assert obs.shape == shape, f"Expected shape {shape}, got {obs.shape}" | ||
if return_numpy: | ||
assert isinstance( | ||
obs, np.ndarray | ||
), f"Expected numpy array, got {type(obs)}" | ||
else: | ||
assert isinstance( | ||
obs, Tensor | ||
), f"Expected torch tensor, got {type(obs)}" | ||
|
||
|
||
@pytest.mark.parametrize("scenario", TEST_SCENARIOS) | ||
@pytest.mark.parametrize("return_numpy", [True, False]) | ||
@pytest.mark.parametrize("continuous_actions", [True, False]) | ||
@pytest.mark.parametrize("dict_space", [True, False]) | ||
def test_gym_wrapper( | ||
scenario, return_numpy, continuous_actions, dict_space, max_steps=10 | ||
): | ||
env = make_env( | ||
scenario=scenario, | ||
num_envs=1, | ||
device="cpu", | ||
continuous_actions=continuous_actions, | ||
dict_spaces=dict_space, | ||
wrapper="gym", | ||
wrapper_kwargs={"return_numpy": return_numpy}, | ||
max_steps=max_steps, | ||
) | ||
|
||
assert ( | ||
len(env.observation_space) == env.unwrapped.n_agents | ||
), "Expected one observation per agent" | ||
assert ( | ||
len(env.action_space) == env.unwrapped.n_agents | ||
), "Expected one action per agent" | ||
if dict_space: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Dict | ||
), "Expected Dict observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Dict | ||
), "Expected Dict action space" | ||
obs_shapes = { | ||
k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() | ||
} | ||
else: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Tuple | ||
), "Expected Tuple observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Tuple | ||
), "Expected Tuple action space" | ||
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] | ||
|
||
assert isinstance( | ||
env.unwrapped, Environment | ||
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" | ||
|
||
obss = env.reset() | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
|
||
for _ in range(max_steps): | ||
actions = [ | ||
env.unwrapped.get_random_action(agent).numpy() | ||
for agent in env.unwrapped.agents | ||
] | ||
obss, rews, done, info = env.step(actions) | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
|
||
assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" | ||
if not dict_space: | ||
assert isinstance( | ||
rews, list | ||
), f"Expected list of rewards but got {type(rews)}" | ||
|
||
rew_values = rews | ||
else: | ||
assert isinstance( | ||
rews, dict | ||
), f"Expected dictionary of rewards but got {type(rews)}" | ||
rew_values = list(rews.values()) | ||
assert all( | ||
isinstance(rew, float) for rew in rew_values | ||
), f"Expected float rewards but got {type(rew_values[0])}" | ||
|
||
assert isinstance(done, bool), f"Expected bool for done but got {type(done)}" | ||
|
||
assert isinstance( | ||
info, dict | ||
), f"Expected info to be a dictionary but got {type(info)}" | ||
|
||
assert done, "Expected done to be True after 100 steps" |
Oops, something went wrong.