Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug wrong dataset shapes #157

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mushroom_rl/core/_impl/core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def after_step(self, last):
self._current_episodes_counter += 1
self._episodes_progress_bar.update(1)

def after_fit(self):
def after_fit(self, n_carry_forward_steps=0):
self._current_episodes_counter = 0
self._current_steps_counter = 0
self._current_steps_counter = n_carry_forward_steps

def terminate_run(self):
self._steps_progress_bar.close()
Expand Down
39 changes: 27 additions & 12 deletions mushroom_rl/core/_impl/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,36 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
mask=None):
i = self._len

self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last
if mask is None:
self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last

if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
else:
assert (policy_state is None) and (policy_next_state is None)
else:
assert (policy_state is None) and (policy_next_state is None)
n_active_envs = self._states.shape[1]

self._states[i] = state[:n_active_envs]
self._actions[i] = action[:n_active_envs]
self._rewards[i] = reward[:n_active_envs]
self._next_states[i] = next_state[:n_active_envs]
self._absorbing[i] = absorbing[:n_active_envs]
self._last[i] = last[:n_active_envs]

if self.is_stateful:
self._policy_states[i] = policy_state[:n_active_envs]
self._policy_next_states[i] = policy_next_state[:n_active_envs]
else:
assert (policy_state is None) and (policy_next_state is None)

if mask is not None:
self._mask[i] = mask
self._mask[i] = mask[:n_active_envs]

self._len += 1

Expand Down
35 changes: 24 additions & 11 deletions mushroom_rl/core/_impl/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,32 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat
mask=None):
i = self._len # todo: handle index out of bounds?

self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last
if mask is None:
self._states[i] = state
self._actions[i] = action
self._rewards[i] = reward
self._next_states[i] = next_state
self._absorbing[i] = absorbing
self._last[i] = last

if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
if self.is_stateful:
self._policy_states[i] = policy_state
self._policy_next_states[i] = policy_next_state
else:
n_active_envs = self._states.shape[1]

self._states[i] = state[:n_active_envs]
self._actions[i] = action[:n_active_envs]
self._rewards[i] = reward[:n_active_envs]
self._next_states[i] = next_state[:n_active_envs]
self._absorbing[i] = absorbing[:n_active_envs]
self._last[i] = last[:n_active_envs]

if self.is_stateful:
self._policy_states[i] = policy_state[:n_active_envs]
self._policy_next_states[i] = policy_next_state[:n_active_envs]

if mask is not None:
self._mask[i] = mask
self._mask[i] = mask[:n_active_envs]

self._len += 1

Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/core/_impl/vectorized_core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def after_step(self, last):
self._current_episodes_counter += completed
self._episodes_progress_bar.update(completed)

def after_fit_vectorized(self, last):
super().after_fit()
def after_fit_vectorized(self, last, n_carry_forward_steps):
super().after_fit(n_carry_forward_steps)
if self._n_episodes_per_fit is not None:
self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool)
return self._array_backend.ones(self._n_envs, dtype=bool)
Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_
assert (render and record) or (not record), "To record, the render flag must be set to true"
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit)
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit, core_counts_episodes=n_episodes is not None)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -91,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa
self._core_logic.initialize_evaluate()

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset)
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset, core_counts_episodes=n_episodes is not None)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down
34 changes: 25 additions & 9 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import math

from collections import defaultdict

Expand Down Expand Up @@ -75,7 +76,7 @@ def create_replay_memory_info(mdp_info, agent_info, device=None):


class Dataset(Serializable):
def __init__(self, dataset_info, n_steps=None, n_episodes=None):
def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False):
assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None)

self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend)
Expand All @@ -91,8 +92,16 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
if dataset_info.n_envs == 1:
base_shape = (n_samples,)
mask_shape = None
elif n_episodes:
horizon = dataset_info.horizon
x = math.ceil(n_episodes / dataset_info.n_envs)
base_shape = (x * horizon, min(n_episodes, dataset_info.n_envs))
mask_shape = base_shape
elif core_counts_episodes:
base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1 + dataset_info.horizon, dataset_info.n_envs)
mask_shape = base_shape
else:
base_shape = (n_samples, dataset_info.n_envs)
base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1, dataset_info.n_envs)
mask_shape = base_shape

state_shape = base_shape + dataset_info.state_shape
Expand All @@ -104,8 +113,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
else:
policy_state_shape = None

