Skip to content

Commit

Permalink
Isaac datatype fixes and convert task observation (#138)
Browse files Browse the repository at this point in the history
* fixed J, R calculation from dataset

* fixed reset_all datatype and dimension check

* convert observation from Task-type (dict) to Tensor

* proper naming of isaac_util function

* convert task_obs also in reset

* added default episode length
  • Loading branch information
RiicK3d authored Dec 29, 2023
1 parent 6a87f3e commit 724f017
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
8 changes: 4 additions & 4 deletions examples/isaac_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep

dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
J = torch.mean(torch.stack(dataset.discounted_return))
R = torch.mean(torch.stack(dataset.undiscounted_return))
E = agent.policy.entropy()

logger.epoch_info(0, J=J, R=R, entropy=E)
Expand All @@ -89,8 +89,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
J = torch.mean(torch.stack(dataset.discounted_return))
R = torch.mean(torch.stack(dataset.undiscounted_return))
E = agent.policy.entropy()

logger.epoch_info(it+1, J=J, R=R, entropy=E)
Expand Down
20 changes: 15 additions & 5 deletions mushroom_rl/environments/isaac_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mushroom_rl.core import VectorizedEnvironment, MDPInfo
from mushroom_rl.utils.viewer import ImageViewer
from mushroom_rl.utils.isaac_utils import convert_task_observation
from mushroom_rl.rl_utils.spaces import *

# import carb
Expand Down Expand Up @@ -49,8 +50,12 @@ def __init__(self, cfg=None, headless=False, backend='torch'):
observation_space = self._convert_gym_space(self._task.observation_space)

# Create MDP info for mushroom
# default episod lenght
max_e_lenght = 1000
if hasattr(self._task, '_max_episode_length'):
max_e_lenght = self._task._max_episode_length
mdp_info = MDPInfo(observation_space, action_space, 0.99,
self._task._max_episode_length, dt=RENDER_DT, backend=backend)
max_e_lenght, dt=RENDER_DT, backend=backend)

super().__init__(mdp_info, self._task.num_envs)

Expand Down Expand Up @@ -80,10 +85,13 @@ def seed(self, seed=-1):
return set_seed(seed)

def reset_all(self, env_mask, state=None):
idxs = torch.argwhere(env_mask).squeeze().cpu().numpy() # TODO check if torch is just fine
self._task.reset_idx(idxs)
idxs = torch.argwhere(env_mask).squeeze() # .cpu().numpy() # takes torch datatype
if idxs.dim() > 0: # only resets task for tensor with actual dimension
self._task.reset_idx(idxs)
# self._world.step(render=self._render) # TODO Check if we can do otherwise
return self._task.get_observations(), [{}]*self._n_envs
task_obs = self._task.get_observations()
observation = convert_task_observation(task_obs)
return observation, [{}]*self._n_envs

def step_all(self, env_mask, action):
self._task.pre_physics_step(action)
Expand All @@ -93,7 +101,9 @@ def step_all(self, env_mask, action):
self._world.step(render=self._render)

observation, reward, done, info = self._task.post_physics_step()

# converts task obs from dictionary to tensor
observation = convert_task_observation(observation)

env_mask_cuda = torch.as_tensor(env_mask).cuda()

return observation, reward, torch.logical_and(done, env_mask_cuda), [info]*self._n_envs
Expand Down
10 changes: 10 additions & 0 deletions mushroom_rl/utils/isaac_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch


def convert_task_observation(observation):
obs_t = observation
for _ in range(5):
if torch.is_tensor(obs_t):
break
obs_t = obs_t[list(obs_t.keys())[0]]
return obs_t

0 comments on commit 724f017

Please sign in to comment.