From edd0bd53c84b3e9912a41c96dbc14ed8643e2113 Mon Sep 17 00:00:00 2001 From: Bjarne-55 <73470930+Bjarne-55@users.noreply.github.com> Date: Tue, 24 Dec 2024 16:55:37 +0100 Subject: [PATCH] Fix bug wrong dataset shapes (#157) * Fix bug wrong dataset shapes * Adapt for using n_episodes --- mushroom_rl/core/_impl/core_logic.py | 4 +- mushroom_rl/core/_impl/numpy_dataset.py | 39 +++++++++++++------ mushroom_rl/core/_impl/torch_dataset.py | 35 +++++++++++------ .../core/_impl/vectorized_core_logic.py | 4 +- mushroom_rl/core/core.py | 4 +- mushroom_rl/core/dataset.py | 34 +++++++++++----- mushroom_rl/core/extra_info.py | 7 +++- mushroom_rl/core/vectorized_core.py | 8 ++-- 8 files changed, 92 insertions(+), 43 deletions(-) diff --git a/mushroom_rl/core/_impl/core_logic.py b/mushroom_rl/core/_impl/core_logic.py index db187e38..0e64f1d5 100644 --- a/mushroom_rl/core/_impl/core_logic.py +++ b/mushroom_rl/core/_impl/core_logic.py @@ -71,9 +71,9 @@ def after_step(self, last): self._current_episodes_counter += 1 self._episodes_progress_bar.update(1) - def after_fit(self): + def after_fit(self, n_carry_forward_steps=0): self._current_episodes_counter = 0 - self._current_steps_counter = 0 + self._current_steps_counter = n_carry_forward_steps def terminate_run(self): self._steps_progress_bar.close() diff --git a/mushroom_rl/core/_impl/numpy_dataset.py b/mushroom_rl/core/_impl/numpy_dataset.py index 4e07b2f3..05bc6652 100644 --- a/mushroom_rl/core/_impl/numpy_dataset.py +++ b/mushroom_rl/core/_impl/numpy_dataset.py @@ -114,21 +114,36 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat mask=None): i = self._len - self._states[i] = state - self._actions[i] = action - self._rewards[i] = reward - self._next_states[i] = next_state - self._absorbing[i] = absorbing - self._last[i] = last + if mask is None: + self._states[i] = state + self._actions[i] = action + self._rewards[i] = reward + self._next_states[i] = next_state + self._absorbing[i] = absorbing + self._last[i] = last - if self.is_stateful: - self._policy_states[i] = policy_state - self._policy_next_states[i] = policy_next_state + if self.is_stateful: + self._policy_states[i] = policy_state + self._policy_next_states[i] = policy_next_state + else: + assert (policy_state is None) and (policy_next_state is None) else: - assert (policy_state is None) and (policy_next_state is None) + n_active_envs = self._states.shape[1] + + self._states[i] = state[:n_active_envs] + self._actions[i] = action[:n_active_envs] + self._rewards[i] = reward[:n_active_envs] + self._next_states[i] = next_state[:n_active_envs] + self._absorbing[i] = absorbing[:n_active_envs] + self._last[i] = last[:n_active_envs] + + if self.is_stateful: + self._policy_states[i] = policy_state[:n_active_envs] + self._policy_next_states[i] = policy_next_state[:n_active_envs] + else: + assert (policy_state is None) and (policy_next_state is None) - if mask is not None: - self._mask[i] = mask + self._mask[i] = mask[:n_active_envs] self._len += 1 diff --git a/mushroom_rl/core/_impl/torch_dataset.py b/mushroom_rl/core/_impl/torch_dataset.py index 3bc9066e..54a18394 100644 --- a/mushroom_rl/core/_impl/torch_dataset.py +++ b/mushroom_rl/core/_impl/torch_dataset.py @@ -116,19 +116,32 @@ def append(self, state, action, reward, next_state, absorbing, last, policy_stat mask=None): i = self._len # todo: handle index out of bounds? - self._states[i] = state - self._actions[i] = action - self._rewards[i] = reward - self._next_states[i] = next_state - self._absorbing[i] = absorbing - self._last[i] = last + if mask is None: + self._states[i] = state + self._actions[i] = action + self._rewards[i] = reward + self._next_states[i] = next_state + self._absorbing[i] = absorbing + self._last[i] = last - if self.is_stateful: - self._policy_states[i] = policy_state - self._policy_next_states[i] = policy_next_state + if self.is_stateful: + self._policy_states[i] = policy_state + self._policy_next_states[i] = policy_next_state + else: + n_active_envs = self._states.shape[1] + + self._states[i] = state[:n_active_envs] + self._actions[i] = action[:n_active_envs] + self._rewards[i] = reward[:n_active_envs] + self._next_states[i] = next_state[:n_active_envs] + self._absorbing[i] = absorbing[:n_active_envs] + self._last[i] = last[:n_active_envs] + + if self.is_stateful: + self._policy_states[i] = policy_state[:n_active_envs] + self._policy_next_states[i] = policy_next_state[:n_active_envs] - if mask is not None: - self._mask[i] = mask + self._mask[i] = mask[:n_active_envs] self._len += 1 diff --git a/mushroom_rl/core/_impl/vectorized_core_logic.py b/mushroom_rl/core/_impl/vectorized_core_logic.py index 24ec612d..c9d0af69 100644 --- a/mushroom_rl/core/_impl/vectorized_core_logic.py +++ b/mushroom_rl/core/_impl/vectorized_core_logic.py @@ -63,8 +63,8 @@ def after_step(self, last): self._current_episodes_counter += completed self._episodes_progress_bar.update(completed) - def after_fit_vectorized(self, last): - super().after_fit() + def after_fit_vectorized(self, last, n_carry_forward_steps): + super().after_fit(n_carry_forward_steps) if self._n_episodes_per_fit is not None: self._running_envs = self._array_backend.zeros(self._n_envs, dtype=bool) return self._array_backend.ones(self._n_envs, dtype=bool) diff --git a/mushroom_rl/core/core.py b/mushroom_rl/core/core.py index 3478029c..d5749367 100644 --- a/mushroom_rl/core/core.py +++ b/mushroom_rl/core/core.py @@ -63,7 +63,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_ assert (render and record) or (not record), "To record, the render flag must be set to true" self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) - dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) + dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit, core_counts_episodes=n_episodes is not None) self._run(dataset, n_steps, n_episodes, render, quiet, record) @@ -91,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa self._core_logic.initialize_evaluate() n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes - dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset) + dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset, core_counts_episodes=n_episodes is not None) return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states) diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index e0016736..367f7383 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -1,4 +1,5 @@ import numpy as np +import math from collections import defaultdict @@ -75,7 +76,7 @@ def create_replay_memory_info(mdp_info, agent_info, device=None): class Dataset(Serializable): - def __init__(self, dataset_info, n_steps=None, n_episodes=None): + def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False): assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None) self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend) @@ -91,8 +92,16 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): if dataset_info.n_envs == 1: base_shape = (n_samples,) mask_shape = None + elif n_episodes: + horizon = dataset_info.horizon + x = math.ceil(n_episodes / dataset_info.n_envs) + base_shape = (x * horizon, min(n_episodes, dataset_info.n_envs)) + mask_shape = base_shape + elif core_counts_episodes: + base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1 + dataset_info.horizon, dataset_info.n_envs) + mask_shape = base_shape else: - base_shape = (n_samples, dataset_info.n_envs) + base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1, dataset_info.n_envs) mask_shape = base_shape state_shape = base_shape + dataset_info.state_shape @@ -104,8 +113,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): else: policy_state_shape = None - self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device) - self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device) + self._info = ExtraInfo(min(n_episodes, dataset_info.n_envs) if n_episodes else dataset_info.n_envs, dataset_info.backend, dataset_info.device) + self._episode_info = ExtraInfo(min(n_episodes, dataset_info.n_envs) if n_episodes else dataset_info.n_envs, dataset_info.backend, dataset_info.device) self._theta_list = list() if dataset_info.backend == 'numpy': @@ -127,10 +136,10 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): self._add_all_save_attr() @classmethod - def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1): + def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1, core_counts_episodes=False): dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs) - return cls(dataset_info, n_steps, n_episodes) + return cls(dataset_info, n_steps, n_episodes, core_counts_episodes) @classmethod def create_raw_instance(cls, dataset=None): @@ -545,8 +554,8 @@ def _merge_info(info, other_info): class VectorizedDataset(Dataset): - def __init__(self, dataset_info, n_steps=None, n_episodes=None): - super().__init__(dataset_info, n_steps, n_episodes) + def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False): + super().__init__(dataset_info, n_steps, n_episodes, core_counts_episodes) self._initialize_theta_list(self._dataset_info.n_envs) @@ -564,6 +573,7 @@ def append_theta_vectorized(self, theta, mask): def clear(self, n_steps_per_fit=None): n_envs = len(self._theta_list) + n_carry_forward_steps = 0 residual_data = None if n_steps_per_fit is not None: @@ -576,12 +586,16 @@ def clear(self, n_steps_per_fit=None): residual_data = self._data.get_view(view_size, copy=True) mask = residual_data.mask original_shape = mask.shape - mask.flatten()[n_extra_steps:] = False + mask = mask.flatten() + true_indices = self._array_backend.where(mask)[0] + mask[true_indices[n_extra_steps:]] = False residual_data.mask = mask.reshape(original_shape) residual_info = self._info.get_view(view_size, copy=True) residual_episode_info = self._episode_info.get_view(view_size, copy=True) + n_carry_forward_steps = mask.sum() + super().clear() self._initialize_theta_list(n_envs) @@ -590,6 +604,8 @@ def clear(self, n_steps_per_fit=None): self._info = residual_info self._episode_info = residual_episode_info + return n_carry_forward_steps + def flatten(self, n_steps_per_fit=None): if len(self) == 0: return None diff --git a/mushroom_rl/core/extra_info.py b/mushroom_rl/core/extra_info.py index 1b1faf75..0dc3fb88 100644 --- a/mushroom_rl/core/extra_info.py +++ b/mushroom_rl/core/extra_info.py @@ -303,7 +303,10 @@ def _append_dict_to_output(self, output, step_data, index, to): for key, key_path in self._key_mapping.items(): value = self._find_element_by_key_path(step_data, key_path) value = self._convert(value, to) - output[key][index] = value + if value is ArrayBackend.get_array_backend(to).none() or self._n_envs == 1: + output[key][index] = value + else: + output[key][index] = value[:self._n_envs] def _append_list_to_output(self, output, step_data, index, to): """ @@ -318,6 +321,8 @@ def _append_list_to_output(self, output, step_data, index, to): assert(self._n_envs > 1) for key, key_path in self._key_mapping.items(): for i, env_data in enumerate(step_data): + if i >= self._n_envs: + break value = self._find_element_by_key_path(env_data, key_path) value = self._convert(value, to) output[key][index][i] = value diff --git a/mushroom_rl/core/vectorized_core.py b/mushroom_rl/core/vectorized_core.py index 58fdf5d7..b0e0c8be 100644 --- a/mushroom_rl/core/vectorized_core.py +++ b/mushroom_rl/core/vectorized_core.py @@ -64,7 +64,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_ self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) dataset = VectorizedDataset.generate(self.env.info, self.agent.info, - n_steps_per_fit, n_episodes_per_fit, self.env.number) + n_steps_per_fit, n_episodes_per_fit, self.env.number, n_episodes is not None) self._run(dataset, n_steps, n_episodes, render, quiet, record) @@ -93,7 +93,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes dataset = VectorizedDataset.generate(self.env.info, self.agent.info, - n_steps, n_episodes_dataset, self.env.number) + n_steps, n_episodes_dataset, self.env.number, n_episodes is not None) return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states) @@ -123,12 +123,12 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat if self._core_logic.fit_required(): fit_dataset = dataset.flatten(self._core_logic.n_steps_per_fit) self.agent.fit(fit_dataset) - last = self._core_logic.after_fit_vectorized(last) for c in self.callbacks_fit: c(dataset) - dataset.clear(self._core_logic.n_steps_per_fit) + n_carry_forward_steps = dataset.clear(self._core_logic.n_steps_per_fit) + last = self._core_logic.after_fit_vectorized(last, n_carry_forward_steps) self.agent.stop() self.env.stop()