Skip to content

Commit

Permalink
Adapt for using n_episodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bjarne-55 committed Dec 9, 2024
1 parent a21a2c3 commit 8e7b1ca
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_
assert (render and record) or (not record), "To record, the render flag must be set to true"
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit)
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit, core_counts_episodes=n_episodes is not None)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -91,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa
self._core_logic.initialize_evaluate()

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset)
dataset = Dataset.generate(self.env.info, self.agent.info, n_steps, n_episodes_dataset, core_counts_episodes=n_episodes is not None)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down
23 changes: 14 additions & 9 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def create_replay_memory_info(mdp_info, agent_info, device=None):


class Dataset(Serializable):
def __init__(self, dataset_info, n_steps=None, n_episodes=None):
def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False):
assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None)

self._array_backend = ArrayBackend.get_array_backend(dataset_info.backend)
Expand All @@ -89,14 +89,17 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):

n_samples = horizon * n_episodes

if dataset_info.n_envs == 1:#TODO here is an error for evaluation
if dataset_info.n_envs == 1:
base_shape = (n_samples,)
mask_shape = None
elif n_episodes:
horizon = dataset_info.horizon
x = math.ceil(n_episodes / dataset_info.n_envs)
base_shape = (x * horizon, min(n_episodes, dataset_info.n_envs))
mask_shape = base_shape
elif core_counts_episodes:
base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1 + dataset_info.horizon, dataset_info.n_envs)
mask_shape = base_shape
else:
base_shape = (math.ceil(n_samples / dataset_info.n_envs) + 1, dataset_info.n_envs)
mask_shape = base_shape
Expand Down Expand Up @@ -133,10 +136,10 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None):
self._add_all_save_attr()

@classmethod
def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1):
def generate(cls, mdp_info, agent_info, n_steps=None, n_episodes=None, n_envs=1, core_counts_episodes=False):
dataset_info = DatasetInfo.create_dataset_info(mdp_info, agent_info, n_envs)

return cls(dataset_info, n_steps, n_episodes)
return cls(dataset_info, n_steps, n_episodes, core_counts_episodes)

@classmethod
def create_raw_instance(cls, dataset=None):
Expand Down Expand Up @@ -551,8 +554,8 @@ def _merge_info(info, other_info):


class VectorizedDataset(Dataset):
def __init__(self, dataset_info, n_steps=None, n_episodes=None):
super().__init__(dataset_info, n_steps, n_episodes)
def __init__(self, dataset_info, n_steps=None, n_episodes=None, core_counts_episodes=False):
super().__init__(dataset_info, n_steps, n_episodes, core_counts_episodes)

self._initialize_theta_list(self._dataset_info.n_envs)

Expand All @@ -568,7 +571,7 @@ def append_theta_vectorized(self, theta, mask):
if mask[i]:
self._theta_list[i].append(theta[i])

def clear(self, n_steps_per_fit=None):#TODO Problem 2 datasets exists at the same time, why even copy?
def clear(self, n_steps_per_fit=None):
n_envs = len(self._theta_list)
n_carry_forward_steps = 0

Expand All @@ -583,13 +586,15 @@ def clear(self, n_steps_per_fit=None):#TODO Problem 2 datasets exists at the sam
residual_data = self._data.get_view(view_size, copy=True)
mask = residual_data.mask
original_shape = mask.shape
mask.flatten()[n_extra_steps:] = False
mask = mask.flatten()
true_indices = self._array_backend.where(mask)[0]
mask[true_indices[n_extra_steps:]] = False
residual_data.mask = mask.reshape(original_shape)

residual_info = self._info.get_view(view_size, copy=True)
residual_episode_info = self._episode_info.get_view(view_size, copy=True)

n_carry_forward_steps = n_extra_steps
n_carry_forward_steps = mask.sum()

super().clear()
self._initialize_theta_list(n_envs)
Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = VectorizedDataset.generate(self.env.info, self.agent.info,
n_steps_per_fit, n_episodes_per_fit, self.env.number)
n_steps_per_fit, n_episodes_per_fit, self.env.number, n_episodes is not None)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -93,7 +93,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = VectorizedDataset.generate(self.env.info, self.agent.info,
n_steps, n_episodes_dataset, self.env.number)
n_steps, n_episodes_dataset, self.env.number, n_episodes is not None)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down

0 comments on commit 8e7b1ca

Please sign in to comment.