Skip to content

Commit

Permalink
Doc update: custom envs, IsaacLab, Brax and dm_control (#2072)
Browse files Browse the repository at this point in the history
* Add note about start!=0 for Discrete spaces

* Update doc for IsaacLab and dm_control

* Fix test due to rounding error
  • Loading branch information
araffin authored Jan 26, 2025
1 parent d055a2e commit f8ea299
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 25 deletions.
18 changes: 18 additions & 0 deletions docs/guide/custom_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ That is to say, your environment must implement the following methods (and inher
Under the hood, when a channel-last image is passed, SB3 uses a ``VecTransposeImage`` wrapper to re-order the channels.


.. note::

SB3 doesn't support ``Discrete`` and ``MultiDiscrete`` spaces with ``start!=0``. However, you can update your environment or use a wrapper to make your env compatible with SB3:

.. code-block:: python
import gymnasium as gym
class ShiftWrapper(gym.Wrapper):
"""Allow to use Discrete() action spaces with start!=0"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Discrete)
self.action_space = gym.spaces.Discrete(env.action_space.n, start=0)
def step(self, action: int):
return self.env.step(action + self.env.action_space.start)
.. code-block:: python
Expand Down
45 changes: 22 additions & 23 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -735,41 +735,40 @@ A2C policy gradient updates on the model.
print(f"Best fitness: {top_candidates[0][1]:.2f}")
SB3 and ProcgenEnv
------------------
SB3 with Isaac Lab, Brax, Procgen, EnvPool
------------------------------------------

Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
to keep track of the agent progress.
Some massively parallel simulations such as `EnvPool <https://github.com/sail-sg/envpool>`_, `Isaac Lab <https://github.com/isaac-sim/IsaacLab>`_, `Brax <https://github.com/google/brax>`_ or `ProcGen <https://github.com/Farama-Foundation/Procgen2>`_ already produce a vectorized environment to speed up data collection (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_).

.. code-block:: python
To use SB3 with these tools, you need to wrap the env with tool-specific ``VecEnvWrapper`` that pre-processes the data for SB3,
you can find links to some of these wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.

from procgen import ProcgenEnv
- Isaac Lab wrapper: `link <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py>`__
- Brax: `link <https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7>`__
- EnvPool: `link <https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py>`__

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor

# ProcgenEnv is already vectorized
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
SB3 with DeepMind Control (dm_control)
--------------------------------------

# To use only part of the observation:
# venv = VecExtractDictObs(venv, "rgb")
If you want to use SB3 with `dm_control <https://github.com/google-deepmind/dm_control>`_, you need to use two wrappers (one from `shimmy <https://github.com/Farama-Foundation/Shimmy>`_, one pre-built one) to convert it to a Gymnasium compatible environment:

# Wrap with a VecMonitor to collect stats and avoid errors
venv = VecMonitor(venv=venv)
.. code-block:: python
model = PPO("MultiInputPolicy", venv, verbose=1)
model.learn(10_000)
import shimmy
import stable_baselines3 as sb3
from dm_control import suite
from gymnasium.wrappers import FlattenObservation
# Available envs:
# suite._DOMAINS and suite.dog.SUITE
SB3 with EnvPool or Isaac Gym
-----------------------------
env = suite.load(domain_name="dog", task_name="run")
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env))
Just like Procgen (see above), `EnvPool <https://github.com/sail-sg/envpool>`_ and `Isaac Gym <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_ accelerate the environment by
already providing a vectorized implementation.
model = sb3.PPO("MlpPolicy", gym_env, verbose=1)
model.learn(10_000, progress_bar=True)
To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
you can find links to those wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
Record a Video
Expand Down
1 change: 1 addition & 0 deletions docs/guide/sbx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Implemented algorithms:
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)


As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ Documentation:
- Add FootstepNet Envs to the project page (@cgaspard3333)
- Added FRASA to the project page (@MarcDcls)
- Fixed atari example (@chrisgao99)
- Add a note about ``Discrete`` action spaces with ``start!=0``
- Update doc for massively parallel simulators (Isaac Lab, Brax, ...)
- Add dm_control example

Release 2.4.1 (2024-12-20)
--------------------------
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def _check_non_zero_start(space: spaces.Space, space_type: str = "observation",
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
f"You can use a wrapper or update your {space_type} space."
"You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) "
f"or update your {space_type} space."
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_get_original():
assert not np.array_equal(orig_obs, obs)
assert not np.array_equal(orig_rewards, rewards)
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6)


def test_get_original_dict():
Expand Down

0 comments on commit f8ea299

Please sign in to comment.