Skip to content

Commit

Permalink
Allow resetting the state of MuJoCo environment from observation (#121)
Browse files Browse the repository at this point in the history
* Made it possible to pass an observation to the MuJoCo reset function
  • Loading branch information
cube1324 authored Feb 12, 2023
1 parent 649e311 commit 699acc0
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 45 deletions.
12 changes: 7 additions & 5 deletions mushroom_rl/environments/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def seed(self, seed):

def reset(self, obs=None):
mujoco.mj_resetData(self._model, self._data)
self.setup()
self.setup(obs)

self._obs = self._create_observation(self.obs_helper.build_obs(self._data))
self._obs = self._create_observation(self.obs_helper._build_obs(self._data))
return self._modify_observation(self._obs)

def step(self, action):
Expand All @@ -154,7 +154,7 @@ def step(self, action):

self._simulation_post_step()

cur_obs = self._create_observation(self.obs_helper.build_obs(self._data))
cur_obs = self._create_observation(self.obs_helper._build_obs(self._data))

self._step_finalize()

Expand Down Expand Up @@ -427,13 +427,15 @@ def is_absorbing(self, obs):
"""
raise NotImplementedError

def setup(self):
def setup(self, obs):
"""
A function that allows to execute setup code after an environment
reset.
"""
pass
if obs is not None:
self.obs_helper._modify_data(self._data, obs)


def get_all_observation_keys(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions mushroom_rl/environments/mujoco_envs/air_hockey/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
self.agents = []

action_spec = []
observation_spec = [("puck_pos", "puck", ObservationType.BODY_POS),
("puck_vel", "puck", ObservationType.BODY_VEL)]
observation_spec = [("puck_pos", "puck", ObservationType.JOINT_POS),
("puck_vel", "puck", ObservationType.JOINT_VEL)]

additional_data = [("puck_pos", "puck", ObservationType.JOINT_POS),
("puck_vel", "puck", ObservationType.JOINT_VEL)]
Expand Down Expand Up @@ -65,8 +65,8 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor

action_spec += ["planar_robot_2/joint_1", "planar_robot_2/joint_2", "planar_robot_2/joint_3"]
# Add puck pos/vel again to transform into second agents frame
observation_spec += [("robot_2/puck_pos", "puck", ObservationType.BODY_POS),
("robot_2/puck_vel", "puck", ObservationType.BODY_VEL),
observation_spec += [("robot_2/puck_pos", "puck", ObservationType.JOINT_POS),
("robot_2/puck_vel", "puck", ObservationType.JOINT_VEL),
("robot_2/joint_1_pos", "planar_robot_2/joint_1", ObservationType.JOINT_POS),
("robot_2/joint_2_pos", "planar_robot_2/joint_2", ObservationType.JOINT_POS),
("robot_2/joint_3_pos", "planar_robot_2/joint_3", ObservationType.JOINT_POS),
Expand Down
6 changes: 3 additions & 3 deletions mushroom_rl/environments/mujoco_envs/air_hockey/defend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, random_init=False, action_penalty=1e-3, init_velocity_range=(
super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, n_intermediate_steps=n_intermediate_steps,
env_noise=env_noise, obs_noise=obs_noise, **viewer_params)

def setup(self, state=None):
def setup(self, obs):
# Set initial puck parameters
if self.random_init:
puck_pos = np.random.rand(2) * (self.start_range[:, 1] - self.start_range[:, 0]) + self.start_range[:, 0]
Expand All @@ -52,8 +52,8 @@ def setup(self, state=None):

self._write_data("puck_pos", np.concatenate([puck_pos, [0, 0, 0, 0, 1]]))
self._write_data("puck_vel", np.concatenate([puck_lin_vel, puck_ang_vel]))
super(AirHockeyDefend, self).setup()

super(AirHockeyDefend, self).setup(obs)

def reward(self, state, action, next_state, absorbing):
r = 0
Expand Down
46 changes: 31 additions & 15 deletions mushroom_rl/environments/mujoco_envs/air_hockey/double.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import mujoco

from mushroom_rl.utils.spaces import Box
from mushroom_rl.environments.mujoco_envs.air_hockey.base import AirHockeyBase
Expand All @@ -20,15 +21,27 @@ def __init__(self, gamma=0.99, horizon=120, env_noise=False, obs_noise=False, ti
super().__init__(gamma=gamma, horizon=horizon, env_noise=env_noise, n_agents=2, obs_noise=obs_noise,
timestep=timestep, n_intermediate_steps=n_intermediate_steps, **viewer_params)

# Remove z position and quaternion from puck pos
self.obs_helper.remove_obs("puck_pos", 2)
self.obs_helper.remove_obs("puck_vel", 0)
self.obs_helper.remove_obs("puck_vel", 1)
self.obs_helper.remove_obs("puck_vel", 5)
self.obs_helper.remove_obs("puck_pos", 3)
self.obs_helper.remove_obs("puck_pos", 4)
self.obs_helper.remove_obs("puck_pos", 5)
self.obs_helper.remove_obs("puck_pos", 6)

self.obs_helper.remove_obs("robot_2/puck_pos", 2)
self.obs_helper.remove_obs("robot_2/puck_vel", 0)
self.obs_helper.remove_obs("robot_2/puck_vel", 1)
self.obs_helper.remove_obs("robot_2/puck_vel", 5)
self.obs_helper.remove_obs("robot_2/puck_pos", 3)
self.obs_helper.remove_obs("robot_2/puck_pos", 4)
self.obs_helper.remove_obs("robot_2/puck_pos", 5)
self.obs_helper.remove_obs("robot_2/puck_pos", 6)

# Remove linear z velocity and angular velocity around x and y
self.obs_helper.remove_obs("puck_vel", 2)
self.obs_helper.remove_obs("puck_vel", 3)
self.obs_helper.remove_obs("puck_vel", 4)

self.obs_helper.remove_obs("robot_2/puck_vel", 2)
self.obs_helper.remove_obs("robot_2/puck_vel", 3)
self.obs_helper.remove_obs("robot_2/puck_vel", 4)

self.obs_helper.add_obs("collision_robot_1_puck", 1, 0, 1)
self.obs_helper.add_obs("collision_robot_2_puck", 1, 0, 1)
Expand All @@ -54,26 +67,27 @@ def get_ee(self, robot=1):
return ee_pos, ee_vel

def _modify_observation(self, obs):
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(obs, "puck_pos"), self.agents[0]["frame"])
new_obs = obs.copy()
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(new_obs, "puck_pos"), self.agents[0]["frame"])

self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(obs, "puck_vel"), self.agents[0]["frame"], type='vel')
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(new_obs, "puck_vel"), self.agents[0]["frame"], type='vel')

self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(obs, "robot_2/puck_pos"), self.agents[1]["frame"])
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(new_obs, "robot_2/puck_pos"), self.agents[1]["frame"])

self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(obs, "robot_2/puck_vel"), self.agents[1]["frame"],
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(new_obs, "robot_2/puck_vel"), self.agents[1]["frame"],
type='vel')

if self.obs_noise:
noise = np.random.randn(2) * 0.001
self.obs_helper.get_from_obs(obs, "puck_pos")[:] += noise
self.obs_helper.get_from_obs(obs, "robot_2/puck_pos")[:] += noise
self.obs_helper.get_from_obs(new_obs, "puck_pos")[:] += noise
self.obs_helper.get_from_obs(new_obs, "robot_2/puck_pos")[:] += noise

return obs
return new_obs

def reward(self, state, action, next_state, absorbing):
return 0

def setup(self):
def setup(self, obs):
self.robot_1_hit = False
self.robot_2_hit = False
self.has_bounce = False
Expand All @@ -84,6 +98,9 @@ def setup(self):
for i in range(3):
self._data.joint("planar_robot_2/joint_" + str(i+1)).qpos = self.init_state[i]

super().setup(obs)
mujoco.mj_fwdPosition(self._model, self._data)

def _simulation_post_step(self):
if not self.robot_1_hit:
self.robot_1_hit = self._check_collision("puck", "robot_1/ee")
Expand All @@ -98,7 +115,6 @@ def _create_observation(self, state):
obs = super(AirHockeyDouble, self)._create_observation(state)
return np.append(obs, [self.robot_1_hit, self.robot_2_hit, self.has_bounce])


def _create_info_dictionary(self, obs):
constraints = {"agent-1": {}, "agent-2":{}}

Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/environments/mujoco_envs/air_hockey/hit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, random_init=False, action_penalty=1e-3, init_robot_state="rig
super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, n_intermediate_steps=n_intermediate_steps,
env_noise=env_noise, obs_noise=obs_noise, **viewer_params)

def setup(self):
def setup(self, obs):
# Initial position of the puck
if self.random_init:
puck_pos = np.random.rand(2) * (self.hit_range[:, 1] - self.hit_range[:, 0]) + self.hit_range[:, 0]
Expand Down Expand Up @@ -63,7 +63,7 @@ def setup(self):

self.vec_puck_side = (side_point - puck_pos) / np.linalg.norm(side_point - puck_pos)

super(AirHockeyHit, self).setup()
super(AirHockeyHit, self).setup(obs)

def reward(self, state, action, next_state, absorbing):
r = 0
Expand Down
6 changes: 3 additions & 3 deletions mushroom_rl/environments/mujoco_envs/air_hockey/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, random_init=False, action_penalty=1e-3, sub_problem="side", g
super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, n_intermediate_steps=n_intermediate_steps,
env_noise=env_noise, obs_noise=obs_noise, **viewer_params)

def setup(self):
def setup(self, obs):
if self.random_init:
puck_pos = np.random.rand(2) * (self.start_range[:, 1] - self.start_range[:, 0]) + self.start_range[:, 0]
puck_pos *= [1, [1, -1][np.random.randint(2)]]
Expand All @@ -48,8 +48,8 @@ def setup(self):
self.desired_point = [puck_pos[0], 0]

self._write_data("puck_pos", np.concatenate([puck_pos, [0, 0, 0, 0, 1]]))
super(AirHockeyPrepare, self).setup()

super(AirHockeyPrepare, self).setup(obs)


def reward(self, state, action, next_state, absorbing):
Expand Down
6 changes: 3 additions & 3 deletions mushroom_rl/environments/mujoco_envs/air_hockey/repel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, random_init=False, action_penalty=1e-3, init_velocity_range=(
super().__init__(gamma=gamma, horizon=horizon, timestep=timestep, n_intermediate_steps=n_intermediate_steps,
env_noise=env_noise, obs_noise=obs_noise, **viewer_params)

def setup(self, state=None):
def setup(self, obs):
# Set initial puck parameters
if self.random_init:
puck_pos = np.random.rand(2) * (self.start_range[:, 1] - self.start_range[:, 0]) + self.start_range[:, 0]
Expand All @@ -54,8 +54,8 @@ def setup(self, state=None):

self._write_data("puck_pos", np.concatenate([puck_pos, [0, 0, 0, 0, 1]]))
self._write_data("puck_vel", np.concatenate([puck_lin_vel, puck_ang_vel]))
super(AirHockeyRepel, self).setup()

super(AirHockeyRepel, self).setup(obs)

# Very flawed needs a lot of tuning
def reward(self, state, action, next_state, absorbing):
Expand Down
28 changes: 20 additions & 8 deletions mushroom_rl/environments/mujoco_envs/air_hockey/single.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import mujoco

from mushroom_rl.utils.spaces import Box
from mushroom_rl.environments.mujoco_envs.air_hockey.base import AirHockeyBase
Expand All @@ -21,10 +22,17 @@ def __init__(self, gamma=0.99, horizon=120, env_noise=False, obs_noise=False, ti
super().__init__(gamma=gamma, horizon=horizon, env_noise=env_noise, n_agents=1, obs_noise=obs_noise,
timestep=timestep, n_intermediate_steps=n_intermediate_steps, **viewer_params)

# Remove z position and quaternion from puck pos
self.obs_helper.remove_obs("puck_pos", 2)
self.obs_helper.remove_obs("puck_vel", 0)
self.obs_helper.remove_obs("puck_vel", 1)
self.obs_helper.remove_obs("puck_vel", 5)
self.obs_helper.remove_obs("puck_pos", 3)
self.obs_helper.remove_obs("puck_pos", 4)
self.obs_helper.remove_obs("puck_pos", 5)
self.obs_helper.remove_obs("puck_pos", 6)

# Remove linear z velocity and angular velocity around x and y
self.obs_helper.remove_obs("puck_vel", 2)
self.obs_helper.remove_obs("puck_vel", 3)
self.obs_helper.remove_obs("puck_vel", 4)

self.obs_helper.add_obs("collision_robot_1_puck", 1, 0, 1)
self.obs_helper.add_obs("collision_short_sides_rim_puck", 1, 0, 1)
Expand Down Expand Up @@ -64,22 +72,26 @@ def get_ee(self):
return ee_pos, ee_vel

def _modify_observation(self, obs):
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(obs, "puck_pos"), self.agents[0]["frame"])
new_obs = obs.copy()
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(new_obs, "puck_pos"), self.agents[0]["frame"])

self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(obs, "puck_vel"), self.agents[0]["frame"], type='vel')
self._puck_2d_in_robot_frame(self.obs_helper.get_from_obs(new_obs, "puck_vel"), self.agents[0]["frame"], type='vel')

if self.obs_noise:
self.obs_helper.get_from_obs(obs, "puck_pos")[:] += np.random.randn(2) * 0.001
self.obs_helper.get_from_obs(new_obs, "puck_pos")[:] += np.random.randn(2) * 0.001

return obs
return new_obs

def setup(self):
def setup(self, obs):
self.has_hit = False
self.has_bounce = False

for i in range(3):
self._data.joint("planar_robot_1/joint_" + str(i+1)).qpos = self.init_state[i]

super().setup(obs)
mujoco.mj_fwdPosition(self._model, self._data)

def _simulation_post_step(self):
if not self.has_hit:
self.has_hit = self._check_collision("puck", "robot_1/ee")
Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def is_absorbing(self, state):
dist = self._read_data("goal_pos") - self._read_data("ball_pos")
return np.linalg.norm(dist) < 0.05 or self._check_collision("ball", "robot")

def setup(self):
def setup(self, obs):
if self._reset_state is None:
# Copy the initial position after the reset
init_pos = self._data.qpos[0:7].copy()
Expand Down
16 changes: 15 additions & 1 deletion mushroom_rl/utils/mujoco/observation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_joint_pos_limits(self):
def get_joint_vel_limits(self):
return self.obs_low[self.joint_vel_idx], self.obs_high[self.joint_vel_idx]

def build_obs(self, data):
def _build_obs(self, data):
"""
Builds the observation given the true state of the simulation. The ObservationType documentation
describes the different returns in detail
Expand All @@ -165,6 +165,20 @@ def build_obs(self, data):
observations.append(obs)
return np.concatenate(observations)

def _modify_data(self, data, obs):
"""
Write the values of the observation into the provided mujoco data object. ONLY joint_pos / joint_vel
observations will have an effect on the simulation when overwritten. Everything else is just discarded by mujoco
"""
current_idx = 0
for key, name, o_type in self.observation_spec:
omit = np.array(self.build_omit_idx[key])
current_obs = self.get_state(data, name, o_type)
for i in range(len(current_obs)):
if i not in omit:
current_obs[i] = obs[current_idx]
current_idx += 1

def get_state(self, data, name, o_type):
"""
Get a single observation from data, given it's name and observation type. The ObservationType documentation
Expand Down
Binary file modified tests/environments/mujoco_envs/air_hockey_defend_data.npy
Binary file not shown.
Binary file modified tests/environments/mujoco_envs/air_hockey_hit_data.npy
Binary file not shown.
Binary file modified tests/environments/mujoco_envs/air_hockey_prepare_data.npy
Binary file not shown.
Binary file modified tests/environments/mujoco_envs/air_hockey_repel_data.npy
Binary file not shown.

0 comments on commit 699acc0

Please sign in to comment.