Skip to content

Commit

Permalink
Add ExtraInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
Bjarne-55 committed Oct 15, 2024
1 parent cab144d commit de95b84
Show file tree
Hide file tree
Showing 7 changed files with 983 additions and 23 deletions.
4 changes: 3 additions & 1 deletion mushroom_rl/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from .serialization import Serializable
from .logger import Logger

from .extra_info import ExtraInfo

from .vectorized_core import VectorCore
from .vectorized_env import VectorizedEnvironment
from .multiprocess_environment import MultiprocessEnvironment

import mushroom_rl.environments

__all__ = ['ArrayBackend', 'Core', 'DatasetInfo', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo',
'Serializable', 'Logger', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
'Serializable', 'Logger', 'ExtraInfo', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
86 changes: 85 additions & 1 deletion mushroom_rl/core/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ def from_list(array):
@staticmethod
def pack_padded_sequence(array, mask):
raise NotImplementedError

@staticmethod
def flatten(array):
raise NotImplementedError

@staticmethod
def empty(shape, device=None):
raise NotImplementedError

@staticmethod
def none():
raise NotImplementedError

@staticmethod
def shape(array):
raise NotImplementedError

@staticmethod
def full(shape, value):
raise NotImplementedError


class NumpyBackend(ArrayBackend):
Expand Down Expand Up @@ -253,6 +273,28 @@ def pack_padded_sequence(array, mask):

new_shape = (shape[0] * shape[1],) + shape[2:]
return array.reshape(new_shape, order='F')[mask.flatten(order='F')]

@staticmethod
def flatten(array):
shape = array.shape
new_shape = (shape[0] * shape[1],) + shape[2:]
return array.reshape(new_shape, order='F')

@staticmethod
def empty(shape, device=None):
return np.empty(shape)

@staticmethod
def none():
return np.nan

@staticmethod
def shape(array):
return array.shape

@staticmethod
def full(shape, value):
return np.full(shape, value)


class TorchBackend(ArrayBackend):
Expand Down Expand Up @@ -364,9 +406,31 @@ def pack_padded_sequence(array, mask):
shape = array.shape

new_shape = (shape[0]*shape[1], ) + shape[2:]

return array.transpose(0, 1).reshape(new_shape)[mask.transpose(0, 1).flatten()]

@staticmethod
def flatten(array):
shape = array.shape
new_shape = (shape[0]*shape[1], ) + shape[2:]
return array.transpose(0, 1).reshape(new_shape)

@staticmethod
def empty(shape, device=None):
device = TorchUtils.get_device() if device is None else device
return torch.empty(shape, device=device)

@staticmethod
def none():
return torch.nan

@staticmethod
def shape(array):
return array.shape

@staticmethod
def full(shape, value):
return torch.full(shape, value)

class ListBackend(ArrayBackend):

Expand Down Expand Up @@ -421,3 +485,23 @@ def from_list(array):
@staticmethod
def pack_padded_sequence(array, mask):
return NumpyBackend.pack_padded_sequence(array, np.array(mask))

@staticmethod
def flatten(array):
return NumpyBackend.flatten(array)

@staticmethod
def empty(shape, device=None):
return np.empty(shape)

@staticmethod
def none():
return None

@staticmethod
def shape(array):
return np.array(array).shape

@staticmethod
def full(shape, value):
return np.full(shape, value)
2 changes: 2 additions & 0 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat

self._end(record)

dataset.info.parse()
dataset.episode_info.parse()
return dataset

def _step(self, render, record):
Expand Down
45 changes: 24 additions & 21 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mushroom_rl.core.serialization import Serializable
from .array_backend import ArrayBackend
from .extra_info import ExtraInfo

from ._impl import *

Expand Down Expand Up @@ -103,8 +104,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
else:
policy_state_shape = None

self._info = defaultdict(list)
self._episode_info = defaultdict(list)
self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._theta_list = list()

if dataset_info.backend == 'numpy':
Expand Down Expand Up @@ -195,12 +196,12 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
dataset = cls.create_raw_instance()

if info is None:
dataset._info = defaultdict(list)
dataset._info = ExtraInfo(1, backend)
else:
dataset._info = info.copy()

if episode_info is None:
dataset._episode_info = defaultdict(list)
dataset._episode_info = ExtraInfo(1, backend)
else:
dataset._episode_info = episode_info.copy()

Expand Down Expand Up @@ -228,7 +229,7 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,

def append(self, step, info):
self._data.append(*step)
self._append_info(self._info, info)
self._info.append(info)

def append_episode_info(self, info):
self._append_info(self._episode_info, info)
Expand All @@ -243,21 +244,17 @@ def get_info(self, field, index=None):
return self._info[field][index]

def clear(self):
self._episode_info = defaultdict(list)
self._episode_info.clear()
self._theta_list = list()
self._info = defaultdict(list)
self._info.clear()

self._data.clear()

def get_view(self, index, copy=False):
dataset = self.create_raw_instance(dataset=self)

info_slice = defaultdict(list)
for key in self._info.keys():
info_slice[key] = self._info[key][index]

dataset._info = info_slice
dataset._episode_info = defaultdict(list)
dataset._info = self._info.get_view(index, copy)
dataset._episode_info = self._episode_info.get_view(index, copy)
dataset._data = self._data.get_view(index, copy)

return dataset
Expand All @@ -276,11 +273,9 @@ def __getitem__(self, index):

def __add__(self, other):
result = self.create_raw_instance(dataset=self)
new_info = self._merge_info(self.info, other.info)
new_episode_info = self._merge_info(self.episode_info, other.episode_info)

result._info = new_info
result._episode_info = new_episode_info
result._info = self._info + other._info
result._episode_info = self._episode_info + other._episode_info
result._theta_list = self._theta_list + other._theta_list
result._data = self._data + other._data

Expand Down Expand Up @@ -525,8 +520,8 @@ def _convert(self, *arrays, to='numpy'):

def _add_all_save_attr(self):
self._add_save_attr(
_info='pickle',
_episode_info='pickle',
_info='mushroom',
_episode_info='mushroom',
_theta_list='pickle',
_data='mushroom',
_array_backend='primitive',
Expand Down Expand Up @@ -557,7 +552,7 @@ def append(self, step, info):

def append_vectorized(self, step, info, mask):
self._data.append(*step, mask=mask)
self._append_info(self._info, {}) # FIXME: handle properly info
self._info.append(info)

def append_theta_vectorized(self, theta, mask):
for i in range(len(theta)):
Expand All @@ -581,11 +576,16 @@ def clear(self, n_steps_per_fit=None):
mask.flatten()[n_extra_steps:] = False
residual_data.mask = mask.reshape(original_shape)

residual_info = self._info.get_view(view_size, copy=True)
residual_episode_info = self._episode_info.get_view(view_size, copy=True)

super().clear()
self._initialize_theta_list(n_envs)

if n_steps_per_fit is not None and residual_data is not None:
self._data = residual_data
self._info = residual_info
self._episode_info = residual_episode_info

def flatten(self, n_steps_per_fit=None):
if len(self) == 0:
Expand Down Expand Up @@ -622,9 +622,12 @@ def flatten(self, n_steps_per_fit=None):

flat_theta_list = self._flatten_theta_list()

flat_info = self._info.flatten(self.mask)
flat_episode_info = self._episode_info.flatten(self.mask)

return Dataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
policy_state=policy_state, policy_next_state=policy_next_state,
info=None, episode_info=None, theta_list=flat_theta_list, # FIXME: handle properly info
info=flat_info, episode_info=flat_episode_info, theta_list=flat_theta_list,
horizon=self._dataset_info.horizon, gamma=self._dataset_info.gamma,
backend=self._array_backend.get_backend_name())

Expand Down
Loading

0 comments on commit de95b84

Please sign in to comment.