Skip to content

Commit

Permalink
Added BODY_VEL_WORLD as observation type in Mujoco
Browse files Browse the repository at this point in the history
- now we provide also the opportunity to get the body velocity info in world frame
  • Loading branch information
boris-il-forte committed Oct 12, 2024
1 parent 495af65 commit 33c8779
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
4 changes: 2 additions & 2 deletions mushroom_rl/environments/mujoco_envs/air_hockey/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
("robot_1/joint_3_vel", "planar_robot_1/joint_3", ObservationType.JOINT_VEL)]

additional_data += [("robot_1/ee_pos", "planar_robot_1/body_ee", ObservationType.BODY_POS),
("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL)]
("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL_WORLD)]

collision_spec += [("robot_1/ee", ["planar_robot_1/ee"])]

Expand All @@ -76,7 +76,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
("robot_2/joint_3_vel", "planar_robot_2/joint_3", ObservationType.JOINT_VEL)]

additional_data += [("robot_2/ee_pos", "planar_robot_2/body_ee", ObservationType.BODY_POS),
("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL)]
("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL_WORLD)]

collision_spec += [("robot_2/ee", ["planar_robot_2/ee"])]
else:
Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self):
("palm_yaw_pos", "wam/palm_yaw_joint", ObservationType.JOINT_POS),
("palm_yaw_vel", "wam/palm_yaw_joint", ObservationType.JOINT_VEL),
("ball_pos", "ball", ObservationType.BODY_POS),
("ball_vel", "ball", ObservationType.BODY_VEL)]
("ball_vel", "ball", ObservationType.BODY_VEL_WORLD)]

additional_data_spec = [("ball_pos", "ball", ObservationType.BODY_POS),
("goal_pos", "cup_goal_final", ObservationType.SITE_POS)]
Expand Down
19 changes: 11 additions & 8 deletions mushroom_rl/utils/mujoco/observation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@ class ObservationType(Enum):
The Observation have the following returns:
BODY_POS: (3,) x, y, z position of the body
BODY_ROT: (4,) quaternion of the body
BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z
BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in local frame
BODY_VEL_WORLD: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in world frame
JOINT_POS: (1,) rotation of the joint OR (7,) position, quaternion of a free joint
JOINT_VEL: (1,) velocity of the joint OR (6,) FIRST linear then angular velocity !different to BODY_VEL!
SITE_POS: (3,) x, y, z position of the body
SITE_ROT: (9,) rotation matrix of the site
"""
__order__ = "BODY_POS BODY_ROT BODY_VEL JOINT_POS JOINT_VEL SITE_POS SITE_ROT"
__order__ = "BODY_POS BODY_ROT BODY_VEL BODY_VEL_WORLD JOINT_POS JOINT_VEL SITE_POS SITE_ROT"
BODY_POS = 0
BODY_ROT = 1
BODY_VEL = 2
JOINT_POS = 3
JOINT_VEL = 4
SITE_POS = 5
SITE_ROT = 6
BODY_VEL_WORLD = 3
JOINT_POS = 4
JOINT_VEL = 5
SITE_POS = 6
SITE_ROT = 7


class ObservationHelper:
Expand Down Expand Up @@ -190,9 +192,10 @@ def get_state(self, model, data, name, o_type):
obs = data.body(name).xpos
elif o_type == ObservationType.BODY_ROT:
obs = data.body(name).xquat
elif o_type == ObservationType.BODY_VEL:
elif o_type == ObservationType.BODY_VEL or o_type == ObservationType.BODY_VEL_WORLD:
local = o_type == ObservationType.BODY_VEL
obs = np.empty(6)
mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, True)
mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, local)
elif o_type == ObservationType.JOINT_POS:
obs = data.joint(name).qpos
elif o_type == ObservationType.JOINT_VEL:
Expand Down

0 comments on commit 33c8779

Please sign in to comment.