From 8e7b1caf520c074cf7f3c970073c89106ed51741 Mon Sep 17 00:00:00 2001 From: Bjarne-55 <73470930+Bjarne-55@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:04:23 +0100 Subject: [PATCH] Adapt for using n_episodes --- mushroom_rl/core/core.py | 4 ++-- mushroom_rl/core/dataset.py | 23 ++++++++++++++--------- mushroom_rl/core/vectorized_core.py | 4 ++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/mushroom_rl/core/core.py b/mushroom_rl/core/core.py index 3478029c..d5749367 100644 --- a/mushroom_rl/core/core.py +++ b/mushroom_rl/core/core.py @@ -63,7 +63,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_ assert (render and record) or (not record), "To record, the render flag must be set to true" self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) - dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) + dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit, core_counts_episodes=n_episodes is not None) self._run(dataset, n_steps, n_episodes, render, quiet, record) @@ -91,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa self._core_logic.initialize_evaluate() n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes - dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset) + dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset, core_counts_episodes=n_episodes is not None) return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states) diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index 0263abba..367f7383 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -76,7 +76,7 @@ def create_replay_memory_info(mdp_info, agent_info, device=None): class Dataset(Serializable): - def __init__(self, dataset_info, n_steps=None, n_episodes=None): + def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False): assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None) self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend) @@ -89,7 +89,7 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): n_samples = horizon * n_episodes - if dataset_info.n_envs == 1:#TODO here is an error for evaluation + if dataset_info.n_envs == 1: base_shape = (n_samples,) mask_shape = None elif n_episodes: @@ -97,6 +97,9 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): x = math.ceil(n_episodes / dataset_info.n_envs) base_shape = (x * horizon, min(n_episodes, dataset_info.n_envs)) mask_shape = base_shape + elif core_counts_episodes: + base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1 + dataset_info.horizon, dataset_info.n_envs) + mask_shape = base_shape else: base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1, dataset_info.n_envs) mask_shape = base_shape @@ -133,10 +136,10 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): self._add_all_save_attr() @classmethod - def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1): + def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1, core_counts_episodes=False): dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs) - return cls(dataset_info, n_steps, n_episodes) + return cls(dataset_info, n_steps, n_episodes, core_counts_episodes) @classmethod def create_raw_instance(cls, dataset=None): @@ -551,8 +554,8 @@ def _merge_info(info, other_info): class VectorizedDataset(Dataset): - def __init__(self, dataset_info, n_steps=None, n_episodes=None): - super().__init__(dataset_info, n_steps, n_episodes) + def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False): + super().__init__(dataset_info, n_steps, n_episodes, core_counts_episodes) self._initialize_theta_list(self._dataset_info.n_envs) @@ -568,7 +571,7 @@ def append_theta_vectorized(self, theta, mask): if mask[i]: self._theta_list[i].append(theta[i]) - def clear(self, n_steps_per_fit=None):#TODO Problem 2 datasets exists at the same time, why even copy? + def clear(self, n_steps_per_fit=None): n_envs = len(self._theta_list) n_carry_forward_steps = 0 @@ -583,13 +586,15 @@ def clear(self, n_steps_per_fit=None):#TODO Problem 2 datasets exists at the sam residual_data = self._data.get_view(view_size, copy=True) mask = residual_data.mask original_shape = mask.shape - mask.flatten()[n_extra_steps:] = False + mask = mask.flatten() + true_indices = self._array_backend.where(mask)[0] + mask[true_indices[n_extra_steps:]] = False residual_data.mask = mask.reshape(original_shape) residual_info = self._info.get_view(view_size, copy=True) residual_episode_info = self._episode_info.get_view(view_size, copy=True) - n_carry_forward_steps = n_extra_steps + n_carry_forward_steps = mask.sum() super().clear() self._initialize_theta_list(n_envs) diff --git a/mushroom_rl/core/vectorized_core.py b/mushroom_rl/core/vectorized_core.py index 02a81ee9..b0e0c8be 100644 --- a/mushroom_rl/core/vectorized_core.py +++ b/mushroom_rl/core/vectorized_core.py @@ -64,7 +64,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_ self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) dataset = VectorizedDataset.generate(self.env.info, self.agent.info, - n_steps_per_fit, n_episodes_per_fit, self.env.number) + n_steps_per_fit, n_episodes_per_fit, self.env.number, n_episodes is not None) self._run(dataset, n_steps, n_episodes, render, quiet, record) @@ -93,7 +93,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes dataset = VectorizedDataset.generate(self.env.info, self.agent.info, - n_steps, n_episodes_dataset, self.env.number) + n_steps, n_episodes_dataset, self.env.number, n_episodes is not None) return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)