Skip to content

Commit

Permalink
Fix issue with foot vel in MJX notebook
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589276253
Change-Id: I2c8c24e42d037d8c59667364a298f8b669d779b6
  • Loading branch information
btaba authored and copybara-github committed Dec 9, 2023
1 parent 263451a commit e7d637f
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions mjx/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,9 @@
" 'physics_steps_per_control_step', physics_steps_per_control_step)\n",
" super().__init__(mj_model=mj_model, **kwargs)\n",
"\n",
" self.torso_idx = 1\n",
" self.torso_idx = mujoco.mj_name2id(\n",
" mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'\n",
" )\n",
" self._action_scale = action_scale\n",
" self._obs_noise = obs_noise\n",
" self._reset_horizon = 500\n",
Expand All @@ -940,6 +942,7 @@
" self.reward_config = get_config()\n",
" self.lowers = self._default_ap_pose - jp.array([0.2, 0.8, 0.8] * 4)\n",
" self.uppers = self._default_ap_pose + jp.array([0.2, 0.8, 0.8] * 4)\n",
" self._foot_radius = 0.014\n",
"\n",
" def sample_command(self, rng: jax.Array) -\u003e jax.Array:\n",
" lin_vel_x = [-0.6, 1.0] # min max [m/s]\n",
Expand Down Expand Up @@ -1023,12 +1026,15 @@
" joint_vel = qvel[6:]\n",
"\n",
" # foot contact data based on z-position\n",
" foot_contact = 0.017 - self._get_feet_pos_vel(x, xd)[0][:, 2]\n",
" contact = foot_contact \u003e -1e-3 # a mm or less off the floor\n",
" foot_contact_pos = (\n",
" self._get_feet_pos_vel(x, xd)[0][:, 2]\n",
" - self._foot_radius\n",
" )\n",
" contact = foot_contact_pos \u003c 1e-3 # a mm or less off the floor\n",
" contact_filt_mm = jp.logical_or(contact, state.info['last_contact'])\n",
" contact_filt_cm = jp.logical_or(\n",
" foot_contact \u003e -1e-2, state.info['last_contact']\n",
" )\n",
" foot_contact_pos \u003c 3e-2, state.info['last_contact']\n",
" ) # 3cm or less off the floor\n",
" first_contact = (state.info['feet_air_time'] \u003e 0) * (contact_filt_mm)\n",
" state.info['feet_air_time'] += self.dt\n",
"\n",
Expand Down Expand Up @@ -1100,29 +1106,26 @@
" state.info.update(rng=rng)\n",
"\n",
" # resetting logic if joint limits are reached or robot is falling\n",
" done = 0.0\n",
" up = jp.array([0.0, 0.0, 1.0])\n",
" done = jp.where(jp.dot(math.rotate(up, x.rot[0]), up) \u003c 0, 1.0, done)\n",
" done = jp.where(jp.logical_or(\n",
" jp.any(joint_angles \u003c .98 * self.lowers),\n",
" jp.any(joint_angles \u003e .98 * self.uppers)), 1.0, done)\n",
" done = jp.where(x.pos[self.torso_idx, 2] \u003c 0.18, 1.0, done)\n",
" done = jp.dot(math.rotate(up, x.rot[0]), up) \u003c 0\n",
" done |= jp.any(joint_angles \u003c 0.98 * self.lowers)\n",
" done |= jp.any(joint_angles \u003e 0.98 * self.uppers)\n",
" done |= x.pos[0, 2] \u003c 0.18\n",
"\n",
" # termination reward\n",
" reward += jp.where(\n",
" (done == 1.0) \u0026 (state.info['step'] \u003c self._reset_horizon),\n",
" self.reward_config.rewards.scales.termination,\n",
" 0.0,\n",
" reward += (\n",
" done * (state.info['step'] \u003c self._reset_horizon) *\n",
" self.reward_config.rewards.scales.termination\n",
" )\n",
"\n",
" # when done, sample new command if more than _reset_horizon timesteps\n",
" # achieved\n",
" state.info['command'] = jp.where(\n",
" (done == 1.0) \u0026 (state.info['step'] \u003e self._reset_horizon),\n",
" done \u0026 (state.info['step'] \u003e self._reset_horizon),\n",
" self.sample_command(cmd_rng), state.info['command'])\n",
" # reset the step counter when done\n",
" state.info['step'] = jp.where(\n",
" (done == 1.0) | (state.info['step'] \u003e self._reset_horizon), 0,\n",
" done | (state.info['step'] \u003e self._reset_horizon), 0,\n",
" state.info['step']\n",
" )\n",
"\n",
Expand All @@ -1133,7 +1136,7 @@
"\n",
" state = state.replace(\n",
" pipeline_state=data, obs=obs + obs_noise, reward=reward,\n",
" done=done)\n",
" done=done * 1.0)\n",
" return state\n",
"\n",
" def _get_obs(self, qpos: jax.Array, x: Transform, xd: Motion,\n",
Expand Down Expand Up @@ -1233,7 +1236,8 @@
" self, x: Transform, xd: Motion) -\u003e Tuple[jax.Array, jax.Array]:\n",
" offset = Transform.create(pos=self._feet_pos)\n",
" pos = x.take(self._feet_index).vmap().do(offset).pos\n",
" vel = offset.vmap().do(xd.take(self._feet_index)).vel\n",
" world_offset = Transform.create(pos=pos - x.take(self._feet_index).pos)\n",
" vel = world_offset.vmap().do(xd.take(self._feet_index)).vel\n",
" return pos, vel\n",
"\n",
" def _reward_foot_slip(\n",
Expand Down Expand Up @@ -1405,8 +1409,8 @@
"private_outputs": true,
"provenance": [
{
"file_id": "1brcF4_qCRS2ASc-QQw1rsEwl5IjzGvq2",
"timestamp": 1697763780236
"file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb",
"timestamp": 1701993737024
}
],
"toc_visible": true
Expand Down

0 comments on commit e7d637f

Please sign in to comment.