-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes on multiprocess environments and vector core
- fixed multiprocess environment class, now it should work - fixed minor bug in dataset - many fixes in vectorized core - small update in utils/torch
- Loading branch information
1 parent
5e74492
commit 2ba0eb0
Showing
9 changed files
with
253 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
import numpy as np | ||
from tqdm import trange | ||
|
||
from mushroom_rl.core import VectorCore, Logger, MultiprocessEnvironment | ||
from mushroom_rl.environments import Gym | ||
from mushroom_rl.algorithms.actor_critic import PPO | ||
|
||
from mushroom_rl.policy import GaussianTorchPolicy | ||
|
||
|
||
class Network(nn.Module): | ||
def __init__(self, input_shape, output_shape, n_features, **kwargs): | ||
super(Network, self).__init__() | ||
|
||
n_input = input_shape[-1] | ||
n_output = output_shape[0] | ||
|
||
self._h1 = nn.Linear(n_input, n_features) | ||
self._h2 = nn.Linear(n_features, n_features) | ||
self._h3 = nn.Linear(n_features, n_output) | ||
|
||
nn.init.xavier_uniform_(self._h1.weight, | ||
gain=nn.init.calculate_gain('relu')) | ||
nn.init.xavier_uniform_(self._h2.weight, | ||
gain=nn.init.calculate_gain('relu')) | ||
nn.init.xavier_uniform_(self._h3.weight, | ||
gain=nn.init.calculate_gain('linear')) | ||
|
||
def forward(self, state, **kwargs): | ||
features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) | ||
features2 = F.relu(self._h2(features1)) | ||
a = self._h3(features2) | ||
|
||
return a | ||
|
||
|
||
def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit, n_episodes_test, | ||
alg_params, policy_params): | ||
|
||
logger = Logger(alg.__name__, results_dir=None) | ||
logger.strong_line() | ||
logger.info('Experiment Algorithm: ' + alg.__name__) | ||
|
||
mdp = MultiprocessEnvironment(Gym, env_id, horizon, gamma, n_envs=15) | ||
|
||
critic_params = dict(network=Network, | ||
optimizer={'class': optim.Adam, | ||
'params': {'lr': 3e-4}}, | ||
loss=F.mse_loss, | ||
n_features=32, | ||
batch_size=64, | ||
input_shape=mdp.info.observation_space.shape, | ||
output_shape=(1,)) | ||
|
||
policy = GaussianTorchPolicy(Network, | ||
mdp.info.observation_space.shape, | ||
mdp.info.action_space.shape, | ||
**policy_params) | ||
|
||
alg_params['critic_params'] = critic_params | ||
|
||
agent = alg(mdp.info, policy, **alg_params) | ||
#agent.set_logger(logger) | ||
|
||
core = VectorCore(agent, mdp) | ||
|
||
dataset = core.evaluate(n_episodes=n_episodes_test, render=False) | ||
|
||
J = np.mean(dataset.discounted_return) | ||
R = np.mean(dataset.undiscounted_return) | ||
E = agent.policy.entropy() | ||
|
||
logger.epoch_info(0, J=J, R=R, entropy=E) | ||
|
||
for it in trange(n_epochs, leave=False): | ||
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit) | ||
dataset = core.evaluate(n_episodes=n_episodes_test, render=False) | ||
|
||
J = np.mean(dataset.discounted_return) | ||
R = np.mean(dataset.undiscounted_return) | ||
E = agent.policy.entropy() | ||
|
||
logger.epoch_info(it+1, J=J, R=R, entropy=E) | ||
|
||
logger.info('Press a button to visualize') | ||
input() | ||
core.evaluate(n_episodes=5, render=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
max_kl = .015 | ||
|
||
policy_params = dict( | ||
std_0=1., | ||
n_features=32, | ||
use_cuda=False | ||
|
||
) | ||
|
||
ppo_params = dict(actor_optimizer={'class': optim.Adam, | ||
'params': {'lr': 3e-4}}, | ||
n_epochs_policy=4, | ||
batch_size=64, | ||
eps_ppo=.2, | ||
lam=.95) | ||
|
||
trpo_params = dict(ent_coeff=0.0, | ||
max_kl=.01, | ||
lam=.95, | ||
n_epochs_line_search=10, | ||
n_epochs_cg=100, | ||
cg_damping=1e-2, | ||
cg_residual_tol=1e-10) | ||
|
||
algs_params = [ | ||
(PPO, 'ppo', ppo_params) | ||
] | ||
|
||
for alg, alg_name, alg_params in algs_params: | ||
experiment(alg=alg, env_id='Pendulum-v1', horizon=200, gamma=.99, | ||
n_epochs=40, n_steps=30000, n_steps_per_fit=3000, | ||
n_episodes_test=25, alg_params=alg_params, | ||
policy_params=policy_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '1.10.0' | ||
__version__ = '2.0.0-rc1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.