Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add infoclass #153

Merged
merged 7 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/isaac_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep

dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = torch.mean(torch.stack(dataset.discounted_return))
R = torch.mean(torch.stack(dataset.undiscounted_return))
J = torch.mean(dataset.discounted_return)
R = torch.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(0, J=J, R=R, entropy=E)
Expand All @@ -89,8 +89,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = torch.mean(torch.stack(dataset.discounted_return))
R = torch.mean(torch.stack(dataset.undiscounted_return))
J = torch.mean(dataset.discounted_return)
R = torch.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(it+1, J=J, R=R, entropy=E)
Expand Down
4 changes: 3 additions & 1 deletion mushroom_rl/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from .serialization import Serializable
from .logger import Logger

from .extra_info import ExtraInfo

from .vectorized_core import VectorCore
from .vectorized_env import VectorizedEnvironment
from .multiprocess_environment import MultiprocessEnvironment

import mushroom_rl.environments

__all__ = ['ArrayBackend', 'Core', 'DatasetInfo', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo',
'Serializable', 'Logger', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
'Serializable', 'Logger', 'ExtraInfo', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
86 changes: 85 additions & 1 deletion mushroom_rl/core/array_backend.py
Bjarne-55 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ def from_list(array):
@staticmethod
def pack_padded_sequence(array, mask):
raise NotImplementedError

@staticmethod
def flatten(array):
raise NotImplementedError

@staticmethod
def empty(shape, device=None):
raise NotImplementedError

@staticmethod
def none():
raise NotImplementedError

@staticmethod
def shape(array):
raise NotImplementedError

@staticmethod
def full(shape, value):
raise NotImplementedError


class NumpyBackend(ArrayBackend):
Expand Down Expand Up @@ -253,6 +273,28 @@ def pack_padded_sequence(array, mask):

new_shape = (shape[0] * shape[1],) + shape[2:]
return array.reshape(new_shape, order='F')[mask.flatten(order='F')]

@staticmethod
def flatten(array):
shape = array.shape
new_shape = (shape[0] * shape[1],) + shape[2:]
return array.reshape(new_shape, order='F')

@staticmethod
def empty(shape, device=None):
return np.empty(shape)

@staticmethod
def none():
return np.nan

@staticmethod
def shape(array):
return array.shape

@staticmethod
def full(shape, value):
return np.full(shape, value)


class TorchBackend(ArrayBackend):
Expand Down Expand Up @@ -364,9 +406,31 @@ def pack_padded_sequence(array, mask):
shape = array.shape

new_shape = (shape[0]*shape[1], ) + shape[2:]

return array.transpose(0, 1).reshape(new_shape)[mask.transpose(0, 1).flatten()]

@staticmethod
def flatten(array):
shape = array.shape
new_shape = (shape[0]*shape[1], ) + shape[2:]
return array.transpose(0, 1).reshape(new_shape)

@staticmethod
def empty(shape, device=None):
device = TorchUtils.get_device() if device is None else device
return torch.empty(shape, device=device)

@staticmethod
def none():
return torch.nan

@staticmethod
def shape(array):
return array.shape

@staticmethod
def full(shape, value):
return torch.full(shape, value)

class ListBackend(ArrayBackend):

Expand Down Expand Up @@ -421,3 +485,23 @@ def from_list(array):
@staticmethod
def pack_padded_sequence(array, mask):
return NumpyBackend.pack_padded_sequence(array, np.array(mask))

@staticmethod
def flatten(array):
return NumpyBackend.flatten(array)

@staticmethod
def empty(shape, device=None):
return np.empty(shape)

@staticmethod
def none():
return None

@staticmethod
def shape(array):
return np.array(array).shape

@staticmethod
def full(shape, value):
return np.full(shape, value)
2 changes: 2 additions & 0 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat

self._end(record)

dataset.info.parse()
dataset.episode_info.parse()
return dataset

def _step(self, render, record):
Expand Down
45 changes: 24 additions & 21 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mushroom_rl.core.serialization import Serializable
from .array_backend import ArrayBackend
from .extra_info import ExtraInfo

from ._impl import *

Expand Down Expand Up @@ -103,8 +104,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
else:
policy_state_shape = None

self._info = defaultdict(list)
self._episode_info = defaultdict(list)
self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._theta_list = list()

if dataset_info.backend == 'numpy':
Expand Down Expand Up @@ -195,12 +196,12 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
dataset = cls.create_raw_instance()

if info is None:
dataset._info = defaultdict(list)
dataset._info = ExtraInfo(1, backend)
else:
dataset._info = info.copy()

if episode_info is None:
dataset._episode_info = defaultdict(list)
dataset._episode_info = ExtraInfo(1, backend)
else:
dataset._episode_info = episode_info.copy()

Expand Down Expand Up @@ -228,7 +229,7 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,

def append(self, step, info):
self._data.append(*step)
self._append_info(self._info, info)
self._info.append(info)

def append_episode_info(self, info):
self._append_info(self._episode_info, info)
Expand All @@ -243,21 +244,17 @@ def get_info(self, field, index=None):
return self._info[field][index]

def clear(self):
self._episode_info = defaultdict(list)
self._episode_info.clear()
self._theta_list = list()
self._info = defaultdict(list)
self._info.clear()

self._data.clear()

def get_view(self, index, copy=False):
dataset = self.create_raw_instance(dataset=self)

info_slice = defaultdict(list)
for key in self._info.keys():
info_slice[key] = self._info[key][index]

dataset._info = info_slice
dataset._episode_info = defaultdict(list)
dataset._info = self._info.get_view(index, copy)
dataset._episode_info = self._episode_info.get_view(index, copy)
dataset._data = self._data.get_view(index, copy)

return dataset
Expand All @@ -276,11 +273,9 @@ def __getitem__(self, index):

def __add__(self, other):
result = self.create_raw_instance(dataset=self)
new_info = self._merge_info(self.info, other.info)
new_episode_info = self._merge_info(self.episode_info, other.episode_info)

result._info = new_info
result._episode_info = new_episode_info
result._info = self._info + other._info
result._episode_info = self._episode_info + other._episode_info
result._theta_list = self._theta_list + other._theta_list
result._data = self._data + other._data

Expand Down Expand Up @@ -525,8 +520,8 @@ def _convert(self, *arrays, to='numpy'):

def _add_all_save_attr(self):
self._add_save_attr(
_info='pickle',
_episode_info='pickle',
_info='mushroom',
_episode_info='mushroom',
_theta_list='pickle',
_data='mushroom',
_array_backend='primitive',
Expand Down Expand Up @@ -557,7 +552,7 @@ def append(self, step, info):

def append_vectorized(self, step, info, mask):
self._data.append(*step, mask=mask)
self._append_info(self._info, {}) # FIXME: handle properly info
self._info.append(info)

def append_theta_vectorized(self, theta, mask):
for i in range(len(theta)):
Expand All @@ -581,11 +576,16 @@ def clear(self, n_steps_per_fit=None):
mask.flatten()[n_extra_steps:] = False
residual_data.mask = mask.reshape(original_shape)

residual_info = self._info.get_view(view_size, copy=True)
residual_episode_info = self._episode_info.get_view(view_size, copy=True)

super().clear()
self._initialize_theta_list(n_envs)

if n_steps_per_fit is not None and residual_data is not None:
self._data = residual_data
self._info = residual_info
self._episode_info = residual_episode_info

def flatten(self, n_steps_per_fit=None):
if len(self) == 0:
Expand Down Expand Up @@ -622,9 +622,12 @@ def flatten(self, n_steps_per_fit=None):

flat_theta_list = self._flatten_theta_list()

flat_info = self._info.flatten(self.mask)
flat_episode_info = self._episode_info.flatten(self.mask)

return Dataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
policy_state=policy_state, policy_next_state=policy_next_state,
info=None, episode_info=None, theta_list=flat_theta_list, # FIXME: handle properly info
info=flat_info, episode_info=flat_episode_info, theta_list=flat_theta_list,
horizon=self._dataset_info.horizon, gamma=self._dataset_info.gamma,
backend=self._array_backend.get_backend_name())

Expand Down
Loading
Loading