Skip to content

Commit

Permalink
Major improvements and bugfixes
Browse files Browse the repository at this point in the history
- Fixed bugs in array backend, imrpioved tests
- Fixed bugs in vectorized dataset, added more functionalities
- Fixed bug in the collection of parameters
- minor improvemnets in setup.py
  • Loading branch information
boris-il-forte committed Dec 7, 2023
1 parent 1348147 commit a668767
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 69 deletions.
12 changes: 8 additions & 4 deletions examples/segway_test_bbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,19 @@ def experiment(alg, params, n_epochs, n_episodes, n_ep_per_fit):
dataset_callback = CollectDataset()
core = Core(agent, mdp, callbacks_fit=[dataset_callback])

dataset = core.evaluate(n_episodes=n_episodes)
J = np.mean(dataset.discounted_return)
p = dist.get_parameters()
logger.epoch_info(0, J=J, mu=p[:n_weights], sigma=p[n_weights:])

for i in trange(n_epochs, leave=False):
core.learn(n_episodes=n_episodes,
n_episodes_per_fit=n_ep_per_fit, render=False)
J = dataset_callback.get().discounted_return
core.learn(n_episodes=n_episodes, n_episodes_per_fit=n_ep_per_fit, render=False)
J = np.mean(dataset_callback.get().discounted_return)
dataset_callback.clean()

p = dist.get_parameters()

logger.epoch_info(i+1, J=np.mean(J), mu=p[:n_weights], sigma=p[n_weights:])
logger.epoch_info(i+1, J=J, mu=p[:n_weights], sigma=p[n_weights:])

logger.info('Press a button to visualize the segway...')
input()
Expand Down
11 changes: 7 additions & 4 deletions examples/vectorized_core/segway_bbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ def experiment(alg, params, n_epochs, n_episodes, n_ep_per_fit):
dataset_callback = CollectDataset()
core = VectorCore(agent, mdp, callbacks_fit=[dataset_callback])

dataset = core.evaluate(n_episodes=n_episodes)
J = np.mean(dataset.discounted_return)
p = dist.get_parameters()
logger.epoch_info(0, J=J, mu=p[:n_weights], sigma=p[n_weights:])

for i in trange(n_epochs, leave=False):
core.learn(n_episodes=n_episodes,
n_episodes_per_fit=n_ep_per_fit, render=False)
dataset = dataset_callback.get()
J = np.mean(dataset.discounted_return)
core.learn(n_episodes=n_episodes, n_episodes_per_fit=n_ep_per_fit, render=False)
J = np.mean(dataset_callback.get().discounted_return)
dataset_callback.clean()

p = dist.get_parameters()
Expand Down
17 changes: 8 additions & 9 deletions mushroom_rl/core/_impl/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def copy(array):
raise NotImplementedError

@staticmethod
def pack_padded_sequence(array, lengths):
def pack_padded_sequence(array, mask):
raise NotImplementedError


Expand Down Expand Up @@ -94,12 +94,11 @@ def copy(array):
return array.copy()

@staticmethod
def pack_padded_sequence(array, lengths):
def pack_padded_sequence(array, mask):
shape = array.shape

new_shape = (shape[0] * shape[1],) + shape[2:]
mask = (np.arange(len(array))[:, None] < lengths[None, :]).flatten(order='F')
return array.reshape(new_shape, order='F')[mask]
return array.reshape(new_shape, order='F')[mask.flatten(order='F')]


class TorchBackend(ArrayBackend):
Expand Down Expand Up @@ -132,12 +131,12 @@ def copy(array):
return array.clone()

@staticmethod
def pack_padded_sequence(array, lengths):
def pack_padded_sequence(array, mask):
shape = array.shape

new_shape = (shape[0]*shape[1], ) + shape[2:]
mask = (torch.arange(len(array), device=TorchUtils.get_device())[None, :] < lengths[:, None]).flatten()
return array.transpose(0,1).reshape(new_shape)[mask]

return array.transpose(0, 1).reshape(new_shape)[mask.transpose(0, 1).flatten()]


class ListBackend(ArrayBackend):
Expand Down Expand Up @@ -170,8 +169,8 @@ def copy(array):
return array.copy()

@staticmethod
def pack_padded_sequence(array, lengths):
return NumpyBackend.pack_padded_sequence(array, lengths)
def pack_padded_sequence(array, mask):
return NumpyBackend.pack_padded_sequence(array, np.array(mask))



Expand Down
18 changes: 15 additions & 3 deletions mushroom_rl/core/_impl/list_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@


class ListDataset(Serializable):
def __init__(self, is_stateful):
def __init__(self, is_stateful, is_vectorized):
self._dataset = list()
self._policy_dataset = list()
self._is_stateful = is_stateful