self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._info = ExtraInfo(min(n_episodes, dataset_info.n_envs) if n_episodes else dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._episode_info = ExtraInfo(min(n_episodes, dataset_info.n_envs) if n_episodes else dataset_info.n_envs, dataset_info.backend, dataset_info.device)
self._theta_list = list()

if dataset_info.backend == 'numpy':
Expand All @@ -127,10 +136,10 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
self._add_all_save_attr()

@classmethod
def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1):
def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1, core_counts_episodes=False):
dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs)

return cls(dataset_info, n_steps, n_episodes)
return cls(dataset_info, n_steps, n_episodes, core_counts_episodes)

@classmethod
def create_raw_instance(cls, dataset=None):
Expand Down Expand Up @@ -545,8 +554,8 @@ def _merge_info(info, other_info):


class VectorizedDataset(Dataset):
def __init__(self, dataset_info, n_steps=None, n_episodes=None):
super().__init__(dataset_info, n_steps, n_episodes)
def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False):
super().__init__(dataset_info, n_steps, n_episodes, core_counts_episodes)

self._initialize_theta_list(self._dataset_info.n_envs)

Expand All @@ -564,6 +573,7 @@ def append_theta_vectorized(self, theta, mask):

def clear(self, n_steps_per_fit=None):
n_envs = len(self._theta_list)
n_carry_forward_steps = 0

residual_data = None
if n_steps_per_fit is not None:
Expand All @@ -576,12 +586,16 @@ def clear(self, n_steps_per_fit=None):
residual_data = self._data.get_view(view_size, copy=True)
mask = residual_data.mask
original_shape = mask.shape
mask.flatten()[n_extra_steps:] = False
mask = mask.flatten()
true_indices = self._array_backend.where(mask)[0]
mask[true_indices[n_extra_steps:]] = False
residual_data.mask = mask.reshape(original_shape)

residual_info = self._info.get_view(view_size, copy=True)
residual_episode_info = self._episode_info.get_view(view_size, copy=True)

n_carry_forward_steps = mask.sum()

super().clear()
self._initialize_theta_list(n_envs)

Expand All @@ -590,6 +604,8 @@ def clear(self, n_steps_per_fit=None):
self._info = residual_info
self._episode_info = residual_episode_info

return n_carry_forward_steps

def flatten(self, n_steps_per_fit=None):
if len(self) == 0:
return None
Expand Down
7 changes: 6 additions & 1 deletion mushroom_rl/core/extra_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def _append_dict_to_output(self, output, step_data, index, to):
for key, key_path in self._key_mapping.items():
value = self._find_element_by_key_path(step_data, key_path)
value = self._convert(value, to)
output[key][index] = value
if value is ArrayBackend.get_array_backend(to).none() or self._n_envs == 1:
output[key][index] = value
else:
output[key][index] = value[:self._n_envs]

def _append_list_to_output(self, output, step_data, index, to):
"""
Expand All @@ -318,6 +321,8 @@ def _append_list_to_output(self, output, step_data, index, to):
assert(self._n_envs > 1)
for key, key_path in self._key_mapping.items():
for i, env_data in enumerate(step_data):
if i >= self._n_envs:
break
value = self._find_element_by_key_path(env_data, key_path)
value = self._convert(value, to)
output[key][index][i] = value
Expand Down
8 changes: 4 additions & 4 deletions mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = VectorizedDataset.generate(self.env.info, self.agent.info,
n_steps_per_fit, n_episodes_per_fit, self.env.number)
n_steps_per_fit, n_episodes_per_fit, self.env.number, n_episodes is not None)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -93,7 +93,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = VectorizedDataset.generate(self.env.info, self.agent.info,
n_steps, n_episodes_dataset, self.env.number)
n_steps, n_episodes_dataset, self.env.number, n_episodes is not None)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down Expand Up @@ -123,12 +123,12 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat
if self._core_logic.fit_required():
fit_dataset = dataset.flatten(self._core_logic.n_steps_per_fit)
self.agent.fit(fit_dataset)
last = self._core_logic.after_fit_vectorized(last)

for c in self.callbacks_fit:
c(dataset)

dataset.clear(self._core_logic.n_steps_per_fit)
n_carry_forward_steps = dataset.clear(self._core_logic.n_steps_per_fit)
last = self._core_logic.after_fit_vectorized(last, n_carry_forward_steps)

self.agent.stop()
self.env.stop()
Expand Down
Loading