Skip to content

Commit

Permalink
Fixes on multiprocess environments and vector core
Browse files Browse the repository at this point in the history
- fixed multiprocess environment class, now it should work
- fixed minor bug in dataset
- many fixes in vectorized core
- small update in utils/torch
  • Loading branch information
boris-il-forte committed Oct 30, 2023
1 parent 5e74492 commit 2ba0eb0
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 77 deletions.
128 changes: 128 additions & 0 deletions examples/multiprocess_pendulum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from tqdm import trange

from mushroom_rl.core import VectorCore, Logger, MultiprocessEnvironment
from mushroom_rl.environments import Gym
from mushroom_rl.algorithms.actor_critic import PPO

from mushroom_rl.policy import GaussianTorchPolicy


class Network(nn.Module):
def __init__(self, input_shape, output_shape, n_features, **kwargs):
super(Network, self).__init__()

n_input = input_shape[-1]
n_output = output_shape[0]

self._h1 = nn.Linear(n_input, n_features)
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)

nn.init.xavier_uniform_(self._h1.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight,
gain=nn.init.calculate_gain('linear'))

def forward(self, state, **kwargs):
features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
features2 = F.relu(self._h2(features1))
a = self._h3(features2)

return a


def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit, n_episodes_test,
alg_params, policy_params):

logger = Logger(alg.__name__, results_dir=None)
logger.strong_line()
logger.info('Experiment Algorithm: ' + alg.__name__)

mdp = MultiprocessEnvironment(Gym, env_id, horizon, gamma, n_envs=15)

critic_params = dict(network=Network,
optimizer={'class': optim.Adam,
'params': {'lr': 3e-4}},
loss=F.mse_loss,
n_features=32,
batch_size=64,
input_shape=mdp.info.observation_space.shape,
output_shape=(1,))

policy = GaussianTorchPolicy(Network,
mdp.info.observation_space.shape,
mdp.info.action_space.shape,
**policy_params)

alg_params['critic_params'] = critic_params

agent = alg(mdp.info, policy, **alg_params)
#agent.set_logger(logger)

core = VectorCore(agent, mdp)

dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(0, J=J, R=R, entropy=E)

for it in trange(n_epochs, leave=False):
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(it+1, J=J, R=R, entropy=E)

logger.info('Press a button to visualize')
input()
core.evaluate(n_episodes=5, render=True)


if __name__ == '__main__':
max_kl = .015

policy_params = dict(
std_0=1.,
n_features=32,
use_cuda=False

)

ppo_params = dict(actor_optimizer={'class': optim.Adam,
'params': {'lr': 3e-4}},
n_epochs_policy=4,
batch_size=64,
eps_ppo=.2,
lam=.95)

trpo_params = dict(ent_coeff=0.0,
max_kl=.01,
lam=.95,
n_epochs_line_search=10,
n_epochs_cg=100,
cg_damping=1e-2,
cg_residual_tol=1e-10)

algs_params = [
(PPO, 'ppo', ppo_params)
]

for alg, alg_name, alg_params in algs_params:
experiment(alg=alg, env_id='Pendulum-v1', horizon=200, gamma=.99,
n_epochs=40, n_steps=30000, n_steps_per_fit=3000,
n_episodes_test=25, alg_params=alg_params,
policy_params=policy_params)
15 changes: 7 additions & 8 deletions examples/pendulum_trust_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mushroom_rl.algorithms.actor_critic import TRPO, PPO

from mushroom_rl.policy import GaussianTorchPolicy
from mushroom_rl.utils.dataset import compute_J


class Network(nn.Module):
Expand Down Expand Up @@ -66,14 +65,14 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
alg_params['critic_params'] = critic_params

agent = alg(mdp.info, policy, **alg_params)
agent.set_logger(logger)
#agent.set_logger(logger)

core = Core(agent, mdp)

dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(0, J=J, R=R, entropy=E)
Expand All @@ -82,8 +81,8 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy()

logger.epoch_info(it+1, J=J, R=R, entropy=E)
Expand All @@ -99,7 +98,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
policy_params = dict(
std_0=1.,
n_features=32,
use_cuda=torch.cuda.is_available()
use_cuda=False

)

Expand All @@ -119,7 +118,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
cg_residual_tol=1e-10)

algs_params = [
(TRPO, 'trpo', trpo_params),
#(TRPO, 'trpo', trpo_params),
(PPO, 'ppo', ppo_params)
]

Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.0'
__version__ = '2.0.0-rc1'
33 changes: 16 additions & 17 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,22 @@ def _update_policy(self, obs, act, adv, old_log_p):

def _log_info(self, dataset, x, v_target, old_pol_dist):
if self._logger:
logging_verr = []
torch_v_targets = torch.tensor(v_target, dtype=torch.float)
for idx in range(len(self._V)):
v_pred = torch.tensor(self._V(x, idx=idx), dtype=torch.float)
v_err = F.mse_loss(v_pred, torch_v_targets)
logging_verr.append(v_err.item())

logging_ent = self.policy.entropy(x)
new_pol_dist = self.policy.distribution(x)
logging_kl = torch.mean(torch.distributions.kl.kl_divergence(
new_pol_dist, old_pol_dist))
avg_rwd = np.mean(dataset.undiscounted_return)
msg = "Iteration {}:\n\t\t\t\trewards {} vf_loss {}\n\t\t\t\tentropy {} kl {}".format(
self._iter, avg_rwd, logging_verr, logging_ent, logging_kl)

self._logger.info(msg)
self._logger.weak_line()
with torch.no_grad():
logging_verr = []
for idx in range(len(self._V)):
v_pred = self._V(x, idx=idx, output_tensor=True)
v_err = F.mse_loss(v_pred, v_target)
logging_verr.append(v_err.item())

