Skip to content

Commit

Permalink
Episode start vectorized fixes (#136)
Browse files Browse the repository at this point in the history
* Fixed the wrong is_episodic flag setting

* Fixed the way of handling policy states and theta in vectorized environments case

* Moved policy reset after theta initialization
  • Loading branch information
pkicki authored Dec 20, 2023
1 parent 86fc83c commit 5a2a26d
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
_dim_env_state='primitive'
)

super().__init__(mdp_info, policy, None)
super().__init__(mdp_info, policy, is_episodic=False)

# add the standardization preprocessor
self._preprocessors.append(StandardizationPreprocessor(mdp_info))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,22 @@ def episode_start(self, initial_state, episode_info):

return policy_state, theta

def episode_start_vectorized(self, initial_states, episode_info, n_envs):
def episode_start_vectorized(self, initial_states, episode_info, start_mask):
n_envs = len(start_mask)
if not isinstance(self.policy, VectorPolicy):
self.policy = VectorPolicy(self.policy, n_envs)
elif len(self.policy) != n_envs:
self.policy.set_n(n_envs)

theta = [self.distribution.sample() for _ in range(n_envs)]
theta = self.policy.get_weights()
if np.any(start_mask):
theta[start_mask] = np.array([self.distribution.sample() for _ in range(np.sum(start_mask))])
self.policy.set_weights(theta)

self.policy.set_weights(theta)
policy_states = self.policy.reset()

policy_state, _ = super().episode_start(initial_states, episode_info)
return policy_states, theta

return policy_state, theta

def fit(self, dataset):
Jep = np.array(dataset.discounted_return)
Expand Down
6 changes: 3 additions & 3 deletions mushroom_rl/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ def episode_start(self, initial_state, episode_info):
"""
return self.policy.reset(), None

def episode_start_vectorized(self, initial_states, episode_info, n_envs):
def episode_start_vectorized(self, initial_states, episode_info, start_mask):
"""
Called by the VectorCore when a new episode starts.
Args:
initial_states (array): the initial states of the environment.
episode_info (dict): a dictionary containing the information at reset, such as context;
n_envs (int): number of environments in parallel to run.
start_mask (array): boolean mask to select the environments that are starting a new episode
Returns:
A tuple containing the policy initial state and, optionally, the policy parameters
A tuple containing the policy initial states and, optionally, the policy parameters
"""
return self.episode_start(initial_states, episode_info)
Expand Down
18 changes: 11 additions & 7 deletions mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, agent, env, callbacks_fit=None, callback_step=None, record_di

self._state = None
self._policy_state = None
self._current_theta = None
self._episode_steps = None

self._core_logic = VectorizedCoreLogic(self.env.info.backend, self.env.number)
Expand Down Expand Up @@ -106,10 +105,10 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat
while self._core_logic.move_required():
if last.any():
mask = self._core_logic.get_mask(last)
reset_mask = self._reset(initial_states, last, mask)
current_theta, reset_mask = self._reset(initial_states, last, mask)

if self.agent.info.is_episodic and reset_mask.any():
dataset.append_theta_vectorized(self._current_theta, reset_mask)
dataset.append_theta_vectorized(current_theta, reset_mask)

samples, step_infos = self._step(render, record, mask)

Expand Down Expand Up @@ -180,11 +179,17 @@ def _reset(self, initial_states, last, mask):
"""
reset_mask = last & mask

initial_state = self._core_logic.get_initial_state(initial_states)

state, episode_info = self._preprocess(self.env.reset_all(reset_mask, initial_state))
self._policy_state, self._current_theta = self.agent.episode_start_vectorized(episode_info, state,
self.env.number)

policy_state, current_theta = self.agent.episode_start_vectorized(state, episode_info, reset_mask)
if self._policy_state is None or policy_state is None:
self._policy_state = policy_state
elif reset_mask.any():
self._policy_state[reset_mask] = policy_state[reset_mask]

self._state = self._preprocess(state)
self.agent.next_action = None

Expand All @@ -193,12 +198,11 @@ def _reset(self, initial_states, last, mask):
else:
self._episode_steps[last] = 0

return reset_mask
return current_theta, reset_mask

def _end(self, record):
self._state = None
self._policy_state = None
self._current_theta = None
self._episode_steps = None

if record:
Expand Down
31 changes: 21 additions & 10 deletions mushroom_rl/policy/vector_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_weights(self):
weights_i = policy.get_weights()
weight_list.append(weights_i)

return weight_list
return np.array(weight_list)

@property
def weights_size(self):
Expand All @@ -83,15 +83,26 @@ def weights_size(self):
"""
return len(self), self._policy_vector[0].weights_size

def reset(self):
policy_states = list()
for i, policy in enumerate(self._policy_vector):
policy_state = policy.reset()

if policy_state is not None:
policy_states.append(policy_state)

return None if len(policy_states) == 0 else np.array(policy_states)
def reset(self, mask=None):
policy_states = None
if self.policy_state_shape is None:
if mask is None:
for policy in self._policy_vector:
policy.reset()
else:
for masked, policy in zip(mask, self._policy_vector):
if masked:
policy.reset()
else:
policy_states = np.empty((len(self._policy_vector),) + self.policy_state_shape)
if mask is None:
for i, policy in enumerate(self._policy_vector):
policy_states[i] = policy.reset()
else:
for i, (masked, policy) in enumerate(zip(mask, self._policy_vector)):
if masked:
policy_states[i] = policy.reset()
return policy_states

def __len__(self):
return len(self._policy_vector)
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self, mdp_info, backend):
def fit(self, dataset):
assert len(dataset.theta_list) == 5

def episode_start_vectorized(self, initial_states, episode_info, n_envs):
def episode_start_vectorized(self, initial_states, episode_info, start_mask):
n_envs = len(start_mask)
current_count = self._counter
self._counter += 1
if self._backend == 'torch':
Expand Down

0 comments on commit 5a2a26d

Please sign in to comment.