Skip to content

Commit

Permalink
Clean-up headless mode Mujoco Viewer.
Browse files Browse the repository at this point in the history
- fixed some bugs.
- cleaned-up code.
  • Loading branch information
robfiras committed Oct 25, 2023
1 parent 544b6a3 commit 656b2db
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 43 deletions.
6 changes: 3 additions & 3 deletions mushroom_rl/environments/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, xml_file, actuation_spec, observation_spec, gamma, horizon, t
The list has to define a maximum velocity for every occurrence of JOINT_VEL in the observation_spec. The
velocity will not be limited in mujoco
**viewer_params: other parameters to be passed to the viewer.
See MujocoGlfwViewer documentation for the available options.
See MujocoViewer documentation for the available options.
"""
# Create the simulation
Expand Down Expand Up @@ -172,7 +172,7 @@ def step(self, action):

def render(self, record=False):
if self._viewer is None:
self._viewer = MujocoGlfwViewer(self._model, self.dt, record=record, **self._viewer_params)
self._viewer = MujocoViewer(self._model, self.dt, record=record, **self._viewer_params)

return self._viewer.render(self._data, record)

Expand Down Expand Up @@ -600,7 +600,7 @@ def __init__(self, xml_files, actuation_spec, observation_spec, gamma, horizon,
random_env_reset (bool): If True, a random environment/model is chosen after each episode. If False, it is
sequentially iterated through the environment/model list.
**viewer_params: other parameters to be passed to the viewer.
See MujocoGlfwViewer documentation for the available options.
See MujocoViewer documentation for the available options.
"""
# Create the simulation
Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/utils/mujoco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .viewer import MujocoGlfwViewer
from .viewer import MujocoViewer
from .observation_helper import ObservationHelper, ObservationType
from .kinematics import forward_kinematics
96 changes: 57 additions & 39 deletions mushroom_rl/utils/mujoco/viewer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,48 @@
import os
import glfw
import mujoco
import time
import collections
from itertools import cycle
import numpy as np
import os


class MujocoGlfwViewer:
def _import_egl(width, height):
from mujoco.egl import GLContext

return GLContext(width, height)


def _import_glfw(width, height):
from mujoco.glfw import GLContext

return GLContext(width, height)


def _import_osmesa(width, height):
from mujoco.osmesa import GLContext

return GLContext(width, height)


_ALL_RENDERERS = collections.OrderedDict(
[
("glfw", _import_glfw),
("egl", _import_egl),
("osmesa", _import_osmesa),
]
)


class MujocoViewer:
"""
Class that creates a Glfw viewer for mujoco environments.
Class that creates a viewer for mujoco environments.
"""

def __init__(self, model, dt, width=1920, height=1080, start_paused=False,
custom_render_callback=None, record=False, camera_params=None,
default_camera_mode="static", hide_menu_on_startup=False,
default_camera_mode="static", hide_menu_on_startup=None,
geom_group_visualization_on_startup=None, headless=False):
"""
Constructor.
Expand All @@ -34,9 +62,15 @@ def __init__(self, model, dt, width=1920, height=1080, start_paused=False,
hide_menu_on_startup (bool): If True, the menu is hidden on startup.
geom_group_visualization_on_startup (int/list): int or list defining which geom group_ids should be
visualized on startup. If None, all are visualized.
headless (bool): If True, render will be done in headless mode.
"""

if hide_menu_on_startup is None and headless:
hide_menu_on_startup = True
elif hide_menu_on_startup is None and not headless:
hide_menu_on_startup = False

