Skip to content

Commit

Permalink
Work in progress on vectorized dataset
Browse files Browse the repository at this point in the history
- preliminary support for vectorized dataset
- episode infos are not supported yet unfortunately
- major refactoring of DataCoverter, now it's ArrayBackend
- still some issues in the vectorized dataset conversion to serial dataset, needs to be fixed
- fixed some issues of tensor features not using proper device
- fixed issue in test for vectorized environments
  • Loading branch information
boris-il-forte committed Dec 5, 2023
1 parent 42872ec commit a5dfa3a
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 100 deletions.
2 changes: 1 addition & 1 deletion mushroom_rl/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .core import Core
from .dataset import Dataset
from .dataset import Dataset, VectorizedDataset
from .environment import Environment, MDPInfo
from .agent import Agent, AgentInfo
from .serialization import Serializable
Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/core/_impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .numpy_dataset import NumpyDataset
from .torch_dataset import TorchDataset
from .list_dataset import ListDataset
from .type_conversions import DataConversion, NumpyConversion, TorchConversion, ListConversion
from .array_backend import ArrayBackend, NumpyBackend, TorchBackend, ListBackend
from .core_logic import CoreLogic
from .vectorized_core_logic import VectorizedCoreLogic
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
from mushroom_rl.utils.torch import TorchUtils


class DataConversion(object):
class ArrayBackend(object):
@staticmethod
def get_converter(backend):
def get_backend_name():
raise NotImplementedError

@staticmethod
def get_array_backend(backend):
if backend == 'numpy':
return NumpyConversion
return NumpyBackend
elif backend == 'torch':
return TorchConversion
return TorchBackend
else:
return ListConversion
return ListBackend

@classmethod
def convert(cls, *arrays, to='numpy'):
Expand Down Expand Up @@ -55,8 +59,16 @@ def ones(*dims, dtype):
def copy(array):
raise NotImplementedError

@staticmethod
def pack_padded_sequence(array, lengths):
raise NotImplementedError


class NumpyBackend(ArrayBackend):
@staticmethod
def get_backend_name():
return 'numpy'

class NumpyConversion(DataConversion):
@staticmethod
def to_numpy(array):
return array
Expand All @@ -81,8 +93,20 @@ def ones(*dims, dtype=float):
def copy(array):
return array.copy()

@staticmethod
def pack_padded_sequence(array, lengths):
shape = array.shape

new_shape = (shape[0] * shape[1],) + shape[2:]
mask = (np.arange(len(array))[:, None] < lengths[None, :]).flatten()
return array.reshape(new_shape)[mask]


class TorchBackend(ArrayBackend):
@staticmethod
def get_backend_name():
return 'torch'

class TorchConversion(DataConversion):
@staticmethod
def to_numpy(array):
return None if array is None else array.detach().cpu().numpy()
Expand All @@ -107,8 +131,20 @@ def ones(*dims, dtype=torch.float32):
def copy(array):
return array.clone()

@staticmethod
def pack_padded_sequence(array, lengths):
shape = array.shape

new_shape = (shape[0]*shape[1], ) + shape[2:]
mask = (torch.arange(len(array), device=TorchUtils.get_device())[None, :] < lengths[:, None]).flatten()
return array.reshape(new_shape)[mask]


class ListBackend(ArrayBackend):
@staticmethod
def get_backend_name():
return 'list'

class ListConversion(DataConversion):
@staticmethod
def to_numpy(array):
return np.array(array)
Expand All @@ -133,6 +169,10 @@ def ones(*dims, dtype=float):
def copy(array):
return array.copy()

@staticmethod
def pack_padded_sequence(array, lengths):
return NumpyBackend.pack_padded_sequence(array, lengths)




4 changes: 4 additions & 0 deletions mushroom_rl/core/_impl/list_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def policy_state(self):
def policy_next_state(self):
return [step[7] for step in self._dataset]

@property
def is_stateful(self):
return self._is_stateful

