Skip to content

Commit

Permalink
Fix bug wrong dataset shapes (#157)
Browse files Browse the repository at this point in the history
* Fix bug wrong dataset shapes

* Adapt for using n_episodes
  • Loading branch information
Bjarne-55 authored Dec 24, 2024
1 parent 0a48399 commit edd0bd5
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 43 deletions.
4 changes: 2 additions & 2 deletions mushroom_rl/core/_impl/core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def after_step(self, last):
self._current_episodes_counter += 1
self._episodes_progress_bar.update(1)

def after_fit(self):
def after_fit(self, n_carry_forward_steps=0):
self._current_episodes_counter = 0
self._current_steps_counter = 0
self._current_steps_counter = n_carry_forward_steps

def terminate_run(self):
self._steps_progress_bar.close()
Expand Down
39 changes: 27 additions & 12 deletions mushroom_rl/core/_impl/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,36 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
mask=None):
i = self._len

self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last
if mask is None:
self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last

if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
else:
assert (policy_state is None) and (policy_next_state is None)
else:
assert (policy_state is None) and (policy_next_state is None)
n_active_envs = self._states.shape[1]

self._states[i] = state[:n_active_envs]
self._actions[i] = action[:n_active_envs]
self._rewards[i] = reward[:n_active_envs]
self._next_states[i] = next_state[:n_active_envs]
self._absorbing[i] = absorbing[:n_active_envs]
self._last[i] = last[:n_active_envs]

if self.is_stateful:
self._policy_states[i] = policy_state[:n_active_envs]
self._policy_next_states[i] = policy_next_state[:n_active_envs]
else:
assert (policy_state is None) and (policy_next_state is None)

if mask is not None:
self._mask[i] = mask
self._mask[i] = mask[:n_active_envs]

self._len += 1

Expand Down
35 changes: 24 additions & 11 deletions mushroom_rl/core/_impl/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,32 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
mask=None):
i = self._len # todo: handle index out of bounds?

self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last
if mask is None:
self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last

if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
else:
n_active_envs = self._states.shape[1]

self._states[i] = state[:n_active_envs]
self._actions[i] = action[:n_active_envs]
self._rewards[i] = reward[:n_active_envs]
self._next_states[i] = next_state[:n_active_envs]
self._absorbing[i] = absorbing[:n_active_envs]
self._last[i] = last[:n_active_envs]

if self.is_stateful:
self._policy_states[i] = policy_state[:n_active_envs]
self._policy_next_states[i] = policy_next_state[:n_active_envs]

if mask is not None:
self._mask[i] = mask
self._mask[i] = mask[:n_active_envs]

self._len += 1

Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/core/_impl/vectorized_core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def after_step(self, last):
self._current_episodes_counter += completed
self._episodes_progress_bar.update(completed)

def after_fit_vectorized(self, last):
super().after_fit()
def after_fit_vectorized(self, last, n_carry_forward_steps):
super().after_fit(n_carry_forward_steps)
if self._n_episodes_per_fit is not None:
self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool)
return self._array_backend.ones(self._n_envs, dtype=bool)
Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_
assert (render and record) or (not record), "To record, the render flag must be set to true"
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit)
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit, core_counts_episodes=n_episodes is not None)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -91,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa
self._core_logic.initialize_evaluate()

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset)
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset, core_counts_episodes=n_episodes is not None)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down
34 changes: 25 additions & 9 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import math

from collections import defaultdict

Expand Down Expand Up @@ -75,7 +76,7 @@ def create_replay_memory_info(mdp_info, agent_info, device=None):


class Dataset(Serializable):
def __init__(self, dataset_info, n_steps=None, n_episodes=None):
def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False):
assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None)

self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend)
Expand All @@ -91,8 +92,16 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
if dataset_info.n_envs == 1:
base_shape = (n_samples,)
mask_shape = None
elif n_episodes:
horizon = dataset_info.horizon
x = math.ceil(n_episodes / dataset_info.n_envs)
base_shape = (x * horizon, min(n_episodes, dataset_info.n_envs))
mask_shape = base_shape
elif core_counts_episodes:
base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1 + dataset_info.horizon, dataset_info.n_envs)
mask_shape = base_shape
else:
base_shape = (n_samples, dataset_info.n_envs)
base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1, dataset_info.n_envs)
mask_shape = base_shape

state_shape = base_shape + dataset_info.state_shape
Expand All @@ -104,8 +113,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
else:
policy_state_shape = None