self.button_left = False
self.button_right = False
self.button_middle = False
Expand All @@ -52,22 +86,26 @@ def __init__(self, model, dt, width=1920, height=1080, start_paused=False,
self._font_scale = 100

if headless:
self._opengl_context = self.get_opengl_backend(width, height)
# use the OpenGL render that is available on the machine
self._opengl_context = self.setup_opengl_backend_headless(width, height)
self._opengl_context.make_current()
self._width, self._height = self.update_headless_size(width, height)
else:
# use glfw
self._width, self._height = width, height
glfw.init()
glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, 0)
self._window = glfw.create_window(width=self._width, height=self._height, title="MuJoCo", monitor=None, share=None)
self._window = glfw.create_window(width=self._width, height=self._height,
title="MuJoCo", monitor=None, share=None)
glfw.make_context_current(self._window)
glfw.set_mouse_button_callback(self._window, self.mouse_button)
glfw.set_cursor_pos_callback(self._window, self.mouse_move)
glfw.set_key_callback(self._window, self.keyboard)
glfw.set_scroll_callback(self._window, self.scroll)
self._set_mujoco_buffers()

self._set_mujoco_buffers()

if record:
if record and not headless:
# dont allow to change the window size to have equal frame size during recording
glfw.window_hint(glfw.RESIZABLE, False)

Expand Down Expand Up @@ -124,6 +162,7 @@ def load_new_model(self, model):
self._scene = mujoco.MjvScene(model, 1000)
self._context = mujoco.MjrContext(model, mujoco.mjtFontScale(self._font_scale))


def mouse_button(self, window, button, act, mods):
"""
Mouse button callback for glfw.
Expand Down Expand Up @@ -285,8 +324,7 @@ def update_headless_size(self, width, height):
if width != _context.offWidth or height != _context.offHeight:
self._model.vis.global_.offwidth = width
self._model.vis.global_.offheight = height

self._set_mujoco_buffers()

return width, height

def render(self, data, record):
Expand Down Expand Up @@ -337,16 +375,12 @@ def render_inner_loop(self):
if not self._headless:
glfw.swap_buffers(self._window)
glfw.poll_events()

self.frames += 1

self._overlay.clear()

if not self._headless:
if glfw.window_should_close(self._window):
self.stop()
exit(0)

self.frames += 1
self._overlay.clear()
self._time_per_render = 0.9 * self._time_per_render + 0.1 * (time.time() - render_start)

if self._paused:
Expand Down Expand Up @@ -379,7 +413,7 @@ def read_pixels(self, depth=False):
"""

if self._headless:
shape = [self._width, self._height]
shape = (self._width, self._height)
else:
shape = glfw.get_framebuffer_size(self._window)

Expand All @@ -398,8 +432,8 @@ def stop(self):
Destroys the glfw image.
"""

glfw.destroy_window(self._window)
if not self._headless:
glfw.destroy_window(self._window)

def _create_overlay(self):
"""
Expand Down Expand Up @@ -565,25 +599,7 @@ def get_default_camera_params():
top_static=dict(distance=5.0, elevation=-90.0, azimuth=90.0, lookat=np.array([0.0, 0.0, 0.0])))


def get_opengl_backend(self, width, height):
# Reference: https://github.com/openai/gym/blob/master/gym/envs/mujoco/mujoco_rendering.py
def _import_egl(width, height):
from mujoco.egl import GLContext
return GLContext(width, height)

def _import_glfw(width, height):
from mujoco.glfw import GLContext
return GLContext(width, height)

def _import_osmesa(width, height):
from mujoco.osmesa import GLContext
return GLContext(width, height)

_ALL_RENDERERS = {
"glfw" : _import_glfw,
"osmesa" : _import_osmesa,
"egl" : _import_egl
}
def setup_opengl_backend_headless(self, width, height):

backend = os.environ.get("MUJOCO_GL")
if backend is not None:
Expand All @@ -597,6 +613,7 @@ def _import_osmesa(width, height):
)

else:
# iterate through all OpenGL backends to see which one is available
for name, _ in _ALL_RENDERERS.items():
try:
opengl_context = _ALL_RENDERERS[name](width, height)
Expand All @@ -609,4 +626,5 @@ def _import_osmesa(width, height):
"No OpenGL backend could be imported. Attempting to create a "
"rendering context will result in a RuntimeError."
)
return opengl_context

return opengl_context

0 comments on commit 656b2db

Please sign in to comment.