@property
def n_episodes(self):
n_episodes = 0
Expand Down
18 changes: 9 additions & 9 deletions mushroom_rl/core/_impl/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class NumpyDataset(Serializable):
def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, policy_state_shape):
flags_len = action_shape[0]
def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape,
policy_state_shape):

self._state_type = state_type
self._action_type = action_type
Expand All @@ -14,8 +14,8 @@ def __init__(self, state_type, state_shape, action_type, action_shape, reward_sh
self._actions = np.empty(action_shape, dtype=self._action_type)
self._rewards = np.empty(reward_shape, dtype=float)
self._next_states = np.empty(state_shape, dtype=self._state_type)
self._absorbing = np.empty(flags_len, dtype=bool)
self._last = np.empty(flags_len, dtype=bool)
self._absorbing = np.empty(flag_shape, dtype=bool)
self._last = np.empty(flag_shape, dtype=bool)
self._len = 0

if policy_state_shape is None:
Expand Down Expand Up @@ -100,7 +100,7 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
self._absorbing[i] = absorbing
self._last[i] = last

if self._is_stateful:
if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state

Expand All @@ -114,7 +114,7 @@ def clear(self):
self._absorbing = np.empty_like(self._absorbing)
self._last = np.empty_like(self._last)

if self._is_stateful:
if self.is_stateful:
self._policy_states = np.empty_like(self._policy_states)
self._policy_next_states = np.empty_like(self._policy_next_states)

Expand All @@ -131,7 +131,7 @@ def get_view(self, index):
view._last = self.last[index, ...]
view._len = view._states.shape[0]

if self._is_stateful:
if self.is_stateful:
view._policy_states = self._policy_states[index, ...]
view._policy_next_states = self._policy_next_states[index, ...]

Expand All @@ -153,7 +153,7 @@ def __add__(self, other):
result._last[len(self)-1] = True
result._len = len(self) + len(other)

if self._is_stateful:
if self.is_stateful:
result._policy_states = np.concatenate((self.policy_state, other.policy_state))
result._policy_next_states = np.concatenate((self.policy_next_state, other.policy_next_state))

Expand Down Expand Up @@ -192,7 +192,7 @@ def policy_next_state(self):
return self._policy_next_states[:len(self)]

@property
def _is_stateful(self):
def is_stateful(self):
return self._policy_states is not None

@property
Expand Down
20 changes: 10 additions & 10 deletions mushroom_rl/core/_impl/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@


class TorchDataset(Serializable):
def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, policy_state_shape):
flags_len = action_shape[0]

def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape,
policy_state_shape):
self._state_type = state_type
self._action_type = action_type

self._states = torch.empty(*state_shape, dtype=self._state_type, device=TorchUtils.get_device())
self._actions = torch.empty(*action_shape, dtype=self._action_type, device=TorchUtils.get_device())
self._rewards = torch.empty(*reward_shape, dtype=torch.float, device=TorchUtils.get_device())
self._next_states = torch.empty(*state_shape, dtype=self._state_type, device=TorchUtils.get_device())
self._absorbing = torch.empty(flags_len, dtype=torch.bool, device=TorchUtils.get_device())
self._last = torch.empty(flags_len, dtype=torch.bool, device=TorchUtils.get_device())
self._absorbing = torch.empty(flag_shape, dtype=torch.bool, device=TorchUtils.get_device())
self._last = torch.empty(flag_shape, dtype=torch.bool, device=TorchUtils.get_device())
self._len = 0

if policy_state_shape is None:
Expand Down Expand Up @@ -101,7 +100,7 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
self._absorbing[i] = absorbing
self._last[i] = last

if self._is_stateful:
if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state

Expand All @@ -115,7 +114,7 @@ def clear(self):
self._absorbing = torch.empty_like(self._absorbing)
self._last = torch.empty_like(self._last)

if self._is_stateful:
if self.is_stateful:
self._policy_states = torch.empty_like(self._policy_states)
self._policy_next_states = torch.empty_like(self._policy_next_states)

Expand All @@ -132,7 +131,7 @@ def get_view(self, index):
view._last = self._last[index, ...]
view._len = view._states.shape[0]

