From e7d637f375d809f6a83b7e3e3d5d698151edfd56 Mon Sep 17 00:00:00 2001 From: Baruch Tabanpour Date: Fri, 8 Dec 2023 16:43:31 -0800 Subject: [PATCH] Fix issue with foot vel in MJX notebook PiperOrigin-RevId: 589276253 Change-Id: I2c8c24e42d037d8c59667364a298f8b669d779b6 --- mjx/tutorial.ipynb | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/mjx/tutorial.ipynb b/mjx/tutorial.ipynb index 16aa9c3243..522b020812 100644 --- a/mjx/tutorial.ipynb +++ b/mjx/tutorial.ipynb @@ -923,7 +923,9 @@ " 'physics_steps_per_control_step', physics_steps_per_control_step)\n", " super().__init__(mj_model=mj_model, **kwargs)\n", "\n", - " self.torso_idx = 1\n", + " self.torso_idx = mujoco.mj_name2id(\n", + " mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'\n", + " )\n", " self._action_scale = action_scale\n", " self._obs_noise = obs_noise\n", " self._reset_horizon = 500\n", @@ -940,6 +942,7 @@ " self.reward_config = get_config()\n", " self.lowers = self._default_ap_pose - jp.array([0.2, 0.8, 0.8] * 4)\n", " self.uppers = self._default_ap_pose + jp.array([0.2, 0.8, 0.8] * 4)\n", + " self._foot_radius = 0.014\n", "\n", " def sample_command(self, rng: jax.Array) -\u003e jax.Array:\n", " lin_vel_x = [-0.6, 1.0] # min max [m/s]\n", @@ -1023,12 +1026,15 @@ " joint_vel = qvel[6:]\n", "\n", " # foot contact data based on z-position\n", - " foot_contact = 0.017 - self._get_feet_pos_vel(x, xd)[0][:, 2]\n", - " contact = foot_contact \u003e -1e-3 # a mm or less off the floor\n", + " foot_contact_pos = (\n", + " self._get_feet_pos_vel(x, xd)[0][:, 2]\n", + " - self._foot_radius\n", + " )\n", + " contact = foot_contact_pos \u003c 1e-3 # a mm or less off the floor\n", " contact_filt_mm = jp.logical_or(contact, state.info['last_contact'])\n", " contact_filt_cm = jp.logical_or(\n", - " foot_contact \u003e -1e-2, state.info['last_contact']\n", - " )\n", + " foot_contact_pos \u003c 3e-2, state.info['last_contact']\n", + " ) # 3cm or less off the floor\n", " first_contact = (state.info['feet_air_time'] \u003e 0) * (contact_filt_mm)\n", " state.info['feet_air_time'] += self.dt\n", "\n", @@ -1100,29 +1106,26 @@ " state.info.update(rng=rng)\n", "\n", " # resetting logic if joint limits are reached or robot is falling\n", - " done = 0.0\n", " up = jp.array([0.0, 0.0, 1.0])\n", - " done = jp.where(jp.dot(math.rotate(up, x.rot[0]), up) \u003c 0, 1.0, done)\n", - " done = jp.where(jp.logical_or(\n", - " jp.any(joint_angles \u003c .98 * self.lowers),\n", - " jp.any(joint_angles \u003e .98 * self.uppers)), 1.0, done)\n", - " done = jp.where(x.pos[self.torso_idx, 2] \u003c 0.18, 1.0, done)\n", + " done = jp.dot(math.rotate(up, x.rot[0]), up) \u003c 0\n", + " done |= jp.any(joint_angles \u003c 0.98 * self.lowers)\n", + " done |= jp.any(joint_angles \u003e 0.98 * self.uppers)\n", + " done |= x.pos[0, 2] \u003c 0.18\n", "\n", " # termination reward\n", - " reward += jp.where(\n", - " (done == 1.0) \u0026 (state.info['step'] \u003c self._reset_horizon),\n", - " self.reward_config.rewards.scales.termination,\n", - " 0.0,\n", + " reward += (\n", + " done * (state.info['step'] \u003c self._reset_horizon) *\n", + " self.reward_config.rewards.scales.termination\n", " )\n", "\n", " # when done, sample new command if more than _reset_horizon timesteps\n", " # achieved\n", " state.info['command'] = jp.where(\n", - " (done == 1.0) \u0026 (state.info['step'] \u003e self._reset_horizon),\n", + " done \u0026 (state.info['step'] \u003e self._reset_horizon),\n", " self.sample_command(cmd_rng), state.info['command'])\n", " # reset the step counter when done\n", " state.info['step'] = jp.where(\n", - " (done == 1.0) | (state.info['step'] \u003e self._reset_horizon), 0,\n", + " done | (state.info['step'] \u003e self._reset_horizon), 0,\n", " state.info['step']\n", " )\n", "\n", @@ -1133,7 +1136,7 @@ "\n", " state = state.replace(\n", " pipeline_state=data, obs=obs + obs_noise, reward=reward,\n", - " done=done)\n", + " done=done * 1.0)\n", " return state\n", "\n", " def _get_obs(self, qpos: jax.Array, x: Transform, xd: Motion,\n", @@ -1233,7 +1236,8 @@ " self, x: Transform, xd: Motion) -\u003e Tuple[jax.Array, jax.Array]:\n", " offset = Transform.create(pos=self._feet_pos)\n", " pos = x.take(self._feet_index).vmap().do(offset).pos\n", - " vel = offset.vmap().do(xd.take(self._feet_index)).vel\n", + " world_offset = Transform.create(pos=pos - x.take(self._feet_index).pos)\n", + " vel = world_offset.vmap().do(xd.take(self._feet_index)).vel\n", " return pos, vel\n", "\n", " def _reward_foot_slip(\n", @@ -1405,8 +1409,8 @@ "private_outputs": true, "provenance": [ { - "file_id": "1brcF4_qCRS2ASc-QQw1rsEwl5IjzGvq2", - "timestamp": 1697763780236 + "file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb", + "timestamp": 1701993737024 } ], "toc_visible": true