diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py index 7d853137..7e0243dd 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py @@ -69,7 +69,7 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, _dim_env_state='primitive' ) - super().__init__(mdp_info, policy, None) + super().__init__(mdp_info, policy, is_episodic=False) # add the standardization preprocessor self._preprocessors.append(StandardizationPreprocessor(mdp_info)) diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py index 3576c5c7..d12faffe 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py @@ -37,19 +37,22 @@ def episode_start(self, initial_state, episode_info): return policy_state, theta - def episode_start_vectorized(self, initial_states, episode_info, n_envs): + def episode_start_vectorized(self, initial_states, episode_info, start_mask): + n_envs = len(start_mask) if not isinstance(self.policy, VectorPolicy): self.policy = VectorPolicy(self.policy, n_envs) elif len(self.policy) != n_envs: self.policy.set_n(n_envs) - theta = [self.distribution.sample() for _ in range(n_envs)] + theta = self.policy.get_weights() + if np.any(start_mask): + theta[start_mask] = np.array([self.distribution.sample() for _ in range(np.sum(start_mask))]) + self.policy.set_weights(theta) - self.policy.set_weights(theta) + policy_states = self.policy.reset() - policy_state, _ = super().episode_start(initial_states, episode_info) + return policy_states, theta - return policy_state, theta def fit(self, dataset): Jep = np.array(dataset.discounted_return) diff --git a/mushroom_rl/core/agent.py b/mushroom_rl/core/agent.py index a8a55391..330b897e 100644 --- a/mushroom_rl/core/agent.py +++ b/mushroom_rl/core/agent.py @@ -112,17 +112,17 @@ def episode_start(self, initial_state, episode_info): """ return self.policy.reset(), None - def episode_start_vectorized(self, initial_states, episode_info, n_envs): + def episode_start_vectorized(self, initial_states, episode_info, start_mask): """ Called by the VectorCore when a new episode starts. Args: initial_states (array): the initial states of the environment. episode_info (dict): a dictionary containing the information at reset, such as context; - n_envs (int): number of environments in parallel to run. + start_mask (array): boolean mask to select the environments that are starting a new episode Returns: - A tuple containing the policy initial state and, optionally, the policy parameters + A tuple containing the policy initial states and, optionally, the policy parameters """ return self.episode_start(initial_states, episode_info) diff --git a/mushroom_rl/core/vectorized_core.py b/mushroom_rl/core/vectorized_core.py index d817c518..69eaf5ff 100644 --- a/mushroom_rl/core/vectorized_core.py +++ b/mushroom_rl/core/vectorized_core.py @@ -31,7 +31,6 @@ def __init__(self, agent, env, callbacks_fit=None, callback_step=None, record_di self._state = None self._policy_state = None - self._current_theta = None self._episode_steps = None self._core_logic = VectorizedCoreLogic(self.env.info.backend, self.env.number) @@ -106,10 +105,10 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat while self._core_logic.move_required(): if last.any(): mask = self._core_logic.get_mask(last) - reset_mask = self._reset(initial_states, last, mask) + current_theta, reset_mask = self._reset(initial_states, last, mask) if self.agent.info.is_episodic and reset_mask.any(): - dataset.append_theta_vectorized(self._current_theta, reset_mask) + dataset.append_theta_vectorized(current_theta, reset_mask) samples, step_infos = self._step(render, record, mask) @@ -180,11 +179,17 @@ def _reset(self, initial_states, last, mask): """ reset_mask = last & mask + initial_state = self._core_logic.get_initial_state(initial_states) state, episode_info = self._preprocess(self.env.reset_all(reset_mask, initial_state)) - self._policy_state, self._current_theta = self.agent.episode_start_vectorized(episode_info, state, - self.env.number) + + policy_state, current_theta = self.agent.episode_start_vectorized(state, episode_info, reset_mask) + if self._policy_state is None or policy_state is None: + self._policy_state = policy_state + elif reset_mask.any(): + self._policy_state[reset_mask] = policy_state[reset_mask] + self._state = self._preprocess(state) self.agent.next_action = None @@ -193,12 +198,11 @@ def _reset(self, initial_states, last, mask): else: self._episode_steps[last] = 0 - return reset_mask + return current_theta, reset_mask def _end(self, record): self._state = None self._policy_state = None - self._current_theta = None self._episode_steps = None if record: diff --git a/mushroom_rl/policy/vector_policy.py b/mushroom_rl/policy/vector_policy.py index f05aa0c6..ac2ec1dd 100644 --- a/mushroom_rl/policy/vector_policy.py +++ b/mushroom_rl/policy/vector_policy.py @@ -70,7 +70,7 @@ def get_weights(self): weights_i = policy.get_weights() weight_list.append(weights_i) - return weight_list + return np.array(weight_list) @property def weights_size(self): @@ -83,15 +83,26 @@ def weights_size(self): """ return len(self), self._policy_vector[0].weights_size - def reset(self): - policy_states = list() - for i, policy in enumerate(self._policy_vector): - policy_state = policy.reset() - - if policy_state is not None: - policy_states.append(policy_state) - - return None if len(policy_states) == 0 else np.array(policy_states) + def reset(self, mask=None): + policy_states = None + if self.policy_state_shape is None: + if mask is None: + for policy in self._policy_vector: + policy.reset() + else: + for masked, policy in zip(mask, self._policy_vector): + if masked: + policy.reset() + else: + policy_states = np.empty((len(self._policy_vector),) + self.policy_state_shape) + if mask is None: + for i, policy in enumerate(self._policy_vector): + policy_states[i] = policy.reset() + else: + for i, (masked, policy) in enumerate(zip(mask, self._policy_vector)): + if masked: + policy_states[i] = policy.reset() + return policy_states def __len__(self): return len(self._policy_vector) diff --git a/tests/core/test_vectorized_core.py b/tests/core/test_vectorized_core.py index 238805e5..6e166631 100644 --- a/tests/core/test_vectorized_core.py +++ b/tests/core/test_vectorized_core.py @@ -42,7 +42,8 @@ def __init__(self, mdp_info, backend): def fit(self, dataset): assert len(dataset.theta_list) == 5 - def episode_start_vectorized(self, initial_states, episode_info, n_envs): + def episode_start_vectorized(self, initial_states, episode_info, start_mask): + n_envs = len(start_mask) current_count = self._counter self._counter += 1 if self._backend == 'torch':