From 33c87797ea627ab2825c34de3c657f1e3eaa0396 Mon Sep 17 00:00:00 2001 From: boris-il-forte Date: Sat, 12 Oct 2024 18:29:04 +0200 Subject: [PATCH] Added BODY_VEL_WORLD as observation type in Mujoco - now we provide also the opportunity to get the body velocity info in world frame --- .../mujoco_envs/air_hockey/base.py | 4 ++-- .../environments/mujoco_envs/ball_in_a_cup.py | 2 +- .../utils/mujoco/observation_helper.py | 19 +++++++++++-------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/base.py b/mushroom_rl/environments/mujoco_envs/air_hockey/base.py index 06de5564..020f81c5 100644 --- a/mushroom_rl/environments/mujoco_envs/air_hockey/base.py +++ b/mushroom_rl/environments/mujoco_envs/air_hockey/base.py @@ -57,7 +57,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor ("robot_1/joint_3_vel", "planar_robot_1/joint_3", ObservationType.JOINT_VEL)] additional_data += [("robot_1/ee_pos", "planar_robot_1/body_ee", ObservationType.BODY_POS), - ("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL)] + ("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL_WORLD)] collision_spec += [("robot_1/ee", ["planar_robot_1/ee"])] @@ -76,7 +76,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor ("robot_2/joint_3_vel", "planar_robot_2/joint_3", ObservationType.JOINT_VEL)] additional_data += [("robot_2/ee_pos", "planar_robot_2/body_ee", ObservationType.BODY_POS), - ("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL)] + ("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL_WORLD)] collision_spec += [("robot_2/ee", ["planar_robot_2/ee"])] else: diff --git a/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py b/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py index c8cf6b46..a0ada7d7 100644 --- a/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py +++ b/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py @@ -36,7 +36,7 @@ def __init__(self): ("palm_yaw_pos", "wam/palm_yaw_joint", ObservationType.JOINT_POS), ("palm_yaw_vel", "wam/palm_yaw_joint", ObservationType.JOINT_VEL), ("ball_pos", "ball", ObservationType.BODY_POS), - ("ball_vel", "ball", ObservationType.BODY_VEL)] + ("ball_vel", "ball", ObservationType.BODY_VEL_WORLD)] additional_data_spec = [("ball_pos", "ball", ObservationType.BODY_POS), ("goal_pos", "cup_goal_final", ObservationType.SITE_POS)] diff --git a/mushroom_rl/utils/mujoco/observation_helper.py b/mushroom_rl/utils/mujoco/observation_helper.py index cf1c3f88..77429e85 100644 --- a/mushroom_rl/utils/mujoco/observation_helper.py +++ b/mushroom_rl/utils/mujoco/observation_helper.py @@ -11,20 +11,22 @@ class ObservationType(Enum): The Observation have the following returns: BODY_POS: (3,) x, y, z position of the body BODY_ROT: (4,) quaternion of the body - BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z + BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in local frame + BODY_VEL_WORLD: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in world frame JOINT_POS: (1,) rotation of the joint OR (7,) position, quaternion of a free joint JOINT_VEL: (1,) velocity of the joint OR (6,) FIRST linear then angular velocity !different to BODY_VEL! SITE_POS: (3,) x, y, z position of the body SITE_ROT: (9,) rotation matrix of the site """ - __order__ = "BODY_POS BODY_ROT BODY_VEL JOINT_POS JOINT_VEL SITE_POS SITE_ROT" + __order__ = "BODY_POS BODY_ROT BODY_VEL BODY_VEL_WORLD JOINT_POS JOINT_VEL SITE_POS SITE_ROT" BODY_POS = 0 BODY_ROT = 1 BODY_VEL = 2 - JOINT_POS = 3 - JOINT_VEL = 4 - SITE_POS = 5 - SITE_ROT = 6 + BODY_VEL_WORLD = 3 + JOINT_POS = 4 + JOINT_VEL = 5 + SITE_POS = 6 + SITE_ROT = 7 class ObservationHelper: @@ -190,9 +192,10 @@ def get_state(self, model, data, name, o_type): obs = data.body(name).xpos elif o_type == ObservationType.BODY_ROT: obs = data.body(name).xquat - elif o_type == ObservationType.BODY_VEL: + elif o_type == ObservationType.BODY_VEL or o_type == ObservationType.BODY_VEL_WORLD: + local = o_type == ObservationType.BODY_VEL obs = np.empty(6) - mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, True) + mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, local) elif o_type == ObservationType.JOINT_POS: obs = data.joint(name).qpos elif o_type == ObservationType.JOINT_VEL: