From 656b2db9c1a50dd68c837a942a37f306468a0028 Mon Sep 17 00:00:00 2001 From: robfiras Date: Wed, 25 Oct 2023 15:58:24 +0200 Subject: [PATCH] Clean-up headless mode Mujoco Viewer. - fixed some bugs. - cleaned-up code. --- mushroom_rl/environments/mujoco.py | 6 +- mushroom_rl/utils/mujoco/__init__.py | 2 +- mushroom_rl/utils/mujoco/viewer.py | 96 +++++++++++++++++----------- 3 files changed, 61 insertions(+), 43 deletions(-) diff --git a/mushroom_rl/environments/mujoco.py b/mushroom_rl/environments/mujoco.py index 39527412..cbc59802 100644 --- a/mushroom_rl/environments/mujoco.py +++ b/mushroom_rl/environments/mujoco.py @@ -56,7 +56,7 @@ def __init__(self, xml_file, actuation_spec, observation_spec, gamma, horizon, t The list has to define a maximum velocity for every occurrence of JOINT_VEL in the observation_spec. The velocity will not be limited in mujoco **viewer_params: other parameters to be passed to the viewer. - See MujocoGlfwViewer documentation for the available options. + See MujocoViewer documentation for the available options. """ # Create the simulation @@ -172,7 +172,7 @@ def step(self, action): def render(self, record=False): if self._viewer is None: - self._viewer = MujocoGlfwViewer(self._model, self.dt, record=record, **self._viewer_params) + self._viewer = MujocoViewer(self._model, self.dt, record=record, **self._viewer_params) return self._viewer.render(self._data, record) @@ -600,7 +600,7 @@ def __init__(self, xml_files, actuation_spec, observation_spec, gamma, horizon, random_env_reset (bool): If True, a random environment/model is chosen after each episode. If False, it is sequentially iterated through the environment/model list. **viewer_params: other parameters to be passed to the viewer. - See MujocoGlfwViewer documentation for the available options. + See MujocoViewer documentation for the available options. """ # Create the simulation diff --git a/mushroom_rl/utils/mujoco/__init__.py b/mushroom_rl/utils/mujoco/__init__.py index 4fca999d..0b5917b5 100644 --- a/mushroom_rl/utils/mujoco/__init__.py +++ b/mushroom_rl/utils/mujoco/__init__.py @@ -1,3 +1,3 @@ -from .viewer import MujocoGlfwViewer +from .viewer import MujocoViewer from .observation_helper import ObservationHelper, ObservationType from .kinematics import forward_kinematics diff --git a/mushroom_rl/utils/mujoco/viewer.py b/mushroom_rl/utils/mujoco/viewer.py index 8bab2c30..fcdda9de 100644 --- a/mushroom_rl/utils/mujoco/viewer.py +++ b/mushroom_rl/utils/mujoco/viewer.py @@ -1,20 +1,48 @@ +import os import glfw import mujoco import time +import collections from itertools import cycle import numpy as np -import os -class MujocoGlfwViewer: +def _import_egl(width, height): + from mujoco.egl import GLContext + + return GLContext(width, height) + + +def _import_glfw(width, height): + from mujoco.glfw import GLContext + + return GLContext(width, height) + + +def _import_osmesa(width, height): + from mujoco.osmesa import GLContext + + return GLContext(width, height) + + +_ALL_RENDERERS = collections.OrderedDict( + [ + ("glfw", _import_glfw), + ("egl", _import_egl), + ("osmesa", _import_osmesa), + ] +) + + +class MujocoViewer: """ - Class that creates a Glfw viewer for mujoco environments. + Class that creates a viewer for mujoco environments. """ def __init__(self, model, dt, width=1920, height=1080, start_paused=False, custom_render_callback=None, record=False, camera_params=None, - default_camera_mode="static", hide_menu_on_startup=False, + default_camera_mode="static", hide_menu_on_startup=None, geom_group_visualization_on_startup=None, headless=False): """ Constructor. @@ -34,9 +62,15 @@ def __init__(self, model, dt, width=1920, height=1080, start_paused=False, hide_menu_on_startup (bool): If True, the menu is hidden on startup. geom_group_visualization_on_startup (int/list): int or list defining which geom group_ids should be visualized on startup. If None, all are visualized. + headless (bool): If True, render will be done in headless mode. """ + if hide_menu_on_startup is None and headless: + hide_menu_on_startup = True + elif hide_menu_on_startup is None and not headless: + hide_menu_on_startup = False + self.button_left = False self.button_right = False self.button_middle = False @@ -52,22 +86,26 @@ def __init__(self, model, dt, width=1920, height=1080, start_paused=False, self._font_scale = 100 if headless: - self._opengl_context = self.get_opengl_backend(width, height) + # use the OpenGL render that is available on the machine + self._opengl_context = self.setup_opengl_backend_headless(width, height) self._opengl_context.make_current() self._width, self._height = self.update_headless_size(width, height) else: + # use glfw self._width, self._height = width, height glfw.init() glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, 0) - self._window = glfw.create_window(width=self._width, height=self._height, title="MuJoCo", monitor=None, share=None) + self._window = glfw.create_window(width=self._width, height=self._height, + title="MuJoCo", monitor=None, share=None) glfw.make_context_current(self._window) glfw.set_mouse_button_callback(self._window, self.mouse_button) glfw.set_cursor_pos_callback(self._window, self.mouse_move) glfw.set_key_callback(self._window, self.keyboard) glfw.set_scroll_callback(self._window, self.scroll) - self._set_mujoco_buffers() + + self._set_mujoco_buffers() - if record: + if record and not headless: # dont allow to change the window size to have equal frame size during recording glfw.window_hint(glfw.RESIZABLE, False) @@ -124,6 +162,7 @@ def load_new_model(self, model): self._scene = mujoco.MjvScene(model, 1000) self._context = mujoco.MjrContext(model, mujoco.mjtFontScale(self._font_scale)) + def mouse_button(self, window, button, act, mods): """ Mouse button callback for glfw. @@ -285,8 +324,7 @@ def update_headless_size(self, width, height): if width != _context.offWidth or height != _context.offHeight: self._model.vis.global_.offwidth = width self._model.vis.global_.offheight = height - - self._set_mujoco_buffers() + return width, height def render(self, data, record): @@ -337,16 +375,12 @@ def render_inner_loop(self): if not self._headless: glfw.swap_buffers(self._window) glfw.poll_events() - - self.frames += 1 - - self._overlay.clear() - - if not self._headless: if glfw.window_should_close(self._window): self.stop() exit(0) + self.frames += 1 + self._overlay.clear() self._time_per_render = 0.9 * self._time_per_render + 0.1 * (time.time() - render_start) if self._paused: @@ -379,7 +413,7 @@ def read_pixels(self, depth=False): """ if self._headless: - shape = [self._width, self._height] + shape = (self._width, self._height) else: shape = glfw.get_framebuffer_size(self._window) @@ -398,8 +432,8 @@ def stop(self): Destroys the glfw image. """ - - glfw.destroy_window(self._window) + if not self._headless: + glfw.destroy_window(self._window) def _create_overlay(self): """ @@ -565,25 +599,7 @@ def get_default_camera_params(): top_static=dict(distance=5.0, elevation=-90.0, azimuth=90.0, lookat=np.array([0.0, 0.0, 0.0]))) - def get_opengl_backend(self, width, height): - # Reference: https://github.com/openai/gym/blob/master/gym/envs/mujoco/mujoco_rendering.py - def _import_egl(width, height): - from mujoco.egl import GLContext - return GLContext(width, height) - - def _import_glfw(width, height): - from mujoco.glfw import GLContext - return GLContext(width, height) - - def _import_osmesa(width, height): - from mujoco.osmesa import GLContext - return GLContext(width, height) - - _ALL_RENDERERS = { - "glfw" : _import_glfw, - "osmesa" : _import_osmesa, - "egl" : _import_egl - } + def setup_opengl_backend_headless(self, width, height): backend = os.environ.get("MUJOCO_GL") if backend is not None: @@ -597,6 +613,7 @@ def _import_osmesa(width, height): ) else: + # iterate through all OpenGL backends to see which one is available for name, _ in _ALL_RENDERERS.items(): try: opengl_context = _ALL_RENDERERS[name](width, height) @@ -609,4 +626,5 @@ def _import_osmesa(width, height): "No OpenGL backend could be imported. Attempting to create a " "rendering context will result in a RuntimeError." ) - return opengl_context \ No newline at end of file + + return opengl_context