if is_vectorized:
self._mask = list()
else:
self._mask = None

self._add_save_attr(
_dataset='pickle',
_policy_dataset='pickle',
Expand All @@ -20,7 +25,7 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, po
policy_next_states=None):
is_stateful = (policy_states is not None) and (policy_next_states is not None)

dataset = cls(is_stateful)
dataset = cls(is_stateful, False)

if dataset._is_stateful:
for s, a, r, ss, ab, last, ps, pss in zip(states, actions, rewards, next_states,
Expand All @@ -37,12 +42,15 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, po
def __len__(self):
return len(self._dataset)

def append(self, *step):
def append(self, *step, mask=None):
step_copy = deepcopy(step)
self._dataset.append(step_copy[:6])
if self._is_stateful:
self._policy_dataset.append(step_copy[6:])

if mask is not None:
self._mask.append(mask)

def clear(self):
self._dataset = list()

Expand Down Expand Up @@ -105,6 +113,10 @@ def policy_next_state(self):
def is_stateful(self):
return self._is_stateful

@property
def mask(self):
return self._mask

@property
def n_episodes(self):
n_episodes = 0
Expand Down
18 changes: 16 additions & 2 deletions mushroom_rl/core/_impl/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class NumpyDataset(Serializable):
def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape,
policy_state_shape):
policy_state_shape, mask_shape):

self._state_type = state_type
self._action_type = action_type
Expand All @@ -25,6 +25,11 @@ def __init__(self, state_type, state_shape, action_type, action_shape, reward_sh
self._policy_states = np.empty(policy_state_shape, dtype=float)
self._policy_next_states = np.empty(policy_state_shape, dtype=float)

if mask_shape is None:
self._mask = None
else:
self._mask = np.empty(mask_shape, dtype=bool)

self._add_save_attr(
_state_type='primitive',
_action_type='primitive',
Expand All @@ -36,6 +41,7 @@ def __init__(self, state_type, state_shape, action_type, action_shape, reward_sh
_last='numpy',
_policy_states='numpy',
_policy_next_states='numpy',
_mask='numpy',
_len='primitive'
)

Expand Down Expand Up @@ -93,7 +99,8 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
def __len__(self):
return self._len

def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None):
def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None,
mask=None):
i = self._len

self._states[i] = state
Expand All @@ -107,6 +114,9 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state

if mask is not None:
self._mask[i] = mask

self._len += 1

def clear(self):
Expand Down Expand Up @@ -194,6 +204,10 @@ def policy_state(self):
def policy_next_state(self):
return self._policy_next_states[:len(self)]

@property
def mask(self):
return self._mask[:len(self)]

@property
def is_stateful(self):
return self._policy_states is not None
Expand Down
35 changes: 25 additions & 10 deletions mushroom_rl/core/_impl/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@

class TorchDataset(Serializable):
def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, flag_shape,
policy_state_shape):
policy_state_shape, mask_shape):

device = TorchUtils.get_device()
self._state_type = state_type
self._action_type = action_type

self._states = torch.empty(*state_shape, dtype=self._state_type, device=TorchUtils.get_device())
self._actions = torch.empty(*action_shape, dtype=self._action_type, device=TorchUtils.get_device())
self._rewards = torch.empty(*reward_shape, dtype=torch.float, device=TorchUtils.get_device())
self._next_states = torch.empty(*state_shape, dtype=self._state_type, device=TorchUtils.get_device())
self._absorbing = torch.empty(flag_shape, dtype=torch.bool, device=TorchUtils.get_device())
self._last = torch.empty(flag_shape, dtype=torch.bool, device=TorchUtils.get_device())
self._states = torch.empty(*state_shape, dtype=self._state_type, device=device)
self._actions = torch.empty(*action_shape, dtype=self._action_type, device=device)
self._rewards = torch.empty(*reward_shape, dtype=torch.float, device=device)
self._next_states = torch.empty(*state_shape, dtype=self._state_type, device=device)
self._absorbing = torch.empty(flag_shape, dtype=torch.bool, device=device)
self._last = torch.empty(flag_shape, dtype=torch.bool, device=device)
self._len = 0

if policy_state_shape is None:
self._policy_states = None
self._policy_next_states = None
else:
self._policy_states = torch.empty(policy_state_shape, dtype=torch.float)
self._policy_next_states = torch.empty(policy_state_shape, dtype=torch.float)
self._policy_states = torch.empty(policy_state_shape, dtype=torch.float, device=device)
self._policy_next_states = torch.empty(policy_state_shape, dtype=torch.float, device=device)

if mask_shape is None:
self._mask = None
else:
self._mask = torch.empty(mask_shape, dtype=torch.bool, device=device)