if self._is_stateful:
if self.is_stateful:
view._policy_states = self._policy_states[index, ...]
view._policy_next_states = self._policy_next_states[index, ...]

Expand All @@ -154,12 +153,13 @@ def __add__(self, other):
result._last[len(self) - 1] = True
result._len = len(self) + len(other)

if self._is_stateful:
if self.is_stateful:
result._policy_states = torch.concatenate((self.policy_state, other.policy_state))
result._policy_next_states = torch.concatenate((self.policy_next_state, other.policy_next_state))

return result


@property
def state(self):
return self._states[:len(self)]
Expand Down Expand Up @@ -193,7 +193,7 @@ def policy_next_state(self):
return self._policy_next_states[:len(self)]

@property
def _is_stateful(self):
def is_stateful(self):
return self._policy_states is not None

@property
Expand Down
18 changes: 9 additions & 9 deletions mushroom_rl/core/_impl/vectorized_core_logic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from .type_conversions import DataConversion
from .array_backend import ArrayBackend
from .core_logic import CoreLogic


class VectorizedCoreLogic(CoreLogic):
def __init__(self, backend, n_envs):
self._converter = DataConversion.get_converter(backend)
self._array_backend = ArrayBackend.get_array_backend(backend)
self._n_envs = n_envs
self._running_envs = self._converter.zeros(n_envs, dtype=bool)
self._running_envs = self._array_backend.zeros(n_envs, dtype=bool)

super().__init__()

def get_mask(self, last):
mask = self._converter.ones(self._n_envs, dtype=bool)
mask = self._array_backend.ones(self._n_envs, dtype=bool)
terminated_episodes = (last & self._running_envs).sum()
running_episodes = (~last & self._running_envs).sum()

Expand All @@ -29,11 +29,11 @@ def get_mask(self, last):
missing_episodes_fit = self._n_episodes_per_fit - self._current_episodes_counter - running_episodes
max_runs = min(missing_episodes_fit, max_runs)

new_mask = self._converter.ones(terminated_episodes, dtype=bool)
new_mask = self._array_backend.ones(terminated_episodes, dtype=bool)
new_mask[max_runs:] = False
mask[last] = new_mask

self._running_envs = self._converter.copy(mask)
self._running_envs = self._array_backend.copy(mask)

return mask

Expand All @@ -59,12 +59,12 @@ def after_step(self, last):
def after_fit(self):
super().after_fit()
if self._n_episodes_per_fit is not None:
self._running_envs = self._converter.zeros(self._n_envs, dtype=bool)
self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool)

def _reset_counters(self):
super()._reset_counters()
self._running_envs = self._converter.zeros(self._n_envs, dtype=bool)
self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool)

@property
def converter(self):
return self._converter
return self._array_backend
12 changes: 6 additions & 6 deletions mushroom_rl/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'):

self.policy = policy
self.next_action = None
self._agent_converter = DataConversion.get_converter(backend)
self._env_converter = DataConversion.get_converter(self.mdp_info.backend)
self._agent_backend = ArrayBackend.get_array_backend(backend)
self._env_backend = ArrayBackend.get_array_backend(self.mdp_info.backend)

self._preprocessors = list()

Expand All @@ -60,8 +60,8 @@ def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'):
next_action='none',
mdp_info='mushroom',
_info='mushroom',
_agent_converter='primitive',
_env_converter='primitive',
_agent_backend='primitive',
_env_backend='primitive',
_preprocessors='mushroom',
_logger='none'
)
Expand Down Expand Up @@ -146,10 +146,10 @@ def preprocessors(self):
return self._preprocessors

def _convert_to_env_backend(self, array):
return self._env_converter.to_backend_array(self._agent_converter, array)
return self._env_backend.to_backend_array(self._agent_backend, array)

def _convert_to_agent_backend(self, array):
return self._agent_converter.to_backend_array(self._env_converter, array)
return self._agent_backend.to_backend_array(self._env_backend, array)

@property
def info(self):
Expand Down
Loading

0 comments on commit a5dfa3a

Please sign in to comment.