-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- added vectorized policies for black box optimization - fixed bug in datasets backends - added black box optimization test - work in progress on vectorized dataset, still many issues needs to be solved
- Loading branch information
1 parent
42ab32d
commit 8d0be98
Showing
15 changed files
with
257 additions
and
10 deletions.
There are no files selected for viewing
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
|
||
from mushroom_rl.core import VectorCore, Logger, MultiprocessEnvironment | ||
from mushroom_rl.environments.segway import Segway | ||
from mushroom_rl.algorithms.policy_search import * | ||
from mushroom_rl.policy import DeterministicPolicy | ||
from mushroom_rl.distributions import GaussianDiagonalDistribution | ||
from mushroom_rl.approximators import Regressor | ||
from mushroom_rl.approximators.parametric import LinearApproximator | ||
from mushroom_rl.utils.callbacks import CollectDataset | ||
from mushroom_rl.rl_utils.optimizers import AdaptiveOptimizer | ||
|
||
from tqdm import tqdm, trange | ||
tqdm.monitor_interval = 0 | ||
|
||
|
||
def experiment(alg, params, n_epochs, n_episodes, n_ep_per_fit): | ||
np.random.seed() | ||
|
||
logger = Logger(alg.__name__, results_dir=None) | ||
logger.strong_line() | ||
logger.info('Experiment Algorithm: ' + alg.__name__) | ||
|
||
# MDP | ||
mdp = MultiprocessEnvironment(Segway, n_envs=15) | ||
|
||
# Policy | ||
approximator = Regressor(LinearApproximator, | ||
input_shape=mdp.info.observation_space.shape, | ||
output_shape=mdp.info.action_space.shape) | ||
|
||
n_weights = approximator.weights_size | ||
mu = np.zeros(n_weights) | ||
sigma = 2e-0 * np.ones(n_weights) | ||
policy = DeterministicPolicy(approximator) | ||
dist = GaussianDiagonalDistribution(mu, sigma) | ||
|
||
agent = alg(mdp.info, dist, policy, **params) | ||
|
||
# Train | ||
dataset_callback = CollectDataset() | ||
core = VectorCore(agent, mdp, callbacks_fit=[dataset_callback]) | ||
|
||
for i in trange(n_epochs, leave=False): | ||
core.learn(n_episodes=n_episodes, | ||
n_episodes_per_fit=n_ep_per_fit, render=False) | ||
dataset = dataset_callback.get() | ||
J = np.mean(dataset.discounted_return) | ||
dataset_callback.clean() | ||
|
||
p = dist.get_parameters() | ||
|
||
logger.epoch_info(i+1, J=J, mu=p[:n_weights], sigma=p[n_weights:]) | ||
|
||
logger.info('Press a button to visualize the segway...') | ||
input() | ||
core.evaluate(n_episodes=3, render=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
algs_params = [ | ||
(REPS, {'eps': 0.05}), | ||
(RWR, {'beta': 0.01}), | ||
(PGPE, {'optimizer': AdaptiveOptimizer(eps=0.3)}), | ||
] | ||
for alg, params in algs_params: | ||
experiment(alg, params, n_epochs=20, n_episodes=100, n_ep_per_fit=25) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import numpy as np | ||
from copy import deepcopy | ||
|
||
from .policy import ParametricPolicy | ||
|
||
|
||
class VectorPolicy(ParametricPolicy): | ||
def __init__(self, policy, n_envs): | ||
""" | ||
Constructor. | ||
Args: | ||
policy (ParametricPolicy): base policy to copy | ||
n_envs: number of environments to be repeated. | ||
""" | ||
super().__init__(policy_state_shape=policy.policy_state_shape) | ||
self._policy_vector = [deepcopy(policy) for _ in range(n_envs)] | ||
|
||
self._add_save_attr(_policy_vector='mushroom') | ||
|
||
def draw_action(self, state, policy_state): | ||
actions = list() | ||
policy_next_states = list() | ||
for i, policy in enumerate(self._policy_vector): | ||
s = state[i] | ||
ps = policy_state[i] if policy_state is not None else None | ||
action, policy_next_state = policy.draw_action(s, policy_state=ps) | ||
|
||
actions.append(action) | ||
|
||
if policy_next_state is not None: | ||
policy_next_state.append(policy_next_state) | ||
|
||
return np.array(actions), None if len(policy_next_states) == 0 else np.array(policy_next_state) | ||
|
||
def set_n(self, n_envs): | ||
if len(self) < n_envs: | ||
self._policy_vector = self._policy_vector[:n_envs] | ||
if len(self) > n_envs: | ||
n_missing = n_envs - len(self) | ||
self._policy_vector += [self._policy_vector[0] for _ in range(n_missing)] | ||
|
||
def get_flat_policy(self): | ||
return self._policy_vector[0] | ||
|
||
def set_weights(self, weights): | ||
""" | ||
Setter. | ||
Args: | ||
weights (np.ndarray): the vector of the new weights to be used by | ||
the policy. | ||
""" | ||
for i, policy in enumerate(self._policy_vector): | ||
policy.set_weights(weights[i]) | ||
|
||
def get_weights(self): | ||
""" | ||
Getter. | ||
Returns: | ||
The current policy weights. | ||
""" | ||
|
||
weight_list = list() | ||
for i, policy in enumerate(self._policy_vector): | ||
weights_i = policy.get_weights() | ||
weight_list.append(weights_i) | ||
|
||
return weight_list | ||
|
||
@property | ||
def weights_size(self): | ||
""" | ||
Property. | ||
Returns: | ||
The size of the policy weights. | ||
""" | ||
return len(self), self._policy_vector[0].weights_size | ||
|
||
def reset(self): | ||
policy_states = list() | ||
for i, policy in enumerate(self._policy_vector): | ||
policy_state = policy.reset() | ||
|
||
if policy_state is not None: | ||
policy_states.append(policy_state) | ||
|
||
return None if len(policy_states) == 0 else np.array(policy_states) | ||
|
||
def __len__(self): | ||
return len(self._policy_vector) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.