self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._info = ExtraInfo(min(n_episodes, dataset_info.n_envs) if n_episodes else dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._episode_info = ExtraInfo(min(n_episodes, dataset_info.n_envs) if n_episodes else dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._theta_list = list()

if dataset_info.backend == 'numpy':
Expand All @@ -127,10 +136,10 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
self._add_all_save_attr()

@classmethod
def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1):
def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1, core_counts_episodes=False):
dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs)

return cls(dataset_info, n_steps, n_episodes)
return cls(dataset_info, n_steps, n_episodes, core_counts_episodes)

@classmethod
def create_raw_instance(cls, dataset=None):
Expand Down Expand Up @@ -545,8 +554,8 @@ def _merge_info(info, other_info):


class VectorizedDataset(Dataset):
def __init__(self, dataset_info, n_steps=None, n_episodes=None):
super().__init__(dataset_info, n_steps, n_episodes)
def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False):
super().__init__(dataset_info, n_steps, n_episodes, core_counts_episodes)

self._initialize_theta_list(self._dataset_info.n_envs)

Expand All @@ -564,6 +573,7 @@ def append_theta_vectorized(self, theta, mask):

def clear(self, n_steps_per_fit=None):
n_envs = len(self._theta_list)
n_carry_forward_steps = 0

residual_data = None
if n_steps_per_fit is not None:
Expand All @@ -576,12 +586,16 @@ def clear(self, n_steps_per_fit=None):
residual_data = self._data.get_view(view_size, copy=True)
mask = residual_data.mask
original_shape = mask.shape
mask.flatten()[n_extra_steps:] = False
mask = mask.flatten()
true_indices = self._array_backend.where(mask)[0]
mask[true_indices[n_extra_steps:]] = False
residual_data.mask = mask.reshape(original_shape)

residual_info = self._info.get_view(view_size, copy=True)
residual_episode_info = self._episode_info.get_view(view_size, copy=True)

n_carry_forward_steps = mask.sum()

super().clear()
self._initialize_theta_list(n_envs)

Expand All @@ -590,6 +604,8 @@ def clear(self, n_steps_per_fit=None):
self._info = residual_info
self._episode_info = residual_episode_info

return n_carry_forward_steps

def flatten(self, n_steps_per_fit=None):
if len(self) == 0:
return None
Expand Down
7 changes: 6 additions & 1 deletion mushroom_rl/core/extra_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def _append_dict_to_output(self, output, step_data, index, to):
for key, key_path in self._key_mapping.items():
value = self._find_element_by_key_path(step_data, key_path)
value = self._convert(value, to)
output[key][index] = value
if value is ArrayBackend.get_array_backend(to).none() or self._n_envs == 1:
output[key][index] = value
else:
output[key][index] = value[:self._n_envs]

def _append_list_to_output(self, output, step_data, index, to):
"""
Expand All @@ -318,6 +321,8 @@ def _append_list_to_output(self, output, step_data, index, to):
assert(self._n_envs > 1)
for key, key_path in self._key_mapping.items():
for i, env_data in enumerate(step_data):
if i >= self._n_envs:
break
value = self._find_element_by_key_path(env_data, key_path)
value = self._convert(value, to)
output[key][index][i] = value
Expand Down
8 changes: 4 additions & 4 deletions mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = VectorizedDataset.generate(self.env.info, self.agent.info,
n_steps_per_fit, n_episodes_per_fit, self.env.number)
n_steps_per_fit, n_episodes_per_fit, self.env.number, n_episodes is not None)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -93,7 +93,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = VectorizedDataset.generate(self.env.info, self.agent.info,
n_steps, n_episodes_dataset, self.env.number)
n_steps, n_episodes_dataset, self.env.number, n_episodes is not None)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down Expand Up @@ -123,12 +123,12 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat
if self._core_logic.fit_required():
fit_dataset = dataset.flatten(self._core_logic.n_steps_per_fit)
self.agent.fit(fit_dataset)
last = self._core_logic.after_fit_vectorized(last)

for c in self.callbacks_fit:
c(dataset)

dataset.clear(self._core_logic.n_steps_per_fit)
n_carry_forward_steps = dataset.clear(self._core_logic.n_steps_per_fit)
last = self._core_logic.after_fit_vectorized(last, n_carry_forward_steps)

self.agent.stop()
self.env.stop()
Expand Down

0 comments on commit edd0bd5

Please sign in to comment.