self._add_save_attr(
_state_type='primitive',
Expand Down Expand Up @@ -93,7 +100,8 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,
def __len__(self):
return self._len

def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None):
def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None,
mask=None):
i = self._len

self._states[i] = state
Expand All @@ -107,6 +115,9 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state

if mask is not None:
self._mask[i] = mask

self._len += 1

def clear(self):
Expand Down Expand Up @@ -199,6 +210,10 @@ def policy_next_state(self):
def is_stateful(self):
return self._policy_states is not None

@property
def mask(self):
return self._mask[:len(self)]

@property
def n_episodes(self):
n_episodes = self.last.sum()
Expand Down
47 changes: 24 additions & 23 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1

if n_envs == 1:
base_shape = (n_samples,)
mask_shape = None
else:
base_shape = (n_samples, n_envs)
mask_shape = base_shape

state_shape = base_shape + mdp_info.observation_space.shape
action_shape = base_shape + mdp_info.action_space.shape
Expand All @@ -44,12 +46,12 @@ def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1

if mdp_info.backend == 'numpy':
self._data = NumpyDataset(state_type, state_shape, action_type, action_shape, reward_shape, base_shape,
policy_state_shape)
policy_state_shape, mask_shape)
elif mdp_info.backend == 'torch':
self._data = TorchDataset(state_type, state_shape, action_type, action_shape, reward_shape, base_shape,
policy_state_shape)
policy_state_shape, mask_shape)
else:
self._data = ListDataset(policy_state_shape is not None)
self._data = ListDataset(policy_state_shape is not None, mask_shape is not None)

self._gamma = mdp_info.gamma

Expand Down Expand Up @@ -435,18 +437,14 @@ class VectorizedDataset(Dataset):
def __init__(self, mdp_info, agent_info, n_envs, n_steps=None, n_episodes=None):
super().__init__(mdp_info, agent_info, n_steps, n_episodes, n_envs)

self._length = self._array_backend.zeros(n_envs, dtype=int)

self._add_save_attr(
_length=mdp_info.backend
)

self._initialize_theta_list(n_envs)

def append_vectorized(self, step, info, mask):
self.append(step, {}) # FIXME: handle properly info
def append(self, step, info):
raise RuntimeError("Trying to use append on a vectorized dataset")

self._length[mask] += 1
def append_vectorized(self, step, info, mask):
self._data.append(*step, mask=mask)
self._append_info(self._info, {}) # FIXME: handle properly info

def append_theta_vectorized(self, theta, mask):
for i in range(len(theta)):
Expand All @@ -456,30 +454,28 @@ def append_theta_vectorized(self, theta, mask):
def clear(self):
n_envs = len(self._theta_list)
super().clear()

self._length = self._array_backend.zeros(len(self._length), dtype=int)
self._initialize_theta_list(n_envs)

def flatten(self):
if len(self) == 0:
return None

states = self._array_backend.pack_padded_sequence(self._data.state, self._length)
actions = self._array_backend.pack_padded_sequence(self._data.action, self._length)
rewards = self._array_backend.pack_padded_sequence(self._data.reward, self._length)
next_states = self._array_backend.pack_padded_sequence(self._data.next_state, self._length)
absorbings = self._array_backend.pack_padded_sequence(self._data.absorbing, self._length)
states = self._array_backend.pack_padded_sequence(self._data.state, self._data.mask)
actions = self._array_backend.pack_padded_sequence(self._data.action, self._data.mask)
rewards = self._array_backend.pack_padded_sequence(self._data.reward, self._data.mask)
next_states = self._array_backend.pack_padded_sequence(self._data.next_state, self._data.mask)
absorbings = self._array_backend.pack_padded_sequence(self._data.absorbing, self._data.mask)

last_padded = self._data.last
last_padded[self._length-1, :] = True
lasts = self._array_backend.pack_padded_sequence(last_padded, self._length)
last_padded[-1, :] = True
lasts = self._array_backend.pack_padded_sequence(last_padded, self._data.mask)

policy_state = None
policy_next_state = None

if self._data.is_stateful:
policy_state = self._array_backend.pack_padded_sequence(self._data.policy_state, self._length)
policy_next_state = self._array_backend.pack_padded_sequence(self._data.policy_next_state, self._length)
policy_state = self._array_backend.pack_padded_sequence(self._data.policy_state, self._data.mask)
policy_next_state = self._array_backend.pack_padded_sequence(self._data.policy_next_state, self._data.mask)

flat_theta_list = self._flatten_theta_list()

Expand All @@ -497,9 +493,14 @@ def _flatten_theta_list(self):
return flat_theta_list

def _initialize_theta_list(self, n_envs):
self._theta_list = list()
for i in range(n_envs):
self._theta_list.append(list())

@property
def mask(self):
return self._data.mask




Loading

0 comments on commit a668767

Please sign in to comment.