logging_ent = self.policy.entropy(x)
new_pol_dist = self.policy.distribution(x)
logging_kl = torch.mean(torch.distributions.kl.kl_divergence(new_pol_dist, old_pol_dist))
avg_rwd = np.mean(dataset.undiscounted_return)
msg = "Iteration {}:\n\t\t\t\trewards {} vf_loss {}\n\t\t\t\tentropy {} kl {}".format(
self._iter, avg_rwd, logging_verr, logging_ent, logging_kl)

self._logger.info(msg)
self._logger.weak_line()

def _post_load(self):
if self._optimizer is not None:
Expand Down
7 changes: 6 additions & 1 deletion mushroom_rl/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from .serialization import Serializable
from .logger import Logger

from .vectorized_core import VectorCore
from .vectorized_env import VectorizedEnvironment
from .multiprocess_environment import MultiprocessEnvironment

import mushroom_rl.environments

__all__ = ['Core', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', 'Serializable', 'Logger']
__all__ = ['Core', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', 'Serializable', 'Logger',
'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
2 changes: 1 addition & 1 deletion mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __add__(self, other):

result._info = new_info
result._episode_info = new_episode_info
result.theta_list = result._theta_list + other._theta_list
result._theta_list = result._theta_list + other._theta_list
result._data = self._data + other._data

return result
Expand Down
86 changes: 60 additions & 26 deletions mushroom_rl/core/multiprocess_environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from multiprocessing import Pipe
from multiprocessing import Process
from multiprocessing import Pipe, Process, cpu_count

import numpy as np

Expand All @@ -16,14 +15,25 @@ def _env_worker(remote, env_class, use_generator, args, kwargs):
try:
while True:
cmd, data = remote.recv()

# if data is None:
# print(f'Executed command {cmd} with None data')

if cmd == 'step':
action = data[0]
action = data
res = env.step(action)
remote.send(res)
elif cmd == 'reset':
init_states = data[0]
if data is not None:
init_states = data[0]
else:
init_states = None
res = env.reset(init_states)
remote.send(res)
elif cmd == 'render':
record = data
res = env.render(record=record)
remote.send(res)
elif cmd in 'stop':
env.stop()
remote.send(None)
Expand All @@ -32,7 +42,10 @@ def _env_worker(remote, env_class, use_generator, args, kwargs):
elif cmd == 'seed':
env.seed(int(data))
remote.send(None)
elif cmd == 'close':
break
else:
print(f'cmd {cmd}')
raise NotImplementedError()
finally:
remote.close()
Expand All @@ -56,7 +69,10 @@ def __init__(self, env_class, *args, n_envs=-1, use_generator=False, **kwargs):
**kwargs: keyword arguments to set to the constructor or to the generator;
"""
assert n_envs > 1
assert n_envs > 1 or n_envs == -1

if n_envs == -1:
n_envs = cpu_count()

self._remotes, self._work_remotes = zip(*[Pipe() for _ in range(n_envs)])
self._processes = list()
Expand All @@ -73,36 +89,51 @@ def __init__(self, env_class, *args, n_envs=-1, use_generator=False, **kwargs):

super().__init__(mdp_info, n_envs)

def step_all(self, env_mask, action):
for i, remote in enumerate(self._remotes):
if env_mask[i]:
remote.send(('step', action[i, :]))

states = list()
step_infos = list()
for i, remote in enumerate(self._remotes):
if env_mask[i]:
state, step_info = remote.recv()
states.append(remote.recv())
step_infos.append(step_info)

return np.array(states), step_infos
self._state_shape = (n_envs,) + self.info.observation_space.shape
self._reward_shape = (n_envs,)
self._absorbing_shape = (n_envs,)

def reset_all(self, env_mask, state=None):
for i, remote in enumerate(self._remotes):
if env_mask[i]:
state_i = state[i, :] if state is not None else None
remote.send(('reset', state_i))

states = list()
states = np.empty(self._state_shape)
episode_infos = list()
for i, remote in enumerate(self._remotes):
if env_mask[i]:
state, episode_info = remote.recv()
states.append(state)

states[i] = state
episode_infos.append(episode_info)
else:
episode_infos.append({})

return np.array(states), episode_infos
return states, episode_infos

def step_all(self, env_mask, action):
for i, remote in enumerate(self._remotes):
if env_mask[i]:
remote.send(('step', action[i, :]))

states = np.empty(self._state_shape)
rewards = np.empty(self._reward_shape)
absorbings = np.empty(self._absorbing_shape, dtype=bool)
step_infos = list()

for i, remote in enumerate(self._remotes):
if env_mask[i]:
state, reward, absorbing, step_info = remote.recv()

states[i] = state
rewards[i] = reward
absorbings[i] = absorbing
step_infos.append(step_info)
else:
step_infos.append({})

return states, rewards, absorbings, step_infos

def render_all(self, env_mask, record=False):
for i, remote in enumerate(self._remotes):
Expand All @@ -128,12 +159,15 @@ def seed(self, seed):
def stop(self):
for remote in self._remotes:
remote.send(('stop', None))
remote.recv()

def __del__(self):
for remote in self._remotes:
remote.send(('close', None))
for p in self._processes:
p.join()
if hasattr(self, '_remotes'):
for remote in self._remotes:
remote.send(('close', None))
if hasattr(self, '_processes'):
for p in self._processes:
p.join()

@staticmethod
def generate(env, *args, n_envs=-1, **kwargs):
Expand Down
Loading

0 comments on commit 2ba0eb0

Please sign in to comment.