diff --git a/examples/car_on_hill_fqi.py b/examples/car_on_hill_fqi.py index cef509223..534d716d3 100644 --- a/examples/car_on_hill_fqi.py +++ b/examples/car_on_hill_fqi.py @@ -6,7 +6,6 @@ from mushroom_rl.core import Core, Logger from mushroom_rl.environments import * from mushroom_rl.policy import EpsGreedy -from mushroom_rl.utils.dataset import compute_J from mushroom_rl.utils.parameters import Parameter """ @@ -65,7 +64,7 @@ def experiment(): # Render core.evaluate(n_episodes=3, render=True) - return np.mean(compute_J(dataset, mdp.info.gamma)) + return np.mean(dataset.discounted_return) if __name__ == '__main__': @@ -75,5 +74,5 @@ def experiment(): logger.strong_line() logger.info('Experiment Algorithm: ' + FQI.__name__) - Js = Parallel(n_jobs=-1)(delayed(experiment)() for _ in range(n_experiment)) + Js = Parallel(n_jobs=None)(delayed(experiment)() for _ in range(n_experiment)) logger.info((np.mean(Js))) diff --git a/examples/gym_recurrent_ppo.py b/examples/gym_recurrent_ppo.py new file mode 100644 index 000000000..fcde5a209 --- /dev/null +++ b/examples/gym_recurrent_ppo.py @@ -0,0 +1,282 @@ +import os +import numpy as np +import torch +from experiment_launcher.decorators import single_experiment +from experiment_launcher import run_experiment +import torch.optim as optim + +from mushroom_rl.core import Logger, Core +from mushroom_rl.environments import Gym + +from mushroom_rl.algorithms.actor_critic import PPO_BPTT +from mushroom_rl.policy import RecurrentGaussianTorchPolicy + +from tqdm import trange + + +def get_recurrent_network(rnn_type): + if rnn_type == "vanilla": + return torch.nn.RNN + elif rnn_type == "gru": + return torch.nn.GRU + else: + raise ValueError("Unknown RNN type %s." % rnn_type) + + +class PPOCriticBPTTNetwork(torch.nn.Module): + + def __init__(self, input_shape, output_shape, dim_env_state, dim_action, rnn_type, + n_hidden_features=128, n_features=128, num_hidden_layers=1, + hidden_state_treatment="zero_initial", **kwargs): + super().__init__() + + assert hidden_state_treatment in ["zero_initial", "use_policy_hidden_state"] + + self._input_shape = input_shape + self._output_shape = output_shape + self._dim_env_state = dim_env_state + self._dim_action = dim_action + self._use_policy_hidden_states = True if hidden_state_treatment == "use_policy_hidden_state" else False + + rnn = get_recurrent_network(rnn_type) + + # embedder + self._h1_o = torch.nn.Linear(dim_env_state, n_features) + self._h1_o_post_rnn = torch.nn.Linear(dim_env_state, n_features) + + # rnn + self._rnn = rnn(input_size=n_features, + hidden_size=n_hidden_features, + num_layers=num_hidden_layers, + # nonlinearity=hidden_activation, # todo: this is turned off for now to allow for rnn and gru + batch_first=True) + + # post-rnn layer + self._hq_1 = torch.nn.Linear(n_hidden_features+n_features, n_features) + self._hq_2 = torch.nn.Linear(n_features, 1) + self._act_func = torch.nn.ReLU() + + torch.nn.init.xavier_uniform_(self._h1_o.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self._h1_o_post_rnn.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self._hq_1.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self._hq_2.weight, gain=torch.nn.init.calculate_gain("relu")) + + def forward(self, state, policy_state, lengths): + # pre-rnn embedder + input_rnn = self._act_func(self._h1_o(state)) + + # --- forward rnn --- + # the inputs are padded. Based on that and the length, we created a packed sequence + packed_seq = torch.nn.utils.rnn.pack_padded_sequence(input_rnn, lengths, enforce_sorted=False, + batch_first=True) + if self._use_policy_hidden_states: + # hidden state has to have shape (N_layers, N_batch, DIM_hidden), + # so we need to reshape and swap the first two axes. + policy_state_reshaped = policy_state.view(-1, self._num_hidden_layers, self._n_hidden_features) + policy_state_reshaped = torch.swapaxes(policy_state_reshaped, 0, 1) + out_rnn, _ = self._rnn(packed_seq, policy_state_reshaped) + else: + out_rnn, _ = self._rnn(packed_seq) # use zero initial states + + # we only need the last entry in each sequence + features_rnn, _ = torch.nn.utils.rnn.pad_packed_sequence(out_rnn, batch_first=True) + rel_indices = lengths.view(-1, 1, 1) - 1 + features_rnn = torch.squeeze(torch.take_along_dim(features_rnn, rel_indices, dim=1), dim=1) + + # post-rnn embedder. Here we again only need the last state + last_state = torch.squeeze(torch.take_along_dim(state, rel_indices, dim=1), dim=1) + feature_s = self._act_func(self._h1_o_post_rnn(last_state)) + + # last layer + input_last_layer = torch.concat([feature_s, features_rnn], dim=1) + q = self._hq_2(self._act_func(self._hq_1(input_last_layer))) + + return torch.squeeze(q) + + +class PPOActorBPTTNetwork(torch.nn.Module): + + def __init__(self, input_shape, output_shape, n_features, dim_env_state, rnn_type, + n_hidden_features, num_hidden_layers=1, **kwargs): + super().__init__() + + dim_state = input_shape[0] + dim_action = output_shape[0] + self._dim_env_state = dim_env_state + self._num_hidden_layers = num_hidden_layers + self._n_hidden_features = n_hidden_features + + rnn = get_recurrent_network(rnn_type) + + # embedder + self._h1_o = torch.nn.Linear(dim_env_state, n_features) + self._h1_o_post_rnn = torch.nn.Linear(dim_env_state, n_features) + + # rnn + self._rnn = rnn(input_size=n_features, + hidden_size=n_hidden_features, + num_layers=num_hidden_layers, + # nonlinearity=hidden_activation, # todo: this is turned off for now to allow for rnn and gru + batch_first=True) + + # post-rnn layer + self._h3 = torch.nn.Linear(n_hidden_features+n_features, dim_action) + self._act_func = torch.nn.ReLU() + self._tanh = torch.nn.Tanh() + + torch.nn.init.xavier_uniform_(self._h1_o.weight, gain=torch.nn.init.calculate_gain("relu")*0.05) + torch.nn.init.xavier_uniform_(self._h1_o_post_rnn.weight, gain=torch.nn.init.calculate_gain("relu")*0.05) + torch.nn.init.xavier_uniform_(self._h3.weight, gain=torch.nn.init.calculate_gain("relu")*0.05) + + def forward(self, state, policy_state, lengths): + # pre-rnn embedder + input_rnn = self._act_func(self._h1_o(state)) + + # forward rnn + # the inputs are padded. Based on that and the length, we created a packed sequence + packed_seq = torch.nn.utils.rnn.pack_padded_sequence(input_rnn, lengths, enforce_sorted=False, + batch_first=True) + + # hidden state has to have shape (N_layers, N_batch, DIM_hidden), + # so we need to reshape and swap the first two axes. + policy_state_reshaped = policy_state.view(-1, self._num_hidden_layers, self._n_hidden_features) + policy_state_reshaped = torch.swapaxes(policy_state_reshaped, 0, 1) + + out_rnn, next_hidden = self._rnn(packed_seq, policy_state_reshaped) + + # we only need the last entry in each sequence + features_rnn, _ = torch.nn.utils.rnn.pad_packed_sequence(out_rnn, batch_first=True) + rel_indices = lengths.view(-1, 1, 1) - 1 + features_rnn = torch.squeeze(torch.take_along_dim(features_rnn, rel_indices, dim=1), dim=1) + + # post-rnn embedder. Here we again only need the last state + last_state = torch.squeeze(torch.take_along_dim(state, rel_indices, dim=1), dim=1) + feature_sa = self._act_func(self._h1_o_post_rnn(last_state)) + + # last layer + input_last_layer = torch.concat([feature_sa, features_rnn], dim=1) + a = self._h3(input_last_layer) + + return a, torch.swapaxes(next_hidden, 0, 1) + + +def get_POMDP_params(pomdp_type): + if pomdp_type == "no_velocities": + return dict(obs_to_hide=("velocities",), random_force_com=False) + elif pomdp_type == "no_positions": + return dict(obs_to_hide=("positions",), random_force_com=False) + elif pomdp_type == "windy": + return dict(obs_to_hide=tuple(), random_force_com=True) + + +@single_experiment +def experiment( + env: str = 'HalfCheetah-v4', + horizon: int = 1000, + gamma: float = 0.99, + n_epochs: int = 300, + n_steps_per_epoch: int = 50000, + n_steps_per_fit: int = 2000, + n_episode_eval: int = 10, + lr_actor: float = 0.001, + lr_critic: float = 0.001, + batch_size_actor: int = 32, + batch_size_critic: int = 32, + n_epochs_policy: int = 10, + clip_eps_ppo: float = 0.05, + gae_lambda: float = 0.95, + seed: int = 0, # This argument is mandatory + results_dir: str = './logs', # This argument is mandatory + use_cuda: bool = False, + std_0: float = 0.5, + rnn_type: str ="gru", + n_hidden_features: int = 128, + num_hidden_layers: int = 1, + truncation_length: int = 5 +): + np.random.seed(seed) + torch.manual_seed(seed) + + # prepare logging + results_dir = os.path.join(results_dir, str(seed)) + logger = Logger(results_dir=results_dir, log_name="stochastic_logging", seed=seed) + + # MDP + mdp = Gym(env, horizon=horizon, gamma=gamma) + + # create the policy + dim_env_state = mdp.info.observation_space.shape[0] + dim_action = mdp.info.action_space.shape[0] + + policy = RecurrentGaussianTorchPolicy(network=PPOActorBPTTNetwork, + policy_state_shape=(n_hidden_features,), + input_shape=(dim_env_state, ), + output_shape=(dim_action,), + n_features=128, + rnn_type=rnn_type, + n_hidden_features=n_hidden_features, + num_hidden_layers=num_hidden_layers, + dim_hidden_state=n_hidden_features, + dim_env_state=dim_env_state, + dim_action=dim_action, + std_0=std_0) + + # setup critic + input_shape_critic = (mdp.info.observation_space.shape[0]+2*n_hidden_features,) + critic_params = dict(network=PPOCriticBPTTNetwork, + optimizer={'class': optim.Adam, + 'params': {'lr': lr_critic, + 'weight_decay': 0.0}}, + loss=torch.nn.MSELoss(), + batch_size=batch_size_critic, + input_shape=input_shape_critic, + output_shape=(1,), + n_features=128, + n_hidden_features=n_hidden_features, + rnn_type=rnn_type, + num_hidden_layers=num_hidden_layers, + dim_env_state=mdp.info.observation_space.shape[0], + dim_hidden_state=n_hidden_features, + dim_action=dim_action, + use_cuda=use_cuda, + ) + + alg_params = dict(actor_optimizer={'class': optim.Adam, + 'params': {'lr': lr_actor, + 'weight_decay': 0.0}}, + n_epochs_policy=n_epochs_policy, + batch_size=batch_size_actor, + dim_env_state=dim_env_state, + eps_ppo=clip_eps_ppo, + lam=gae_lambda, + truncation_length=truncation_length + ) + + # Create the agent + agent = PPO_BPTT(mdp_info=mdp.info, policy=policy, critic_params=critic_params, **alg_params) + + # Create Core + core = Core(agent, mdp) + + # Evaluation + dataset = core.evaluate(n_episodes=5) + J = np.mean(dataset.discounted_return) + R = np.mean(dataset.undiscounted_return) + L = np.mean(dataset.episodes_length) + logger.log_numpy(R=R, J=J, L=L) + logger.epoch_info(0, R=R, J=J, L=L) + + for i in trange(1, n_epochs+1, 1, leave=False): + core.learn(n_steps=n_steps_per_epoch, n_steps_per_fit=n_steps_per_fit) + + # Evaluation + dataset = core.evaluate(n_episodes=n_episode_eval) + J = np.mean(dataset.discounted_return) + R = np.mean(dataset.undiscounted_return) + L = np.mean(dataset.episodes_length) + logger.log_numpy(R=R, J=J, L=L) + logger.epoch_info(i, R=R, J=J, L=L) + + +if __name__ == '__main__': + run_experiment(experiment) diff --git a/examples/pendulum_a2c.py b/examples/pendulum_a2c.py index 58cf76fb3..e37582cef 100644 --- a/examples/pendulum_a2c.py +++ b/examples/pendulum_a2c.py @@ -11,7 +11,6 @@ from mushroom_rl.algorithms.actor_critic import A2C from mushroom_rl.policy import GaussianTorchPolicy -from mushroom_rl.utils.dataset import compute_J class Network(nn.Module): @@ -72,8 +71,8 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit, dataset = core.evaluate(n_steps=n_step_test, render=False) - J = np.mean(compute_J(dataset, mdp.info.gamma)) - R = np.mean(compute_J(dataset)) + J = np.mean(dataset.discounted_return) + R = np.mean(dataset.undiscounted_return) E = agent.policy.entropy() logger.epoch_info(0, J=J, R=R, entropy=E) @@ -82,8 +81,8 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit, core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit) dataset = core.evaluate(n_steps=n_step_test, render=False) - J = np.mean(compute_J(dataset, mdp.info.gamma)) - R = np.mean(compute_J(dataset)) + J = np.mean(dataset.discounted_return) + R = np.mean(dataset.undiscounted_return) E = agent.policy.entropy() logger.epoch_info(it+1, J=J, R=R, entropy=E) diff --git a/mushroom_rl/algorithms/actor_critic/__init__.py b/mushroom_rl/algorithms/actor_critic/__init__.py index aea310acf..ab3748f44 100644 --- a/mushroom_rl/algorithms/actor_critic/__init__.py +++ b/mushroom_rl/algorithms/actor_critic/__init__.py @@ -1,5 +1,5 @@ from .classic_actor_critic import StochasticAC, StochasticAC_AVG, COPDAC_Q -from .deep_actor_critic import DeepAC, A2C, DDPG, TD3, SAC, TRPO, PPO +from .deep_actor_critic import DeepAC, A2C, DDPG, TD3, SAC, TRPO, PPO, PPO_BPTT __all__ = ['COPDAC_Q', 'StochasticAC', 'StochasticAC_AVG', - 'DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO'] + 'DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO', 'PPO_BPTT'] diff --git a/mushroom_rl/algorithms/actor_critic/classic_actor_critic/copdac_q.py b/mushroom_rl/algorithms/actor_critic/classic_actor_critic/copdac_q.py index 7787e785d..edb0e9739 100644 --- a/mushroom_rl/algorithms/actor_critic/classic_actor_critic/copdac_q.py +++ b/mushroom_rl/algorithms/actor_critic/classic_actor_critic/copdac_q.py @@ -14,8 +14,7 @@ class COPDAC_Q(Agent): Silver D. et al.. 2014. """ - def __init__(self, mdp_info, policy, mu, alpha_theta, alpha_omega, alpha_v, - value_function_features=None, policy_features=None): + def __init__(self, mdp_info, policy, mu, alpha_theta, alpha_omega, alpha_v, value_function_features=None): """ Constructor. @@ -27,7 +26,6 @@ def __init__(self, mdp_info, policy, mu, alpha_theta, alpha_omega, alpha_v, alpha_v ([float, Parameter]): learning rate for the value function; value_function_features (Features, None): features used by the value function approximator; - policy_features (Features, None): features used by the policy. """ self._mu = mu @@ -59,19 +57,18 @@ def __init__(self, mdp_info, policy, mu, alpha_theta, alpha_omega, alpha_v, _A='mushroom' ) - super().__init__(mdp_info, policy, policy_features) + super().__init__(mdp_info, policy) - def fit(self, dataset, **info): + def fit(self, dataset): for step in dataset: s, a, r, ss, absorbing, _ = step - s_phi = self.phi(s) if self.phi is not None else s s_psi = self._psi(s) if self._psi is not None else s ss_psi = self._psi(ss) if self._psi is not None else ss q_next = self._V(ss_psi).item() if not absorbing else 0 - grad_mu_s = np.atleast_2d(self._mu.diff(s_phi)) + grad_mu_s = np.atleast_2d(self._mu.diff(s)) omega = self._A.get_weights() delta = r + self.mdp_info.gamma * q_next - self._Q(s, a) @@ -96,8 +93,7 @@ def _Q(self, state, action): action)).item() def _nu(self, state, action): - state_phi = self.phi(state) if self.phi is not None else state - grad_mu = np.atleast_2d(self._mu.diff(state_phi)) - delta = action - self._mu(state_phi) + grad_mu = np.atleast_2d(self._mu.diff(state)) + delta = action - self._mu(state) return delta.dot(grad_mu) diff --git a/mushroom_rl/algorithms/actor_critic/classic_actor_critic/stochastic_ac.py b/mushroom_rl/algorithms/actor_critic/classic_actor_critic/stochastic_ac.py index d74b1d41d..2c49123b9 100644 --- a/mushroom_rl/algorithms/actor_critic/classic_actor_critic/stochastic_ac.py +++ b/mushroom_rl/algorithms/actor_critic/classic_actor_critic/stochastic_ac.py @@ -14,8 +14,7 @@ class StochasticAC(Agent): Degris T. et al.. 2012. """ - def __init__(self, mdp_info, policy, alpha_theta, alpha_v, lambda_par=.9, - value_function_features=None, policy_features=None): + def __init__(self, mdp_info, policy, alpha_theta, alpha_v, lambda_par=.9, value_function_features=None): """ Constructor. @@ -23,9 +22,7 @@ def __init__(self, mdp_info, policy, alpha_theta, alpha_v, lambda_par=.9, alpha_theta ([float, Parameter]): learning rate for policy update; alpha_v ([float, Parameter]): learning rate for the value function; lambda_par ([float, Parameter], .9): trace decay parameter; - value_function_features (Features, None): features used by the - value function approximator; - policy_features (Features, None): features used by the policy. + value_function_features (Features, None): features used by the value function approximator. """ self._psi = value_function_features @@ -35,15 +32,14 @@ def __init__(self, mdp_info, policy, alpha_theta, alpha_v, lambda_par=.9, self._lambda = to_parameter(lambda_par) - super().__init__(mdp_info, policy, policy_features) + super().__init__(mdp_info, policy) if self._psi is not None: input_shape = (self._psi.size,) else: input_shape = mdp_info.observation_space.shape - self._V = Regressor(LinearApproximator, input_shape=input_shape, - output_shape=(1,)) + self._V = Regressor(LinearApproximator, input_shape=input_shape, output_shape=(1,)) self._e_v = np.zeros(self._V.weights_size) self._e_theta = np.zeros(self.policy.weights_size) @@ -58,23 +54,22 @@ def __init__(self, mdp_info, policy, alpha_theta, alpha_v, lambda_par=.9, _e_theta='numpy' ) - def episode_start(self): + def episode_start(self, episode_info): self._e_v = np.zeros(self._V.weights_size) self._e_theta = np.zeros(self.policy.weights_size) - super().episode_start() + return super().episode_start(episode_info) - def fit(self, dataset, **info): + def fit(self, dataset): for step in dataset: s, a, r, ss, absorbing, _ = step - s_phi = self.phi(s) if self.phi is not None else s s_psi = self._psi(s) if self._psi is not None else s ss_psi = self._psi(ss) if self._psi is not None else ss v_next = self._V(ss_psi) if not absorbing else 0 - delta = self._compute_td_n_traces(a, r, v_next, s_psi, s_phi) + delta = self._compute_td_n_traces(s, a, r, v_next, s_psi) # Update value function delta_v = self._alpha_v(s, a) * delta * self._e_v @@ -86,14 +81,13 @@ def fit(self, dataset, **info): theta_new = self.policy.get_weights() + delta_theta self.policy.set_weights(theta_new) - def _compute_td_n_traces(self, a, r, v_next, s_psi, s_phi): + def _compute_td_n_traces(self, s, a, r, v_next, s_psi): # Compute TD error delta = r + self.mdp_info.gamma * v_next - self._V(s_psi) # Update traces self._e_v = self.mdp_info.gamma * self._lambda() * self._e_v + s_psi - self._e_theta = self.mdp_info.gamma * self._lambda() * \ - self._e_theta + self.policy.diff_log(s_phi, a) + self._e_theta = self.mdp_info.gamma * self._lambda() * self._e_theta + self.policy.diff_log(s, a) return delta @@ -105,9 +99,7 @@ class StochasticAC_AVG(StochasticAC): Degris T. et al.. 2012. """ - def __init__(self, mdp_info, policy, alpha_theta, alpha_v, alpha_r, - lambda_par=.9, value_function_features=None, - policy_features=None): + def __init__(self, mdp_info, policy, alpha_theta, alpha_v, alpha_r, lambda_par=.9, value_function_features=None): """ Constructor. @@ -115,21 +107,20 @@ def __init__(self, mdp_info, policy, alpha_theta, alpha_v, alpha_r, alpha_r (Parameter): learning rate for the reward trace. """ - super().__init__(mdp_info, policy, alpha_theta, alpha_v, lambda_par, - value_function_features, policy_features) + super().__init__(mdp_info, policy, alpha_theta, alpha_v, lambda_par, value_function_features) self._alpha_r = to_parameter(alpha_r) self._r_bar = 0 self._add_save_attr(_alpha_r='mushroom', _r_bar='primitive') - def _compute_td_n_traces(self, a, r, v_next, s_psi, s_phi): + def _compute_td_n_traces(self, s, a, r, v_next, s_psi): # Compute TD error delta = r - self._r_bar + v_next - self._V(s_psi) # Update traces self._r_bar += self._alpha_r() * delta self._e_v = self._lambda() * self._e_v + s_psi - self._e_theta = self._lambda() * self._e_theta + self.policy.diff_log(s_phi, a) + self._e_theta = self._lambda() * self._e_theta + self.policy.diff_log(s, a) return delta diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py index a2fbf5bb6..9740f682b 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py @@ -5,5 +5,6 @@ from .sac import SAC from .trpo import TRPO from .ppo import PPO +from .ppo_bptt import PPO_BPTT -__all__ = ['DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO'] \ No newline at end of file +__all__ = ['DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO', 'PPO_BPTT'] \ No newline at end of file diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py index a1c2dbc84..51b127439 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py @@ -4,9 +4,7 @@ from mushroom_rl.approximators import Regressor from mushroom_rl.approximators.parametric import TorchApproximator from mushroom_rl.utils.value_functions import compute_advantage_montecarlo -from mushroom_rl.utils.dataset import parse_dataset from mushroom_rl.utils.parameters import to_parameter -from mushroom_rl.utils.torch import to_float_tensor from copy import deepcopy @@ -58,8 +56,8 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, super().__init__(mdp_info, policy, actor_optimizer, policy.parameters()) - def fit(self, dataset, **info): - state, action, reward, next_state, absorbing, _ = parse_dataset(dataset) + def fit(self, dataset): + state, action, reward, next_state, absorbing, _ = dataset.parse(to='torch') v, adv = compute_advantage_montecarlo(self._V, state, next_state, reward, absorbing, @@ -70,15 +68,8 @@ def fit(self, dataset, **info): self._optimize_actor_parameters(loss) def _loss(self, state, action, adv): - use_cuda = self.policy.use_cuda - - s = to_float_tensor(state, use_cuda) - a = to_float_tensor(action, use_cuda) - - adv_t = to_float_tensor(adv, use_cuda) - - gradient_loss = -torch.mean(self.policy.log_prob_t(s, a)*adv_t) - entropy_loss = -self.policy.entropy_t(s) + gradient_loss = -torch.mean(self.policy.log_prob_t(state, action)*adv) + entropy_loss = -self.policy.entropy_t(state) return gradient_loss + self._entropy_coeff() * entropy_loss diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py index 54c6abb73..bf3380a6a 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ddpg.py @@ -97,7 +97,7 @@ def __init__(self, mdp_info, policy_class, policy_params, super().__init__(mdp_info, policy, actor_optimizer, policy_parameters) - def fit(self, dataset, **info): + def fit(self, dataset): self._replay_memory.add(dataset) if self._replay_memory.initialized: state, action, reward, next_state, absorbing, _ =\ diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py index d007b1ccd..f97b1d4d2 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py @@ -42,7 +42,7 @@ def __init__(self, mdp_info, policy, actor_optimizer, parameters): super().__init__(mdp_info, policy) - def fit(self, dataset, **info): + def fit(self, dataset): """ Fit step. diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py index bef5ddd7c..748ed39f8 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py @@ -8,7 +8,6 @@ from mushroom_rl.approximators.parametric import TorchApproximator from mushroom_rl.utils.torch import to_float_tensor, update_optimizer_parameters from mushroom_rl.utils.minibatches import minibatch_generator -from mushroom_rl.utils.dataset import parse_dataset, compute_J from mushroom_rl.utils.value_functions import compute_gae from mushroom_rl.utils.parameters import to_parameter @@ -69,30 +68,27 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, _iter='primitive' ) - super().__init__(mdp_info, policy, None) + super().__init__(mdp_info, policy) - def fit(self, dataset, **info): - x, u, r, xn, absorbing, last = parse_dataset(dataset) - x = x.astype(np.float32) - u = u.astype(np.float32) - r = r.astype(np.float32) - xn = xn.astype(np.float32) + def fit(self, dataset): + state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') - obs = to_float_tensor(x, self.policy.use_cuda) - act = to_float_tensor(u, self.policy.use_cuda) - v_target, np_adv = compute_gae(self._V, x, xn, r, absorbing, last, self.mdp_info.gamma, self._lambda()) - np_adv = (np_adv - np.mean(np_adv)) / (np.std(np_adv) + 1e-8) - adv = to_float_tensor(np_adv, self.policy.use_cuda) + v_target, adv = compute_gae(self._V, state, next_state, reward, absorbing, last, + self.mdp_info.gamma, self._lambda()) + adv = (adv - torch.mean(adv)) / (torch.std(adv) + 1e-8) - old_pol_dist = self.policy.distribution_t(obs) - old_log_p = old_pol_dist.log_prob(act)[:, None].detach() + adv = adv.detach() + v_target = v_target.detach() - self._V.fit(x, v_target, **self._critic_fit_params) + old_pol_dist = self.policy.distribution_t(state) + old_log_p = old_pol_dist.log_prob(action)[:, None].detach() - self._update_policy(obs, act, adv, old_log_p) + self._V.fit(state, v_target, **self._critic_fit_params) + + self._update_policy(state, action, adv, old_log_p) # Print fit information - self._log_info(dataset, x, v_target, old_pol_dist) + self._log_info(dataset, state, v_target, old_pol_dist) self._iter += 1 def _update_policy(self, obs, act, adv, old_log_p): @@ -124,7 +120,7 @@ def _log_info(self, dataset, x, v_target, old_pol_dist): new_pol_dist = self.policy.distribution(x) logging_kl = torch.mean(torch.distributions.kl.kl_divergence( new_pol_dist, old_pol_dist)) - avg_rwd = np.mean(compute_J(dataset)) + avg_rwd = np.mean(dataset.undiscounted_return) msg = "Iteration {}:\n\t\t\t\trewards {} vf_loss {}\n\t\t\t\tentropy {} kl {}".format( self._iter, avg_rwd, logging_verr, logging_ent, logging_kl) diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py new file mode 100644 index 000000000..5efa65448 --- /dev/null +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py @@ -0,0 +1,213 @@ +import torch + +from mushroom_rl.core import Agent +from mushroom_rl.approximators import Regressor +from mushroom_rl.approximators.parametric import TorchApproximator +from mushroom_rl.utils.torch import update_optimizer_parameters +from mushroom_rl.utils.minibatches import minibatch_generator +from mushroom_rl.utils.parameters import to_parameter +from mushroom_rl.utils.preprocessors import StandardizationPreprocessor + + +class PPO_BPTT(Agent): + """ + Proximal Policy Optimization algorithm. + "Proximal Policy Optimization Algorithms". + Schulman J. et al.. 2017. + + """ + def __init__(self, mdp_info, policy, actor_optimizer, critic_params, + n_epochs_policy, batch_size, eps_ppo, lam, dim_env_state, ent_coeff=0.0, + critic_fit_params=None, truncation_length=5): + """ + Constructor. + + Args: + policy (TorchPolicy): torch policy to be learned by the algorithm + actor_optimizer (dict): parameters to specify the actor optimizer + algorithm; + critic_params (dict): parameters of the critic approximator to + build; + n_epochs_policy ([int, Parameter]): number of policy updates for every dataset; + batch_size ([int, Parameter]): size of minibatches for every optimization step + eps_ppo ([float, Parameter]): value for probability ratio clipping; + lam ([float, Parameter], 1.): lambda coefficient used by generalized + advantage estimation; + ent_coeff ([float, Parameter], 1.): coefficient for the entropy regularization term; + critic_fit_params (dict, None): parameters of the fitting algorithm + of the critic approximator. + + """ + self._critic_fit_params = dict(n_epochs=10) if critic_fit_params is None else critic_fit_params + + self._n_epochs_policy = to_parameter(n_epochs_policy) + self._batch_size = to_parameter(batch_size) + self._eps_ppo = to_parameter(eps_ppo) + + self._optimizer = actor_optimizer['class'](policy.parameters(), **actor_optimizer['params']) + + self._lambda = to_parameter(lam) + self._ent_coeff = to_parameter(ent_coeff) + + self._V = Regressor(TorchApproximator, **critic_params) + + self._truncation_length = truncation_length + self._dim_env_state = dim_env_state + + self._iter = 1 + + self._add_save_attr( + _critic_fit_params='pickle', + _n_epochs_policy='mushroom', + _batch_size='mushroom', + _eps_ppo='mushroom', + _ent_coeff='mushroom', + _optimizer='torch', + _lambda='mushroom', + _V='mushroom', + _iter='primitive', + _dim_env_state='primitive' + ) + + super().__init__(mdp_info, policy, None) + + # add the standardization preprocessor + self._preprocessors.append(StandardizationPreprocessor(mdp_info)) + + def divide_state_to_env_hidden_batch(self, states): + assert len(states.shape) > 1, "This function only divides batches of states." + return states[:, 0:self._dim_env_state], states[:, self._dim_env_state:] + + def fit(self, dataset): + obs, act, r, obs_next, absorbing, last = dataset.parse(to='torch') + policy_state, policy_next_state = dataset.parse_policy_state(to='torch') + obs_seq, policy_state_seq, act_seq, obs_next_seq, policy_next_state_seq, lengths = \ + self.transform_to_sequences(obs, policy_state, act, obs_next, policy_next_state, last, absorbing) + + v_target, adv = self.compute_gae(self._V, obs_seq, policy_state_seq, obs_next_seq, policy_next_state_seq, + lengths, r, absorbing, last, self.mdp_info.gamma, self._lambda()) + adv = (adv - torch.mean(adv)) / (torch.std(adv) + 1e-8) + + old_pol_dist = self.policy.distribution_t(obs_seq, policy_state_seq, lengths) + old_log_p = old_pol_dist.log_prob(act)[:, None].detach() + + self._V.fit(obs_seq, policy_state_seq, lengths, v_target, **self._critic_fit_params) + + self._update_policy(obs_seq, policy_state_seq, act, lengths, adv, old_log_p) + + # Print fit information + self._log_info(dataset, obs_seq, policy_state_seq, lengths, v_target, old_pol_dist) + self._iter += 1 + + def transform_to_sequences(self, states, policy_states, actions, next_states, policy_next_states, last, absorbing): + + s = torch.empty(len(states), self._truncation_length, states.shape[-1]) + ps = torch.empty(len(states), policy_states.shape[-1]) + a = torch.empty(len(actions), self._truncation_length, actions.shape[-1]) + ss = torch.empty(len(states), self._truncation_length, states.shape[-1]) + pss = torch.empty(len(states), policy_states.shape[-1]) + lengths = torch.empty(len(states), dtype=torch.long) + + for i in range(len(states)): + # determine the begin of a sequence + begin_seq = max(i - self._truncation_length + 1, 0) + end_seq = i + 1 + + # maybe the sequence contains more than one trajectory, so we need to cut it so that it contains only one + lasts_absorbing = last[begin_seq - 1: i].int() + absorbing[begin_seq - 1: i].int() + begin_traj = torch.where(lasts_absorbing > 0) + sequence_is_shorter_than_requested = len(*begin_traj) > 0 + if sequence_is_shorter_than_requested: + begin_seq = begin_seq + begin_traj[0][-1] + + # get the sequences + states_seq = states[begin_seq:end_seq] + actions_seq = actions[begin_seq:end_seq] + next_states_seq = next_states[begin_seq:end_seq] + + # apply padding + length_seq = len(states_seq) + padded_states = torch.concatenate([states_seq, + torch.zeros((self._truncation_length - states_seq.shape[0], + states_seq.shape[1]))]) + padded_next_states = torch.concatenate([next_states_seq, + torch.zeros((self._truncation_length - next_states_seq.shape[0], + next_states_seq.shape[1]))]) + padded_action_seq = torch.concatenate([actions_seq, + torch.zeros((self._truncation_length - actions_seq.shape[0], + actions_seq.shape[1]))]) + + s[i] = padded_states + ps[i] = policy_states[begin_seq] + a[i] = padded_action_seq + ss[i] = padded_next_states + pss[i] = policy_next_states[begin_seq] + + lengths[i] = length_seq + + return s.detach(), ps.detach(), a.detach(), ss.detach(), pss.detach(), lengths.detach() + + def _update_policy(self, obs, pi_h, act, lengths, adv, old_log_p): + for epoch in range(self._n_epochs_policy()): + for obs_i, pi_h_i, act_i, length_i, adv_i, old_log_p_i in minibatch_generator( + self._batch_size(), obs, pi_h, act, lengths, adv, old_log_p): + self._optimizer.zero_grad() + prob_ratio = torch.exp( + self.policy.log_prob_t(obs_i, act_i, pi_h_i, length_i) - old_log_p_i + ) + clipped_ratio = torch.clamp(prob_ratio, 1 - self._eps_ppo(), 1 + self._eps_ppo.get_value()) + loss = -torch.mean(torch.min(prob_ratio * adv_i, clipped_ratio * adv_i)) + loss -= self._ent_coeff()*self.policy.entropy_t(obs_i) + loss.backward() + self._optimizer.step() + + def _log_info(self, dataset, x, pi_h, lengths, v_target, old_pol_dist): + pass + + def _post_load(self): + if self._optimizer is not None: + update_optimizer_parameters(self._optimizer, list(self.policy.parameters())) + + @staticmethod + def compute_gae(V, s, pi_h, ss, pi_hn, lengths, r, absorbing, last, gamma, lam): + """ + Function to compute Generalized Advantage Estimation (GAE) + and new value function target over a dataset. + + "High-Dimensional Continuous Control Using Generalized + Advantage Estimation". + Schulman J. et al.. 2016. + + Args: + V (Regressor): the current value function regressor; + s (numpy.ndarray): the set of states in which we want + to evaluate the advantage; + ss (numpy.ndarray): the set of next states in which we want + to evaluate the advantage; + r (numpy.ndarray): the reward obtained in each transition + from state s to state ss; + absorbing (numpy.ndarray): an array of boolean flags indicating + if the reached state is absorbing; + last (numpy.ndarray): an array of boolean flags indicating + if the reached state is the last of the trajectory; + gamma (float): the discount factor of the considered problem; + lam (float): the value for the lamba coefficient used by GEA + algorithm. + Returns: + The new estimate for the value function of the next state + and the estimated generalized advantage. + """ + with torch.no_grad(): + v = V(s, pi_h, lengths, output_tensor=True) + v_next = V(ss, pi_hn, lengths, output_tensor=True) + gen_adv = torch.empty_like(v) + for rev_k in range(len(v)): + k = len(v) - rev_k - 1 + if last[k] or rev_k == 0: + gen_adv[k] = r[k] - v[k] + if not absorbing[k]: + gen_adv[k] += gamma * v_next[k] + else: + gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] + + return gen_adv + v, gen_adv diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py index b4236f2b3..95adbf5e3 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py @@ -22,7 +22,8 @@ class SACPolicy(Policy): compute_action_and_log_prob_t methods, that are fundamental for the internals calculations of the SAC algorithm. """ - def __init__(self, mu_approximator, sigma_approximator, min_a, max_a, log_std_min, log_std_max): + def __init__(self, mu_approximator, sigma_approximator, min_a, max_a, log_std_min, log_std_max, + policy_state_shape=None): """ Constructor. @@ -35,6 +36,8 @@ def __init__(self, mu_approximator, sigma_approximator, min_a, max_a, log_std_mi log_std_max ([float, Parameter]): max value for the policy log std. """ + super().__init__(policy_state_shape) + self._mu_approximator = mu_approximator self._sigma_approximator = sigma_approximator @@ -62,12 +65,11 @@ def __init__(self, mu_approximator, sigma_approximator, min_a, max_a, log_std_mi _eps_log_prob='primitive' ) - def __call__(self, state, action): + def __call__(self, state, action, internal_state=None): raise NotImplementedError - def draw_action(self, state): - return self.compute_action_and_log_prob_t( - state, compute_log_prob=False).detach().cpu().numpy() + def draw_action(self, state, internal_state=None): + return self.compute_action_and_log_prob_t(state, compute_log_prob=False).detach().cpu().numpy(), None def compute_action_and_log_prob(self, state): """ @@ -278,7 +280,7 @@ def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, actor_optimize super().__init__(mdp_info, policy, actor_optimizer, policy_parameters) - def fit(self, dataset, **info): + def fit(self, dataset): self._replay_memory.add(dataset) if self._replay_memory.initialized: state, action, reward, next_state, absorbing, _ = self._replay_memory.get(self._batch_size()) diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py index df48a9564..ded0856db 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py @@ -9,7 +9,6 @@ from mushroom_rl.approximators import Regressor from mushroom_rl.approximators.parametric import TorchApproximator from mushroom_rl.utils.torch import get_gradient, zero_grad, to_float_tensor -from mushroom_rl.utils.dataset import parse_dataset, compute_J from mushroom_rl.utils.value_functions import compute_gae from mushroom_rl.utils.parameters import to_parameter @@ -80,29 +79,25 @@ def __init__(self, mdp_info, policy, critic_params, ent_coeff=0., max_kl=.001, l _iter='primitive' ) - super().__init__(mdp_info, policy, None) + super().__init__(mdp_info, policy) - def fit(self, dataset, **info): - state, action, reward, next_state, absorbing, last = parse_dataset(dataset) - x = state.astype(np.float32) - u = action.astype(np.float32) - r = reward.astype(np.float32) - xn = next_state.astype(np.float32) + def fit(self, dataset): + state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') - obs = to_float_tensor(x, self.policy.use_cuda) - act = to_float_tensor(u, self.policy.use_cuda) - v_target, np_adv = compute_gae(self._V, x, xn, r, absorbing, last, - self.mdp_info.gamma, self._lambda()) - np_adv = (np_adv - np.mean(np_adv)) / (np.std(np_adv) + 1e-8) - adv = to_float_tensor(np_adv, self.policy.use_cuda) + v_target, adv = compute_gae(self._V, state, next_state, reward, absorbing, last, + self.mdp_info.gamma, self._lambda()) + adv = (adv - torch.mean(adv)) / (torch.std(adv) + 1e-8) + + adv = adv.detach() + v_target = v_target.detach() # Policy update self._old_policy = deepcopy(self.policy) - old_pol_dist = self._old_policy.distribution_t(obs) - old_log_prob = self._old_policy.log_prob_t(obs, act).detach() + old_pol_dist = self._old_policy.distribution_t(state) + old_log_prob = self._old_policy.log_prob_t(state, action).detach() zero_grad(self.policy.parameters()) - loss = self._compute_loss(obs, act, adv, old_log_prob) + loss = self._compute_loss(state, action, adv, old_log_prob) prev_loss = loss.item() @@ -111,26 +106,26 @@ def fit(self, dataset, **info): g = get_gradient(self.policy.parameters()) # Compute direction through conjugate gradient - stepdir = self._conjugate_gradient(g, obs, old_pol_dist) + stepdir = self._conjugate_gradient(g, state, old_pol_dist) # Line search - self._line_search(obs, act, adv, old_log_prob, old_pol_dist, prev_loss, stepdir) + self._line_search(state, action, adv, old_log_prob, old_pol_dist, prev_loss, stepdir) # VF update - self._V.fit(x, v_target, **self._critic_fit_params) + self._V.fit(state, v_target, **self._critic_fit_params) # Print fit information - self._log_info(dataset, x, v_target, old_pol_dist) + self._log_info(dataset, state, v_target, old_pol_dist) self._iter += 1 - def _fisher_vector_product(self, p, obs, old_pol_dist): - p_tensor = torch.from_numpy(p) - if self.policy.use_cuda: - p_tensor = p_tensor.cuda() - - return self._fisher_vector_product_t(p_tensor, obs, old_pol_dist) + # def _fisher_vector_product(self, p, obs, old_pol_dist): + # p_tensor = torch.from_numpy(p) + # if self.policy.use_cuda: + # p_tensor = p_tensor.cuda() + # + # return self._fisher_vector_product_t(p_tensor, obs, old_pol_dist) - def _fisher_vector_product_t(self, p, obs, old_pol_dist): + def _fisher_vector_product(self, p, obs, old_pol_dist): kl = self._compute_kl(obs, old_pol_dist) grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) @@ -142,13 +137,13 @@ def _fisher_vector_product_t(self, p, obs, old_pol_dist): return flat_grad_grad_kl + p * self._cg_damping() def _conjugate_gradient(self, b, obs, old_pol_dist): - p = b.detach().cpu().numpy() - r = b.detach().cpu().numpy() - x = np.zeros_like(p) + p = b.detach() + r = b.detach() + x = torch.zeros_like(p) r2 = r.dot(r) for i in range(self._n_epochs_cg()): - z = self._fisher_vector_product(p, obs, old_pol_dist).detach().cpu().numpy() + z = self._fisher_vector_product(p, obs, old_pol_dist).detach() v = r2 / p.dot(z) x += v * p r -= v * z @@ -163,10 +158,10 @@ def _conjugate_gradient(self, b, obs, old_pol_dist): def _line_search(self, obs, act, adv, old_log_prob, old_pol_dist, prev_loss, stepdir): # Compute optimal step size - direction = self._fisher_vector_product(stepdir, obs, old_pol_dist).detach().cpu().numpy() + direction = self._fisher_vector_product(stepdir, obs, old_pol_dist).detach() shs = .5 * stepdir.dot(direction) - lm = np.sqrt(shs / self._max_kl()) - full_step = stepdir / lm + lm = torch.sqrt(shs / self._max_kl()) + full_step = (stepdir / lm).detach().cpu().numpy() stepsize = 1. # Save old policy parameters @@ -214,7 +209,7 @@ def _log_info(self, dataset, x, v_target, old_pol_dist): logging_kl = torch.mean( torch.distributions.kl.kl_divergence(old_pol_dist, new_pol_dist) ) - avg_rwd = np.mean(compute_J(dataset)) + avg_rwd = np.mean(dataset.undiscounted_return) msg = "Iteration {}:\n\t\t\t\trewards {} vf_loss {}\n\t\t\t\tentropy {} kl {}".format( self._iter, avg_rwd, logging_verr, logging_ent, logging_kl) diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py index 7e2e288c1..0ea85988a 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/black_box_optimization.py @@ -1,7 +1,6 @@ import numpy as np from mushroom_rl.core import Agent -from mushroom_rl.utils.dataset import compute_J class BlackBoxOptimization(Agent): @@ -11,7 +10,7 @@ class BlackBoxOptimization(Agent): do not rely on stochastic and differentiable policies. """ - def __init__(self, mdp_info, distribution, policy, features=None): + def __init__(self, mdp_info, distribution, policy): """ Constructor. @@ -21,32 +20,25 @@ def __init__(self, mdp_info, distribution, policy, features=None): """ self.distribution = distribution - self._theta_list = list() - self._add_save_attr(distribution='mushroom', _theta_list='pickle') + self._add_save_attr(distribution='mushroom') - super().__init__(mdp_info, policy, features) + super().__init__(mdp_info, policy, is_episodic=True) - def episode_start(self): + def episode_start(self, episode_info): theta = self.distribution.sample() - self._theta_list.append(theta) self.policy.set_weights(theta) - super().episode_start() + policy_state, _ = super().episode_start(episode_info) - def fit(self, dataset, **info): - Jep = compute_J(dataset, self.mdp_info.gamma) + return policy_state, theta - Jep = np.array(Jep) - theta = np.array(self._theta_list) + def fit(self, dataset): + Jep = np.array(dataset.discounted_return) + theta = np.array(dataset.theta_list) self._update(Jep, theta) - self._theta_list = list() - - def stop(self): - self._theta_list = list() - def _update(self, Jep, theta): """ Function that implements the update routine of distribution parameters. diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/constrained_reps.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/constrained_reps.py index a3cb7cf2b..e15771546 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/constrained_reps.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/constrained_reps.py @@ -9,7 +9,7 @@ class ConstrainedREPS(BlackBoxOptimization): Episodic Relative Entropy Policy Search algorithm with constrained policy update. """ - def __init__(self, mdp_info, distribution, policy, eps, kappa, features=None): + def __init__(self, mdp_info, distribution, policy, eps, kappa): """ Constructor. @@ -28,7 +28,7 @@ def __init__(self, mdp_info, distribution, policy, eps, kappa, features=None): self._add_save_attr(_eps='mushroom') self._add_save_attr(_kappa='mushroom') - super().__init__(mdp_info, distribution, policy, features) + super().__init__(mdp_info, distribution, policy) def _update(self, Jep, theta): eta_start = np.ones(1) diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/more.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/more.py index 50eaf8c5e..5c7517cb1 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/more.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/more.py @@ -17,7 +17,7 @@ class MORE(BlackBoxOptimization): Peters, Jan R and Lau, Nuno and Pualo Reis, Luis and Neumann, Gerhard. 2015. """ - def __init__(self, mdp_info, distribution, policy, eps, h0=-75, kappa=0.99, features=None): + def __init__(self, mdp_info, distribution, policy, eps, h0=-75, kappa=0.99): """ Constructor. @@ -53,7 +53,7 @@ def __init__(self, mdp_info, distribution, policy, eps, h0=-75, kappa=0.99, feat self._add_save_attr(h0='primitive') self._add_save_attr(kappa='primitive') - super().__init__(mdp_info, distribution, policy, features) + super().__init__(mdp_info, distribution, policy) def _update(self, Jep, theta): diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/pgpe.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/pgpe.py index 9f504fca9..a5fde15e1 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/pgpe.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/pgpe.py @@ -10,8 +10,7 @@ class PGPE(BlackBoxOptimization): Peters J.. 2013. """ - def __init__(self, mdp_info, distribution, policy, optimizer, - features=None): + def __init__(self, mdp_info, distribution, policy, optimizer): """ Constructor. @@ -23,7 +22,7 @@ def __init__(self, mdp_info, distribution, policy, optimizer, self._add_save_attr(optimizer='mushroom') - super().__init__(mdp_info, distribution, policy, features) + super().__init__(mdp_info, distribution, policy) def _update(self, Jep, theta): baseline_num_list = list() diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/reps.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/reps.py index 792a261c3..dd91093c3 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/reps.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/reps.py @@ -13,7 +13,7 @@ class REPS(BlackBoxOptimization): Peters J.. 2013. """ - def __init__(self, mdp_info, distribution, policy, eps, features=None): + def __init__(self, mdp_info, distribution, policy, eps): """ Constructor. @@ -27,7 +27,7 @@ def __init__(self, mdp_info, distribution, policy, eps, features=None): self._add_save_attr(_eps='mushroom') - super().__init__(mdp_info, distribution, policy, features) + super().__init__(mdp_info, distribution, policy) def _update(self, Jep, theta): eta_start = np.ones(1) diff --git a/mushroom_rl/algorithms/policy_search/black_box_optimization/rwr.py b/mushroom_rl/algorithms/policy_search/black_box_optimization/rwr.py index e751d62a3..719105c9c 100644 --- a/mushroom_rl/algorithms/policy_search/black_box_optimization/rwr.py +++ b/mushroom_rl/algorithms/policy_search/black_box_optimization/rwr.py @@ -11,7 +11,7 @@ class RWR(BlackBoxOptimization): Peters J.. 2013. """ - def __init__(self, mdp_info, distribution, policy, beta, features=None): + def __init__(self, mdp_info, distribution, policy, beta): """ Constructor. @@ -24,7 +24,7 @@ def __init__(self, mdp_info, distribution, policy, beta, features=None): self._add_save_attr(_beta='mushroom') - super().__init__(mdp_info, distribution, policy, features) + super().__init__(mdp_info, distribution, policy) def _update(self, Jep, theta): Jep -= np.max(Jep) diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py b/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py index 24d58573f..a7af25842 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py @@ -10,8 +10,7 @@ class eNAC(PolicyGradient): Peters J. 2013. """ - def __init__(self, mdp_info, policy, optimizer, features=None, - critic_features=None): + def __init__(self, mdp_info, policy, optimizer, critic_features=None): """ Constructor. @@ -19,7 +18,7 @@ def __init__(self, mdp_info, policy, optimizer, features=None, critic_features (Features, None): features used by the critic. """ - super().__init__(mdp_info, policy, optimizer, features) + super().__init__(mdp_info, policy, optimizer) self.phi_c = critic_features self.sum_grad_log = None diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py b/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py index c0a68eb46..bf2acd6e6 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py @@ -10,8 +10,8 @@ class GPOMDP(PolicyGradient): 2001. """ - def __init__(self, mdp_info, policy, optimizer, features=None): - super().__init__(mdp_info, policy, optimizer, features) + def __init__(self, mdp_info, policy, optimizer): + super().__init__(mdp_info, policy, optimizer) self.sum_d_log_pi = None self.list_sum_d_log_pi = list() diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py b/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py index e7b85d05c..6a4afef3f 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py @@ -11,7 +11,7 @@ class PolicyGradient(Agent): al.. 2011. """ - def __init__(self, mdp_info, policy, optimizer, features): + def __init__(self, mdp_info, policy, optimizer): """ Constructor. @@ -29,9 +29,9 @@ def __init__(self, mdp_info, policy, optimizer, features): J_episode='numpy' ) - super().__init__(mdp_info, policy, features) + super().__init__(mdp_info, policy) - def fit(self, dataset, **info): + def fit(self, dataset): J = list() self.df = 1. self.J_episode = 0. @@ -133,7 +133,4 @@ def _parse(self, sample): absorbing = sample[4] last = sample[5] - if self.phi is not None: - state = self.phi(state) - return state, action, reward, next_state, absorbing, last diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py b/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py index 15d4db55f..f241f0be3 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py @@ -10,8 +10,8 @@ class REINFORCE(PolicyGradient): Reinforcement Learning", Williams R. J.. 1992. """ - def __init__(self, mdp_info, policy, optimizer, features=None): - super().__init__(mdp_info, policy, optimizer, features) + def __init__(self, mdp_info, policy, optimizer): + super().__init__(mdp_info, policy, optimizer) self.sum_d_log_pi = None self.list_sum_d_log_pi = list() self.baseline_num = list() diff --git a/mushroom_rl/algorithms/value/batch_td/batch_td.py b/mushroom_rl/algorithms/value/batch_td/batch_td.py index 34c313bbb..5178ddc89 100644 --- a/mushroom_rl/algorithms/value/batch_td/batch_td.py +++ b/mushroom_rl/algorithms/value/batch_td/batch_td.py @@ -7,8 +7,7 @@ class BatchTD(Agent): Abstract class to implement a generic Batch TD algorithm. """ - def __init__(self, mdp_info, policy, approximator, approximator_params=None, - fit_params=None, features=None): + def __init__(self, mdp_info, policy, approximator, approximator_params=None, fit_params=None): """ Constructor. @@ -33,7 +32,7 @@ def __init__(self, mdp_info, policy, approximator, approximator_params=None, _fit_params='pickle' ) - super().__init__(mdp_info, policy, features) + super().__init__(mdp_info, policy) def _post_load(self): self.policy.set_q(self.approximator) diff --git a/mushroom_rl/algorithms/value/batch_td/boosted_fqi.py b/mushroom_rl/algorithms/value/batch_td/boosted_fqi.py index ad12da032..9c75ff757 100644 --- a/mushroom_rl/algorithms/value/batch_td/boosted_fqi.py +++ b/mushroom_rl/algorithms/value/batch_td/boosted_fqi.py @@ -1,8 +1,6 @@ import numpy as np from tqdm import trange -from mushroom_rl.utils.dataset import parse_dataset - from .fqi import FQI @@ -31,8 +29,8 @@ def __init__(self, mdp_info, policy, approximator, n_iterations, super().__init__(mdp_info, policy, approximator, n_iterations, approximator_params, fit_params, quiet) - def fit(self, dataset, **info): - state, action, reward, next_state, absorbing, _ = parse_dataset(dataset) + def fit(self, dataset): + state, action, reward, next_state, absorbing, _ = dataset.parse() for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): if self._target is None: self._target = reward diff --git a/mushroom_rl/algorithms/value/batch_td/double_fqi.py b/mushroom_rl/algorithms/value/batch_td/double_fqi.py index f133b80be..ed0354988 100644 --- a/mushroom_rl/algorithms/value/batch_td/double_fqi.py +++ b/mushroom_rl/algorithms/value/batch_td/double_fqi.py @@ -1,8 +1,6 @@ import numpy as np from tqdm import trange -from mushroom_rl.utils.dataset import parse_dataset - from .fqi import FQI @@ -20,7 +18,7 @@ def __init__(self, mdp_info, policy, approximator, n_iterations, super().__init__(mdp_info, policy, approximator, n_iterations, approximator_params, fit_params, quiet) - def fit(self, dataset, **info): + def fit(self, dataset): for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): state = list() action = list() @@ -30,7 +28,7 @@ def fit(self, dataset, **info): half = len(dataset) // 2 for i in range(2): - s, a, r, ss, ab, _ = parse_dataset(dataset[i * half:(i + 1) * half]) + s, a, r, ss, ab, _ = dataset[i * half:(i + 1) * half].parse() state.append(s) action.append(a) reward.append(r) diff --git a/mushroom_rl/algorithms/value/batch_td/fqi.py b/mushroom_rl/algorithms/value/batch_td/fqi.py index 97fa2316f..d52307961 100644 --- a/mushroom_rl/algorithms/value/batch_td/fqi.py +++ b/mushroom_rl/algorithms/value/batch_td/fqi.py @@ -2,7 +2,6 @@ from tqdm import trange from mushroom_rl.algorithms.value.batch_td import BatchTD -from mushroom_rl.utils.dataset import parse_dataset from mushroom_rl.utils.parameters import to_parameter @@ -34,8 +33,8 @@ def __init__(self, mdp_info, policy, approximator, n_iterations, super().__init__(mdp_info, policy, approximator, approximator_params, fit_params) - def fit(self, dataset, **info): - state, action, reward, next_state, absorbing, _ = parse_dataset(dataset) + def fit(self, dataset): + state, action, reward, next_state, absorbing, _ = dataset.parse() for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): if self._target is None: self._target = reward diff --git a/mushroom_rl/algorithms/value/batch_td/lspi.py b/mushroom_rl/algorithms/value/batch_td/lspi.py index 7e28cb816..0e380d79e 100644 --- a/mushroom_rl/algorithms/value/batch_td/lspi.py +++ b/mushroom_rl/algorithms/value/batch_td/lspi.py @@ -3,7 +3,6 @@ from mushroom_rl.algorithms.value.batch_td import BatchTD from mushroom_rl.approximators.parametric import LinearApproximator from mushroom_rl.features import get_action_features -from mushroom_rl.utils.dataset import parse_dataset from mushroom_rl.utils.parameters import to_parameter @@ -13,8 +12,7 @@ class LSPI(BatchTD): "Least-Squares Policy Iteration". Lagoudakis M. G. and Parr R.. 2003. """ - def __init__(self, mdp_info, policy, approximator_params=None, - epsilon=1e-2, fit_params=None, features=None): + def __init__(self, mdp_info, policy, approximator_params=None, epsilon=1e-2, fit_params=None): """ Constructor. @@ -26,30 +24,26 @@ def __init__(self, mdp_info, policy, approximator_params=None, self._add_save_attr(_epsilon='mushroom') - super().__init__(mdp_info, policy, LinearApproximator, - approximator_params, fit_params, features) + super().__init__(mdp_info, policy, LinearApproximator, approximator_params, fit_params) - def fit(self, dataset, **info): - phi_state, action, reward, phi_next_state, absorbing, _ = parse_dataset( - dataset, self.phi) - phi_state_action = get_action_features(phi_state, action, - self.mdp_info.action_space.n) + def fit(self, dataset): + state, action, reward, next_state, absorbing, _ = dataset.parse() + + phi_state = self.approximator.model.phi(state) + phi_next_state = self.approximator.model.phi(next_state) + + phi_state_action = get_action_features(phi_state, action, self.mdp_info.action_space.n) norm = np.inf while norm > self._epsilon(): - q = self.approximator.predict(phi_next_state) + q = self.approximator.predict(next_state) if np.any(absorbing): q *= 1 - absorbing.reshape(-1, 1) next_action = np.argmax(q, axis=1).reshape(-1, 1) - phi_next_state_next_action = get_action_features( - phi_next_state, - next_action, - self.mdp_info.action_space.n - ) - - tmp = phi_state_action - self.mdp_info.gamma *\ - phi_next_state_next_action + phi_next_state_next_action = get_action_features(phi_next_state, next_action, self.mdp_info.action_space.n) + + tmp = phi_state_action - self.mdp_info.gamma * phi_next_state_next_action A = phi_state_action.T.dot(tmp) b = (phi_state_action.T.dot(reward)).reshape(-1, 1) diff --git a/mushroom_rl/algorithms/value/dqn/abstract_dqn.py b/mushroom_rl/algorithms/value/dqn/abstract_dqn.py index c38158cb0..411eac7d7 100644 --- a/mushroom_rl/algorithms/value/dqn/abstract_dqn.py +++ b/mushroom_rl/algorithms/value/dqn/abstract_dqn.py @@ -81,7 +81,7 @@ def __init__(self, mdp_info, policy, approximator, approximator_params, super().__init__(mdp_info, policy) - def fit(self, dataset, **info): + def fit(self, dataset): self._fit(dataset) self._n_updates += 1 @@ -121,16 +121,9 @@ def _fit_prioritized(self, dataset): self.approximator.fit(state, action, q, weights=is_weight, **self._fit_params) - def draw_action(self, state): - action = super().draw_action(np.array(state)) - - return action - - def _initialize_regressors(self, approximator, apprx_params_train, - apprx_params_target): + def _initialize_regressors(self, approximator, apprx_params_train, apprx_params_target): self.approximator = Regressor(approximator, **apprx_params_train) - self.target_approximator = Regressor(approximator, - **apprx_params_target) + self.target_approximator = Regressor(approximator, **apprx_params_target) self._update_target() def _update_target(self): diff --git a/mushroom_rl/algorithms/value/dqn/categorical_dqn.py b/mushroom_rl/algorithms/value/dqn/categorical_dqn.py index daaea344a..7dee67c1c 100644 --- a/mushroom_rl/algorithms/value/dqn/categorical_dqn.py +++ b/mushroom_rl/algorithms/value/dqn/categorical_dqn.py @@ -113,7 +113,7 @@ def __init__(self, mdp_info, policy, approximator_params, n_atoms, v_min, super().__init__(mdp_info, policy, TorchApproximator, **params) - def fit(self, dataset, **info): + def fit(self, dataset): self._replay_memory.add(dataset) if self._replay_memory.initialized: state, action, reward, next_state, absorbing, _ =\ diff --git a/mushroom_rl/algorithms/value/dqn/maxmin_dqn.py b/mushroom_rl/algorithms/value/dqn/maxmin_dqn.py index 144c64d91..8f202eef9 100644 --- a/mushroom_rl/algorithms/value/dqn/maxmin_dqn.py +++ b/mushroom_rl/algorithms/value/dqn/maxmin_dqn.py @@ -11,8 +11,7 @@ class MaxminDQN(DQN): Lan Q. et al.. 2020. """ - def __init__(self, mdp_info, policy, approximator, n_approximators, - **params): + def __init__(self, mdp_info, policy, approximator, n_approximators, **params): """ Constructor. @@ -26,17 +25,15 @@ def __init__(self, mdp_info, policy, approximator, n_approximators, super().__init__(mdp_info, policy, approximator, **params) - def fit(self, dataset, **info): + def fit(self, dataset): self._fit_params['idx'] = np.random.randint(self._n_approximators) - super().fit(dataset, **info) + super().fit(dataset) - def _initialize_regressors(self, approximator, apprx_params_train, - apprx_params_target): + def _initialize_regressors(self, approximator, apprx_params_train, apprx_params_target): self.approximator = Regressor(approximator, n_models=self._n_approximators, - prediction='min', - **apprx_params_train) + prediction='min', **apprx_params_train) self.target_approximator = Regressor(approximator, n_models=self._n_approximators, prediction='min', @@ -45,5 +42,4 @@ def _initialize_regressors(self, approximator, apprx_params_train, def _update_target(self): for i in range(len(self.target_approximator)): - self.target_approximator[i].set_weights( - self.approximator[i].get_weights()) + self.target_approximator[i].set_weights(self.approximator[i].get_weights()) diff --git a/mushroom_rl/algorithms/value/dqn/quantile_dqn.py b/mushroom_rl/algorithms/value/dqn/quantile_dqn.py index df60e9767..32a85b0c4 100644 --- a/mushroom_rl/algorithms/value/dqn/quantile_dqn.py +++ b/mushroom_rl/algorithms/value/dqn/quantile_dqn.py @@ -97,7 +97,7 @@ def __init__(self, mdp_info, policy, approximator_params, n_quantiles, **params) super().__init__(mdp_info, policy, TorchApproximator, **params) - def fit(self, dataset, **info): + def fit(self, dataset): self._replay_memory.add(dataset) if self._replay_memory.initialized: state, action, reward, next_state, absorbing, _ =\ diff --git a/mushroom_rl/algorithms/value/dqn/rainbow.py b/mushroom_rl/algorithms/value/dqn/rainbow.py index f167bae98..f058503fb 100644 --- a/mushroom_rl/algorithms/value/dqn/rainbow.py +++ b/mushroom_rl/algorithms/value/dqn/rainbow.py @@ -119,7 +119,7 @@ def __init__(self, mdp_info, policy, approximator_params, n_atoms, v_min, super().__init__(mdp_info, policy, TorchApproximator, **params) - def fit(self, dataset, **info): + def fit(self, dataset): self._replay_memory.add(dataset, np.ones(len(dataset)) * self._replay_memory.max_priority, n_steps_return=self._n_steps_return, gamma=self.mdp_info.gamma) if self._replay_memory.initialized: diff --git a/mushroom_rl/algorithms/value/td/q_lambda.py b/mushroom_rl/algorithms/value/td/q_lambda.py index 2181fb557..c8e6cd32e 100644 --- a/mushroom_rl/algorithms/value/td/q_lambda.py +++ b/mushroom_rl/algorithms/value/td/q_lambda.py @@ -44,7 +44,7 @@ def _update(self, state, action, reward, next_state, absorbing): self.Q.table += self._alpha(state, action) * delta * self.e.table self.e.table *= self.mdp_info.gamma * self._lambda() - def episode_start(self): + def episode_start(self, episode_info): self.e.reset() - super().episode_start() + return super().episode_start(episode_info) diff --git a/mushroom_rl/algorithms/value/td/rq_learning.py b/mushroom_rl/algorithms/value/td/rq_learning.py index 4a7da860a..0a986e7f2 100644 --- a/mushroom_rl/algorithms/value/td/rq_learning.py +++ b/mushroom_rl/algorithms/value/td/rq_learning.py @@ -57,16 +57,13 @@ def _update(self, state, action, reward, next_state, absorbing): q_next = self._next_q(next_state) if self.delta is not None: - beta = alpha * self.delta(state, action, target=q_next, - factor=alpha) + beta = alpha * self.delta(state, action, target=q_next, factor=alpha) else: beta = self.beta(state, action, target=q_next) - self.Q_tilde[state, action] += beta * (q_next - self.Q_tilde[ - state, action]) + self.Q_tilde[state, action] += beta * (q_next - self.Q_tilde[state, action]) - self.Q[state, action] = self.R_tilde[ - state, action] + self.mdp_info.gamma * self.Q_tilde[state, action] + self.Q[state, action] = self.R_tilde[state, action] + self.mdp_info.gamma * self.Q_tilde[state, action] def _next_q(self, next_state): """ @@ -81,6 +78,6 @@ def _next_q(self, next_state): if self.off_policy: return np.max(self.Q[next_state, :]) else: - self.next_action = self.draw_action(next_state) + self.next_action, _ = self.draw_action(next_state) return self.Q[next_state, self.next_action] diff --git a/mushroom_rl/algorithms/value/td/sarsa.py b/mushroom_rl/algorithms/value/td/sarsa.py index 253dac401..e376e0e63 100644 --- a/mushroom_rl/algorithms/value/td/sarsa.py +++ b/mushroom_rl/algorithms/value/td/sarsa.py @@ -15,7 +15,7 @@ def __init__(self, mdp_info, policy, learning_rate): def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] - self.next_action = self.draw_action(next_state) + self.next_action, _ = self.draw_action(next_state) q_next = self.Q[next_state, self.next_action] if not absorbing else 0. self.Q[state, action] = q_current + self._alpha(state, action) * ( diff --git a/mushroom_rl/algorithms/value/td/sarsa_lambda.py b/mushroom_rl/algorithms/value/td/sarsa_lambda.py index 0f5a4999c..7e82d29a0 100644 --- a/mushroom_rl/algorithms/value/td/sarsa_lambda.py +++ b/mushroom_rl/algorithms/value/td/sarsa_lambda.py @@ -9,8 +9,7 @@ class SARSALambda(TD): The SARSA(lambda) algorithm for finite MDPs. """ - def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, - trace='replacing'): + def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, trace='replacing'): """ Constructor. @@ -33,7 +32,7 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] - self.next_action = self.draw_action(next_state) + self.next_action, _ = self.draw_action(next_state) q_next = self.Q[next_state, self.next_action] if not absorbing else 0. delta = reward + self.mdp_info.gamma * q_next - q_current @@ -42,7 +41,7 @@ def _update(self, state, action, reward, next_state, absorbing): self.Q.table += self._alpha(state, action) * delta * self.e.table self.e.table *= self.mdp_info.gamma * self._lambda() - def episode_start(self): + def episode_start(self, episode_info): self.e.reset() - super().episode_start() + return super().episode_start(episode_info) diff --git a/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py b/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py index 1c434bce1..c6e7b7aa4 100644 --- a/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py +++ b/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py @@ -10,8 +10,7 @@ class SARSALambdaContinuous(TD): Continuous version of SARSA(lambda) algorithm. """ - def __init__(self, mdp_info, policy, approximator, learning_rate, - lambda_coeff, features, approximator_params=None): + def __init__(self, mdp_info, policy, approximator, learning_rate, lambda_coeff, approximator_params=None): """ Constructor. @@ -19,8 +18,7 @@ def __init__(self, mdp_info, policy, approximator, learning_rate, lambda_coeff ([float, Parameter]): eligibility trace coefficient. """ - approximator_params = dict() if approximator_params is None else \ - approximator_params + approximator_params = dict() if approximator_params is None else approximator_params Q = Regressor(approximator, **approximator_params) self.e = np.zeros(Q.weights_size) @@ -31,21 +29,17 @@ def __init__(self, mdp_info, policy, approximator, learning_rate, e='numpy' ) - super().__init__(mdp_info, policy, Q, learning_rate, features) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): - phi_state = self.phi(state) - q_current = self.Q.predict(phi_state, action) + q_current = self.Q.predict(state, action) alpha = self._alpha(state, action) - self.e = self.mdp_info.gamma * self._lambda() * self.e + self.Q.diff( - phi_state, action) + self.e = self.mdp_info.gamma * self._lambda() * self.e + self.Q.diff(state, action) - self.next_action = self.draw_action(next_state) - phi_next_state = self.phi(next_state) - q_next = self.Q.predict(phi_next_state, - self.next_action) if not absorbing else 0. + self.next_action, _ = self.draw_action(next_state) + q_next = self.Q.predict(next_state, self.next_action) if not absorbing else 0. delta = reward + self.mdp_info.gamma * q_next - q_current @@ -53,7 +47,7 @@ def _update(self, state, action, reward, next_state, absorbing): theta += alpha * delta * self.e self.Q.set_weights(theta) - def episode_start(self): + def episode_start(self, episode_info): self.e = np.zeros(self.Q.weights_size) - super().episode_start() + return super().episode_start(episode_info) diff --git a/mushroom_rl/algorithms/value/td/td.py b/mushroom_rl/algorithms/value/td/td.py index 29dd92b48..bf87ae2b0 100644 --- a/mushroom_rl/algorithms/value/td/td.py +++ b/mushroom_rl/algorithms/value/td/td.py @@ -8,14 +8,12 @@ class TD(Agent): Implements functions to run TD algorithms. """ - def __init__(self, mdp_info, policy, approximator, learning_rate, - features=None): + def __init__(self, mdp_info, policy, approximator, learning_rate): """ Constructor. Args: - approximator (object): the approximator to use to fit the - Q-function; + approximator: the approximator to use to fit the Q-function; learning_rate (Parameter): the learning rate. """ @@ -26,36 +24,14 @@ def __init__(self, mdp_info, policy, approximator, learning_rate, self._add_save_attr(_alpha='mushroom', Q='mushroom') - super().__init__(mdp_info, policy, features) + super().__init__(mdp_info, policy) - def fit(self, dataset, **info): + def fit(self, dataset): assert len(dataset) == 1 - state, action, reward, next_state, absorbing = self._parse(dataset) + state, action, reward, next_state, absorbing, _ = dataset.item() self._update(state, action, reward, next_state, absorbing) - @staticmethod - def _parse(dataset): - """ - Utility to parse the dataset that is supposed to contain only a sample. - - Args: - dataset (list): the current episode step. - - Returns: - A tuple containing state, action, reward, next state, absorbing and - last flag. - - """ - sample = dataset[0] - state = sample[0] - action = sample[1] - reward = sample[2] - next_state = sample[3] - absorbing = sample[4] - - return state, action, reward, next_state, absorbing - def _update(self, state, action, reward, next_state, absorbing): """ Update the Q-table. diff --git a/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py b/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py index 8ea545573..196765c3e 100644 --- a/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py +++ b/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py @@ -13,8 +13,7 @@ class TrueOnlineSARSALambda(TD): "True Online TD(lambda)". Seijen H. V. et al.. 2014. """ - def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, - features, approximator_params=None): + def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, approximator_params=None): """ Constructor. @@ -22,8 +21,7 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, lambda_coeff ([float, Parameter]): eligibility trace coefficient. """ - approximator_params = dict() if approximator_params is None else \ - approximator_params + approximator_params = dict() if approximator_params is None else approximator_params Q = Regressor(LinearApproximator, **approximator_params) self.e = np.zeros(Q.weights_size) @@ -36,13 +34,12 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, e='numpy' ) - super().__init__(mdp_info, policy, Q, learning_rate, features) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): - phi_state = self.phi(state) - phi_state_action = get_action_features(phi_state, action, - self.mdp_info.action_space.n) - q_current = self.Q.predict(phi_state, action) + phi_state = self.Q.model.phi(state) + phi_state_action = get_action_features(phi_state, action, self.mdp_info.action_space.n) + q_current = self.Q.predict(state, action) if self._q_old is None: self._q_old = q_current @@ -50,25 +47,22 @@ def _update(self, state, action, reward, next_state, absorbing): alpha = self._alpha(state, action) e_phi = self.e.dot(phi_state_action) - self.e = self.mdp_info.gamma * self._lambda() * self.e + alpha * ( - 1. - self.mdp_info.gamma * self._lambda.get_value() * e_phi) * phi_state_action + self.e = (self.mdp_info.gamma * self._lambda() * self.e + + alpha * (1. - self.mdp_info.gamma * self._lambda.get_value() * e_phi) * phi_state_action) - self.next_action = self.draw_action(next_state) - phi_next_state = self.phi(next_state) - q_next = self.Q.predict(phi_next_state, - self.next_action) if not absorbing else 0. + self.next_action, _ = self.draw_action(next_state) + q_next = self.Q.predict(next_state, self.next_action) if not absorbing else 0. delta = reward + self.mdp_info.gamma * q_next - self._q_old theta = self.Q.get_weights() - theta += delta * self.e + alpha * ( - self._q_old - q_current) * phi_state_action + theta += delta * self.e + alpha * (self._q_old - q_current) * phi_state_action self.Q.set_weights(theta) self._q_old = q_next - def episode_start(self): + def episode_start(self, episode_info): self._q_old = None self.e = np.zeros(self.Q.weights_size) - super().episode_start() + return super().episode_start(episode_info) diff --git a/mushroom_rl/algorithms/value/td/weighted_q_learning.py b/mushroom_rl/algorithms/value/td/weighted_q_learning.py index b6cc15463..bae01dbc7 100644 --- a/mushroom_rl/algorithms/value/td/weighted_q_learning.py +++ b/mushroom_rl/algorithms/value/td/weighted_q_learning.py @@ -56,14 +56,11 @@ def _update(self, state, action, reward, next_state, absorbing): alpha = self._alpha(state, action) self.Q[state, action] = q_current + alpha * (target - q_current) - self._Q2[state, action] = q2_current + alpha * ( - target ** 2 - q2_current - ) + self._Q2[state, action] = q2_current + alpha * (target ** 2 - q2_current) self._n_updates[state, action] += 1 - self._w2[state, action] = (1 - alpha) ** 2 * self._w2[ - state, action] + alpha ** 2 + self._w2[state, action] = (1 - alpha) ** 2 * self._w2[state, action] + alpha ** 2 self._w1[state, action] = (1 - alpha) * self._w1[state, action] + alpha if self._n_updates[state, action] > 1: @@ -90,8 +87,7 @@ def _next_q(self, next_state): sigmas[a] = self._sigma[next_state, np.array([a])] if self._sampling: - samples = np.random.normal(np.repeat([means], self._precision, 0), - np.repeat([sigmas], self._precision, 0)) + samples = np.random.normal(np.repeat([means], self._precision, 0), np.repeat([sigmas], self._precision, 0)) max_idx = np.argmax(samples, axis=1) max_idx, max_count = np.unique(max_idx, return_counts=True) count = np.zeros(means.size) diff --git a/mushroom_rl/approximators/parametric/cmac.py b/mushroom_rl/approximators/parametric/cmac.py index f0af55ee9..90b7c70c4 100644 --- a/mushroom_rl/approximators/parametric/cmac.py +++ b/mushroom_rl/approximators/parametric/cmac.py @@ -16,21 +16,18 @@ def __init__(self, tilings, weights=None, output_shape=(1,), **kwargs): Args: tilings (list): list of tilings to discretize the input space. - weights (np.ndarray): array of weights to initialize the weights - of the approximator; - input_shape (np.ndarray, None): the shape of the input of the - model; - output_shape (np.ndarray, (1,)): the shape of the output of the - model; + weights (np.ndarray): array of weights to initialize the weights of the approximator; + input_shape (np.ndarray, None): the shape of the input of the model; + output_shape (np.ndarray, (1,)): the shape of the output of the model; **kwargs: other params of the approximator. """ - self._phi = Features(tilings=tilings) + phi = Features(tilings=tilings) self._n = len(tilings) - super().__init__(weights=weights, input_shape=(self._phi.size,), output_shape=output_shape) + super().__init__(weights=weights, input_shape=(phi.size,), output_shape=output_shape, phi=phi) - self._add_save_attr(_phi='pickle', _n='primitive') + self._add_save_attr(_n='primitive') def fit(self, x, y, alpha=1.0, **kwargs): """ @@ -40,8 +37,7 @@ def fit(self, x, y, alpha=1.0, **kwargs): x (np.ndarray): input; y (np.ndarray): target; alpha (float): learning rate; - **kwargs: other parameters used by the fit method of the - regressor. + **kwargs: other parameters used by the fit method of the regressor. """ y_hat = self.predict(x) @@ -64,8 +60,7 @@ def predict(self, x, **predict_params): Args: x (np.ndarray): input; - **predict_params: other parameters used by the predict method - the regressor. + **predict_params: other parameters used by the predict method the regressor. Returns: The predictions of the model. @@ -84,16 +79,14 @@ def predict(self, x, **predict_params): def diff(self, state, action=None): """ - Compute the derivative of the output w.r.t. ``state``, and ``action`` - if provided. + Compute the derivative of the output w.r.t. ``state``, and ``action`` if provided. Args: state (np.ndarray): the state; action (np.ndarray, None): the action. Returns: - The derivative of the output w.r.t. ``state``, and ``action`` - if provided. + The derivative of the output w.r.t. ``state``, and ``action`` if provided. """ diff --git a/mushroom_rl/approximators/parametric/linear.py b/mushroom_rl/approximators/parametric/linear.py index 942a8c5bc..55c4b0fcb 100644 --- a/mushroom_rl/approximators/parametric/linear.py +++ b/mushroom_rl/approximators/parametric/linear.py @@ -8,18 +8,16 @@ class LinearApproximator(Serializable): This class implements a linear approximator. """ - def __init__(self, weights=None, input_shape=None, output_shape=(1,), + def __init__(self, weights=None, input_shape=None, output_shape=(1,), phi=None, **kwargs): """ Constructor. Args: - weights (np.ndarray): array of weights to initialize the weights - of the approximator; - input_shape (np.ndarray, None): the shape of the input of the - model; - output_shape (np.ndarray, (1,)): the shape of the output of the - model; + weights (np.ndarray): array of weights to initialize the weights of the approximator; + input_shape (np.ndarray, None): the shape of the input of the model; + output_shape (np.ndarray, (1,)): the shape of the output of the model; + phi (object, None): features to extract from the state; **kwargs: other params of the approximator. """ @@ -36,7 +34,11 @@ def __init__(self, weights=None, input_shape=None, output_shape=(1,), raise ValueError('You should specify the initial parameter vector' ' or the input dimension') - self._add_save_attr(_w='numpy') + self._phi = phi + self._add_save_attr( + _w='numpy', + _phi='pickle' + ) def fit(self, x, y, **fit_params): """ @@ -45,11 +47,11 @@ def fit(self, x, y, **fit_params): Args: x (np.ndarray): input; y (np.ndarray): target; - **fit_params: other parameters used by the fit method of the - regressor. + **fit_params: other parameters used by the fit method of the regressor. """ - self._w = np.atleast_2d(np.linalg.pinv(x).dot(y).T) + phi = np.atleast_2d(self.phi(x)) + self._w = np.atleast_2d(np.linalg.pinv(phi).dot(y).T) def predict(self, x, **predict_params): """ @@ -57,16 +59,17 @@ def predict(self, x, **predict_params): Args: x (np.ndarray): input; - **predict_params: other parameters used by the predict method - the regressor. + **predict_params: other parameters used by the predict method the regressor. Returns: The predictions of the model. """ - prediction = np.ones((x.shape[0], self._w.shape[0])) - for i, x_i in enumerate(x): - prediction[i] = x_i.dot(self._w.T) + phi = np.atleast_2d(self.phi(x)) + + prediction = np.ones((phi.shape[0], self._w.shape[0])) + for i, phi_i in enumerate(phi): + prediction[i] = phi_i.dot(self._w.T) return prediction @@ -99,10 +102,15 @@ def set_weights(self, w): """ self._w = w.reshape(self._w.shape) + def phi(self, x): + if self._phi is not None: + return self._phi(x) + else: + return x + def diff(self, state, action=None): """ - Compute the derivative of the output w.r.t. ``state``, and ``action`` - if provided. + Compute the derivative of the output w.r.t. ``state``, and ``action`` if provided. Args: state (np.ndarray): the state; @@ -114,7 +122,7 @@ def diff(self, state, action=None): """ if len(self._w.shape) == 1 or self._w.shape[0] == 1: - return state + return self.phi(state) else: n_phi = self._w.shape[1] n_outs = self._w.shape[0] @@ -125,13 +133,14 @@ def diff(self, state, action=None): start = 0 for i in range(n_outs): stop = start + n_phi - df[start:stop, i] = state + df[start:stop, i] = self.phi(state) start = stop else: shape = (n_phi * n_outs) df = np.zeros(shape) start = action[0] * n_phi stop = start + n_phi - df[start:stop] = state + df[start:stop] = self.phi(state) return df + diff --git a/mushroom_rl/approximators/parametric/torch_approximator.py b/mushroom_rl/approximators/parametric/torch_approximator.py index 51a896451..b3a5e4c31 100644 --- a/mushroom_rl/approximators/parametric/torch_approximator.py +++ b/mushroom_rl/approximators/parametric/torch_approximator.py @@ -10,9 +10,8 @@ class TorchApproximator(Serializable): """ Class to interface a pytorch model to the mushroom Regressor interface. - This class implements all is needed to use a generic pytorch model and train - it using a specified optimizer and objective function. - This class supports also minibatches. + This class implements all is needed to use a generic pytorch model and train it using a specified optimizer and + objective function. This class supports also minibatches. """ def __init__(self, input_shape, output_shape, network, optimizer=None, diff --git a/mushroom_rl/core/__init__.py b/mushroom_rl/core/__init__.py index ebf9082db..16d34d519 100644 --- a/mushroom_rl/core/__init__.py +++ b/mushroom_rl/core/__init__.py @@ -1,9 +1,10 @@ from .core import Core +from .dataset import Dataset from .environment import Environment, MDPInfo -from .agent import Agent +from .agent import Agent, AgentInfo from .serialization import Serializable from .logger import Logger import mushroom_rl.environments -__all__ = ['Core', 'Environment', 'MDPInfo', 'Agent', 'Serializable', 'Logger'] +__all__ = ['Core', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', 'Serializable', 'Logger'] diff --git a/mushroom_rl/core/_impl/__init__.py b/mushroom_rl/core/_impl/__init__.py new file mode 100644 index 000000000..b43e6afa1 --- /dev/null +++ b/mushroom_rl/core/_impl/__init__.py @@ -0,0 +1,5 @@ +from .numpy_dataset import NumpyDataset +from .torch_dataset import TorchDataset +from .list_dataset import ListDataset +from .type_conversions import DataConversion, NumpyConversion, TorchConversion, ListConversion +from .core_logic import CoreLogic \ No newline at end of file diff --git a/mushroom_rl/core/_impl/core_logic.py b/mushroom_rl/core/_impl/core_logic.py new file mode 100644 index 000000000..babb41e27 --- /dev/null +++ b/mushroom_rl/core/_impl/core_logic.py @@ -0,0 +1,81 @@ +from tqdm import tqdm + + +class CoreLogic(object): + def __init__(self): + self.fit_required = None + self.move_required = None + + self._total_episodes_counter = 0 + self._total_steps_counter = 0 + self._current_episodes_counter = 0 + self._current_steps_counter = 0 + + self._n_episodes = None + self._n_steps_per_fit = None + self._n_episodes_per_fit = None + + self._steps_progress_bar = None + self._episodes_progress_bar = None + + def initialize_fit(self, n_steps_per_fit, n_episodes_per_fit): + assert (n_episodes_per_fit is not None and n_steps_per_fit is None) \ + or (n_episodes_per_fit is None and n_steps_per_fit is not None) + + self._n_steps_per_fit = n_steps_per_fit + self._n_episodes_per_fit = n_episodes_per_fit + + if n_steps_per_fit is not None: + self.fit_required = lambda: self._current_steps_counter >= self._n_steps_per_fit + else: + self.fit_required = lambda: self._current_episodes_counter >= self._n_episodes_per_fit + + def initialize_evaluate(self): + self.fit_required = lambda: False + + def initialize_run(self, n_steps, n_episodes, initial_states, quiet): + assert n_episodes is not None and n_steps is None and initial_states is None\ + or n_episodes is None and n_steps is not None and initial_states is None\ + or n_episodes is None and n_steps is None and initial_states is not None + + self._n_episodes = len(initial_states) if initial_states is not None else n_episodes + + if n_steps is not None: + self.move_required = lambda: self._total_steps_counter < n_steps + + self._steps_progress_bar = tqdm(total=n_steps, dynamic_ncols=True, disable=quiet, leave=False) + self._episodes_progress_bar = tqdm(disable=True) + else: + self.move_required = lambda: self._total_episodes_counter < self._n_episodes + + self._steps_progress_bar = tqdm(disable=True) + self._episodes_progress_bar = tqdm(total=self._n_episodes, dynamic_ncols=True, disable=quiet, leave=False) + + self._total_episodes_counter = 0 + self._total_steps_counter = 0 + self._current_episodes_counter = 0 + self._current_steps_counter = 0 + + def get_initial_state(self, initial_states): + if initial_states is None or self._total_episodes_counter == self._n_episodes: + return None + else: + return initial_states[self._total_episodes_counter] + + def after_step(self, last): + self._total_steps_counter += 1 + self._current_steps_counter += 1 + self._steps_progress_bar.update(1) + + if last: + self._total_episodes_counter += 1 + self._current_episodes_counter += 1 + self._episodes_progress_bar.update(1) + + def after_fit(self): + self._current_episodes_counter = 0 + self._current_steps_counter = 0 + + def terminate_run(self): + self._steps_progress_bar.close() + self._episodes_progress_bar.close() \ No newline at end of file diff --git a/mushroom_rl/core/_impl/list_dataset.py b/mushroom_rl/core/_impl/list_dataset.py new file mode 100644 index 000000000..f88e3c721 --- /dev/null +++ b/mushroom_rl/core/_impl/list_dataset.py @@ -0,0 +1,104 @@ +from copy import deepcopy + +import numpy as np + +from mushroom_rl.core.serialization import Serializable + + +class ListDataset(Serializable): + def __init__(self, is_stateful): + self._dataset = list() + self._policy_dataset = list() + self._is_stateful = is_stateful + + self._add_save_attr( + _dataset='pickle', + _policy_dataset='pickle', + _is_stateful='primitive' + ) + + @classmethod + def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, policy_states=None, + policy_next_states=None): + is_stateful = (policy_states is not None) and (policy_next_states is not None) + + dataset = cls(is_stateful) + + if dataset._is_stateful: + for s, a, r, ss, ab, last, ps, pss in zip(states, actions, rewards, next_states, + absorbings.astype(bool), lasts.astype(bool), + policy_states, policy_next_states): + dataset.append(s, a, r.item(), ss, ab.item(), last.item(), ps.item(), pss.item()) + else: + for s, a, r, ss, ab, last in zip(states, actions, rewards, next_states, + absorbings.astype(bool), lasts.astype(bool)): + dataset.append(s, a, r.item(), ss, ab.item(), last.item()) + + return dataset + + def __len__(self): + return len(self._dataset) + + def append(self, *step): + step_copy = deepcopy(step) + self._dataset.append(step_copy[:6]) + if self._is_stateful: + self._policy_dataset.append(step_copy[6:]) + + def clear(self): + self._dataset = list() + + def get_view(self, index): + view = self.copy() + + if isinstance(index, (int, slice)): + view._dataset = self._dataset[index] + else: + view._dataset = [self._dataset[i] for i in index] + + return view + + def __getitem__(self, index): + return self._dataset[index] + + def __add__(self, other): + result = self.copy() + last_step = result._dataset[-1] + modified_last_step = last_step[:-1] + (True,) + result._dataset[-1] = modified_last_step + result._dataset = result._dataset + other._dataset + result._policy_dataset = result._policy_dataset + other._policy_dataset + + return result + + @property + def state(self): + return [step[0] for step in self._dataset] + + @property + def action(self): + return [step[1] for step in self._dataset] + + @property + def reward(self): + return [step[2] for step in self._dataset] + + @property + def next_state(self): + return [step[3] for step in self._dataset] + + @property + def absorbing(self): + return [step[4] for step in self._dataset] + + @property + def last(self): + return [step[5] for step in self._dataset] + + @property + def policy_state(self): + return [step[6] for step in self._dataset] + + @property + def policy_next_state(self): + return [step[7] for step in self._dataset] diff --git a/mushroom_rl/core/_impl/numpy_dataset.py b/mushroom_rl/core/_impl/numpy_dataset.py new file mode 100644 index 000000000..d8986d2c5 --- /dev/null +++ b/mushroom_rl/core/_impl/numpy_dataset.py @@ -0,0 +1,196 @@ +import numpy as np + +from mushroom_rl.core.serialization import Serializable + + +class NumpyDataset(Serializable): + def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, policy_state_shape): + flags_len = action_shape[0] + + self._state_type = state_type + self._action_type = action_type + + self._states = np.empty(state_shape, dtype=self._state_type) + self._actions = np.empty(action_shape, dtype=self._action_type) + self._rewards = np.empty(reward_shape, dtype=float) + self._next_states = np.empty(state_shape, dtype=self._state_type) + self._absorbing = np.empty(flags_len, dtype=bool) + self._last = np.empty(flags_len, dtype=bool) + self._len = 0 + + if policy_state_shape is None: + self._policy_states = None + self._policy_next_states = None + else: + self._policy_states = np.empty(policy_state_shape, dtype=float) + self._policy_next_states = np.empty(policy_state_shape, dtype=float) + + self._add_save_attr( + _state_type='primitive', + _action_type='primitive', + _states='numpy', + _actions='numpy', + _rewards='numpy', + _next_states='numpy', + _absorbing='numpy', + _last='numpy', + _policy_states='numpy', + _policy_next_states='numpy', + _len='primitive' + ) + + @classmethod + def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, + policy_states=None, policy_next_states=None): + if not isinstance(states, np.ndarray): + states = states.numpy() + actions = actions.numpy() + rewards = rewards.numpy() + next_states = next_states.numpy() + absorbings = absorbings.numpy() + lasts = lasts.numpy() + + dataset = cls.__new__(cls) + + dataset._state_type = states.dtype + dataset._action_type = actions.dtype + + dataset._states = states + dataset._actions = actions + dataset._rewards = rewards + dataset._next_states = next_states + dataset._absorbing = absorbings + dataset._last = lasts + dataset._len = len(lasts) + + if policy_states is not None and policy_next_states is not None: + if not isinstance(policy_states, np.ndarray): + policy_states = policy_states.numpy() + policy_next_states = policy_next_states.numpy() + + dataset._policy_states = policy_states + dataset._policy_next_states = policy_next_states + + dataset._add_save_attr( + _state_type='primitive', + _action_type='primitive', + _states='numpy', + _actions='numpy', + _rewards='numpy', + _next_states='numpy', + _absorbing='numpy', + _last='numpy', + _policy_states='numpy', + _policy_next_states='numpy', + _len='primitive' + ) + + return dataset + + def __len__(self): + return self._len + + def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None): + i = self._len + + self._states[i] = state + self._actions[i] = action + self._rewards[i] = reward + self._next_states[i] = next_state + self._absorbing[i] = absorbing + self._last[i] = last + + if self._is_stateful: + self._policy_states[i] = policy_state + self._policy_next_states[i] = policy_next_state + + self._len += 1 + + def clear(self): + self._states = np.empty_like(self._states) + self._actions = np.empty_like(self._actions) + self._rewards = np.empty_like(self._rewards) + self._next_states = np.empty_like(self._next_states) + self._absorbing = np.empty_like(self._absorbing) + self._last = np.empty_like(self._last) + + if self._is_stateful: + self._policy_states = np.empty_like(self._policy_states) + self._policy_next_states = np.empty_like(self._policy_next_states) + + self._len = 0 + + def get_view(self, index): + view = self.copy() + + view._states = self.state[index, ...] + view._actions = self.action[index, ...] + view._rewards = self.reward[index, ...] + view._next_states = self.next_state[index, ...] + view._absorbing = self.absorbing[index, ...] + view._last = self.last[index, ...] + view._len = view._states.shape[0] + + if self._is_stateful: + view._policy_states = self._policy_states[index, ...] + view._policy_next_states = self._policy_next_states[index, ...] + + return view + + def __getitem__(self, index): + return self._states[index], self._actions[index], self._rewards[index], self._next_states[index], \ + self._absorbing[index], self._last[index] + + def __add__(self, other): + result = self.copy() + + result._states = np.concatenate((self.state, other.state)) + result._actions = np.concatenate((self.action, other.action)) + result._rewards = np.concatenate((self.reward, other.reward)) + result._next_states = np.concatenate((self.next_state, other.next_state)) + result._absorbing = np.concatenate((self.absorbing, other.absorbing)) + result._last = np.concatenate((self.last, other.last)) + result._last[len(self)-1] = True + result._len = len(self) + len(other) + + if self._is_stateful: + result._policy_states = np.concatenate((self.policy_state, other.policy_state)) + result._policy_next_states = np.concatenate((self.policy_next_state, other.policy_next_state)) + + return result + + @property + def state(self): + return self._states[:len(self)] + + @property + def action(self): + return self._actions[:len(self)] + + @property + def reward(self): + return self._rewards[:len(self)] + + @property + def next_state(self): + return self._next_states[:len(self)] + + @property + def absorbing(self): + return self._absorbing[:len(self)] + + @property + def last(self): + return self._last[:len(self)] + + @property + def policy_state(self): + return self._policy_states[:len(self)] + + @property + def policy_next_state(self): + return self._policy_next_states[:len(self)] + + @property + def _is_stateful(self): + return self._policy_states is not None \ No newline at end of file diff --git a/mushroom_rl/core/_impl/torch_dataset.py b/mushroom_rl/core/_impl/torch_dataset.py new file mode 100644 index 000000000..7e7309e88 --- /dev/null +++ b/mushroom_rl/core/_impl/torch_dataset.py @@ -0,0 +1,196 @@ +import torch + +from mushroom_rl.core.serialization import Serializable + + +class TorchDataset(Serializable): + def __init__(self, state_type, state_shape, action_type, action_shape, reward_shape, policy_state_shape): + flags_len = action_shape[0] + + self._state_type = state_type + self._action_type = action_type + + self._states = torch.empty(*state_shape, dtype=self._state_type) + self._actions = torch.empty(*action_shape, dtype=self._action_type) + self._rewards = torch.empty(*reward_shape, dtype=torch.float) + self._next_states = torch.empty(*state_shape, dtype=self._state_type) + self._absorbing = torch.empty(flags_len, dtype=torch.bool) + self._last = torch.empty(flags_len, dtype=torch.bool) + self._len = 0 + + if policy_state_shape is None: + self._policy_states = None + self._policy_next_states = None + else: + self._policy_states = torch.empty(policy_state_shape, dtype=torch.float) + self._policy_next_states = torch.empty(policy_state_shape, dtype=torch.float) + + self._add_save_attr( + _state_type='primitive', + _action_type='primitive', + _states='torch', + _actions='torch', + _rewards='torch', + _next_states='torch', + _absorbing='torch', + _last='torch', + _policy_states='numpy', + _policy_next_states='numpy', + _len='primitive' + ) + + @classmethod + def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, + policy_states=None, policy_next_states=None): + if not isinstance(states, torch.Tensor): + states = torch.as_tensor(states) + actions = torch.as_tensor(actions) + rewards = torch.as_tensor(rewards) + next_states = torch.as_tensor(next_states) + absorbings = torch.as_tensor(absorbings) + lasts = torch.as_tensor(lasts) + + dataset = cls.__new__(cls) + + dataset._state_type = states.dtype + dataset._action_type = actions.dtype + + dataset._states = torch.as_tensor(states) + dataset._actions = torch.as_tensor(actions) + dataset._rewards = torch.as_tensor(rewards) + dataset._next_states = torch.as_tensor(next_states) + dataset._absorbing = torch.as_tensor(absorbings, dtype=torch.bool) + dataset._last = torch.as_tensor(lasts, dtype=torch.bool) + dataset._len = len(lasts) + + if policy_states is not None and policy_next_states is not None: + if not isinstance(policy_states, torch.Tensor): + policy_states = torch.as_tensor(policy_states) + policy_next_states = torch.as_tensor(policy_next_states) + + dataset._policy_states = policy_states + dataset._policy_next_states = policy_next_states + + dataset._add_save_attr( + _state_type='primitive', + _action_type='primitive', + _states='torch', + _actions='torch', + _rewards='torch', + _next_states='torch', + _absorbing='torch', + _last='torch', + _policy_states='numpy', + _policy_next_states='numpy', + _len='primitive' + ) + + return dataset + + def __len__(self): + return self._len + + def append(self, state, action, reward, next_state, absorbing, last, policy_state=None, policy_next_state=None): + i = self._len + + self._states[i] = state + self._actions[i] = action + self._rewards[i] = reward + self._next_states[i] = next_state + self._absorbing[i] = absorbing + self._last[i] = last + + if self._is_stateful: + self._policy_states[i] = policy_state + self._policy_next_states[i] = policy_next_state + + self._len += 1 + + def clear(self): + self._states = torch.empty_like(self._states) + self._actions = torch.empty_like(self._actions) + self._rewards = torch.empty_like(self._rewards) + self._next_states = torch.empty_like(self._next_states) + self._absorbing = torch.empty_like(self._absorbing) + self._last = torch.empty_like(self._last) + + if self._is_stateful: + self._policy_states = torch.empty_like(self._policy_states) + self._policy_next_states = torch.empty_like(self._policy_next_states) + + self._len = 0 + + def get_view(self, index): + view = self.copy() + + view._states = self._states[index, ...] + view._actions = self._actions[index, ...] + view._rewards = self._rewards[index, ...] + view._next_states = self._next_states[index, ...] + view._absorbing = self._absorbing[index, ...] + view._last = self._last[index, ...] + view._len = view._states.shape[0] + + if self._is_stateful: + view._policy_states = self._policy_states[index, ...] + view._policy_next_states = self._policy_next_states[index, ...] + + return view + + def __getitem__(self, index): + return self._states[index], self._actions[index], self._rewards[index], self._next_states[index], \ + self._absorbing[index], self._last[index] + + def __add__(self, other): + result = self.copy() + + result._states = torch.concatenate((self.state, other.state)) + result._actions = torch.concatenate((self.action, other.action)) + result._rewards = torch.concatenate((self.reward, other.reward)) + result._next_states = torch.concatenate((self.next_state, other.next_state)) + result._absorbing = torch.concatenate((self.absorbing, other.absorbing)) + result._last = torch.concatenate((self.last, other.last)) + result._last[len(self) - 1] = True + result._len = len(self) + len(other) + + if self._is_stateful: + result._policy_states = torch.concatenate((self.policy_state, other.policy_state)) + result._policy_next_states = torch.concatenate((self.policy_next_state, other.policy_next_state)) + + return result + + @property + def state(self): + return self._states[:len(self)] + + @property + def action(self): + return self._actions[:len(self)] + + @property + def reward(self): + return self._rewards[:len(self)] + + @property + def next_state(self): + return self._next_states[:len(self)] + + @property + def absorbing(self): + return self._absorbing[:len(self)] + + @property + def last(self): + return self._last[:len(self)] + + @property + def policy_state(self): + return self._policy_states[:len(self)] + + @property + def policy_next_state(self): + return self._policy_next_states[:len(self)] + + @property + def _is_stateful(self): + return self._policy_states is not None diff --git a/mushroom_rl/core/_impl/type_conversions.py b/mushroom_rl/core/_impl/type_conversions.py new file mode 100644 index 000000000..2a7410c6e --- /dev/null +++ b/mushroom_rl/core/_impl/type_conversions.py @@ -0,0 +1,88 @@ +import numpy +import torch + + +class DataConversion(object): + @staticmethod + def get_converter(backend): + if backend == 'numpy': + return NumpyConversion + elif backend == 'torch': + return TorchConversion + else: + return ListConversion + + @classmethod + def convert(cls, *arrays, to='numpy'): + if to == 'numpy': + return cls.arrays_to_numpy(*arrays) + elif to == 'torch': + return cls.arrays_to_torch(*arrays) + else: + return NotImplementedError + + @classmethod + def arrays_to_numpy(cls, *arrays): + return (cls.to_numpy(array) for array in arrays) + + @classmethod + def arrays_to_torch(cls, *arrays): + return (cls.to_torch(array) for array in arrays) + + @staticmethod + def to_numpy(array): + return NotImplementedError + + @staticmethod + def to_torch(array): + raise NotImplementedError + + @staticmethod + def to_backend_array(cls, array): + raise NotImplementedError + + +class NumpyConversion(DataConversion): + @staticmethod + def to_numpy(array): + return array + + @staticmethod + def to_torch(array): + return torch.from_numpy(array) + + @staticmethod + def to_backend_array(cls, array): + return cls.to_numpy(array) + + +class TorchConversion(DataConversion): + @staticmethod + def to_numpy(array): + return array.detach().cpu().numpy() + + @staticmethod + def to_torch(array): + return array + + @staticmethod + def to_backend_array(cls, array): + return cls.to_torch(array) + + +class ListConversion(DataConversion): + @staticmethod + def to_numpy(array): + return numpy.array(array) + + @staticmethod + def to_torch(array): + return torch.as_tensor(array) + + @staticmethod + def to_backend_array(cls, array): + return cls.to_numpy(array) + + + + diff --git a/mushroom_rl/core/agent.py b/mushroom_rl/core/agent.py index c8e18f2a6..136edf874 100644 --- a/mushroom_rl/core/agent.py +++ b/mushroom_rl/core/agent.py @@ -1,88 +1,117 @@ from mushroom_rl.core.serialization import Serializable +from ._impl import * + + +class AgentInfo(Serializable): + def __init__(self, is_episodic, policy_state_shape, backend): + assert isinstance(is_episodic, bool) + assert policy_state_shape is None or isinstance(policy_state_shape, tuple) + assert isinstance(backend, str) + + self.is_episodic = is_episodic + self.is_stateful = policy_state_shape is not None + self.policy_state_shape = policy_state_shape + self.backend = backend + + self._add_save_attr( + is_episodic='primitive', + is_stateful='primitive', + policy_state_shape='primitive', + backend='primitive' + ) + class Agent(Serializable): """ - This class implements the functions to manage the agent (e.g. move the agent - following its policy). + This class implements the functions to manage the agent (e.g. move the agent following its policy). """ - def __init__(self, mdp_info, policy, features=None): + def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'): """ Constructor. Args: mdp_info (MDPInfo): information about the MDP; policy (Policy): the policy followed by the agent; - features (object, None): features to extract from the state. + is_episodic (bool, False): whether the agent is learning in an episodic fashion or not; + backend (str, 'numpy'): array backend to be used by the algorithm. """ self.mdp_info = mdp_info - self.policy = policy - - self.phi = features + self._info = AgentInfo( + is_episodic=is_episodic, + policy_state_shape=policy.policy_state_shape, + backend=backend + ) + self.policy = policy self.next_action = None + self._agent_converter = DataConversion.get_converter(backend) + self._env_converter = DataConversion.get_converter(self.mdp_info.backend) self._preprocessors = list() + self._logger = None self._add_save_attr( - mdp_info='pickle', policy='mushroom', - phi='pickle', - next_action='numpy', + next_action='none', + mdp_info='mushroom', + _info='mushroom', + _agent_converter='primitive', + _env_converter='primitive', _preprocessors='mushroom', _logger='none' ) - def fit(self, dataset, **info): + def fit(self, dataset): """ Fit step. Args: - dataset (list): the dataset. + dataset (Dataset): the dataset. """ raise NotImplementedError('Agent is an abstract class') - def draw_action(self, state): + def draw_action(self, state, policy_state=None): """ - Return the action to execute in the given state. It is the action - returned by the policy or the action set by the algorithm (e.g. in the - case of SARSA). + Return the action to execute in the given state. It is the action returned by the policy or the action set by + the algorithm (e.g. in the case of SARSA). Args: - state (np.ndarray): the state where the agent is. + state: the state where the agent is; + policy_state: the policy internal state. Returns: The action to be executed. """ - if self.phi is not None: - state = self.phi(state) - if self.next_action is None: - return self.policy.draw_action(state) + action, next_policy_state = self.policy.draw_action(state, policy_state) else: action = self.next_action + next_policy_state = None # FIXME self.next_action = None - return action + return self._convert_to_env_backend(action), self._convert_to_env_backend(next_policy_state) - def episode_start(self): + def episode_start(self, episode_info): """ Called by the agent when a new episode starts. + Args: + episode_info (dict): a dictionary containing the information at reset, such as context. + """ - self.policy.reset() + return self.policy.reset(), None def stop(self): """ - Method used to stop an agent. Useful when dealing with real world - environments, simulators, or to cleanup environments internals after - a core learn/evaluate to enforce consistency. + Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup + environments internals after a core learn/evaluate to enforce consistency. """ pass @@ -99,8 +128,7 @@ def set_logger(self, logger): def add_preprocessor(self, preprocessor): """ - Add preprocessor to the preprocessor list. - The preprocessors are applied in order. + Add preprocessor to the preprocessor list. The preprocessors are applied in order. Args: preprocessor (object): state preprocessors to be applied @@ -116,3 +144,14 @@ def preprocessors(self): """ return self._preprocessors + + def _convert_to_env_backend(self, array): + return self._env_converter.to_backend_array(self._agent_converter, array) + + def _convert_to_agent_backend(self, array): + return self._agent_converter.to_backend_array(self._env_converter, array) + + @property + def info(self): + return self._info + diff --git a/mushroom_rl/core/core.py b/mushroom_rl/core/core.py index 22d6e20e3..1dc5154ac 100644 --- a/mushroom_rl/core/core.py +++ b/mushroom_rl/core/core.py @@ -1,8 +1,8 @@ -from tqdm import tqdm - -from collections import defaultdict +from mushroom_rl.core.dataset import Dataset from mushroom_rl.utils.record import VideoRecorder +from ._impl import CoreLogic + class Core(object): """ @@ -26,15 +26,11 @@ def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None, record_di self.callback_step = callback_step if callback_step is not None else lambda x: None self._state = None - - self._total_episodes_counter = 0 - self._total_steps_counter = 0 - self._current_episodes_counter = 0 - self._current_steps_counter = 0 + self._policy_state = None + self._current_theta = None self._episode_steps = None - self._n_episodes = None - self._n_steps_per_fit = None - self._n_episodes_per_fit = None + + self._core_logic = CoreLogic() if record_dictionary is None: record_dictionary = dict() @@ -61,23 +57,14 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, should be set to True. """ - assert (n_episodes_per_fit is not None and n_steps_per_fit is None)\ - or (n_episodes_per_fit is None and n_steps_per_fit is not None) - assert (render and record) or (not record), "To record, the render flag must be set to true" + self._core_logic.initialize_fit(n_steps_per_fit, n_episodes_per_fit) - self._n_steps_per_fit = n_steps_per_fit - self._n_episodes_per_fit = n_episodes_per_fit + dataset = Dataset(self.mdp.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) - if n_steps_per_fit is not None: - fit_condition = lambda: self._current_steps_counter >= self._n_steps_per_fit - else: - fit_condition = lambda: self._current_episodes_counter >= self._n_episodes_per_fit + self._run(dataset, n_steps, n_episodes, render, quiet, record) - self._run(n_steps, n_episodes, fit_condition, render, quiet, record, get_env_info=False) - - def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, - render=False, quiet=False, record=False, get_env_info=False): + def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=False, quiet=False, record=False): """ This function moves the agent in the environment using its policy. The agent is moved for a provided number of steps, episodes, or from a set of initial states for the whole @@ -90,102 +77,55 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render (bool, False): whether to render the environment or not; quiet (bool, False): whether to show the progress bar or not; record (bool, False): whether to record a video of the environment or not. If True, also the render flag - should be set to True; - get_env_info (bool, False): whether to return the environment info list or not. + should be set to True. Returns: - The collected dataset and, optionally, an extra dataset of - environment info, collected at each step. + The collected dataset. """ assert (render and record) or (not record), "To record, the render flag must be set to true" - fit_condition = lambda: False - - return self._run(n_steps, n_episodes, fit_condition, render, quiet, record, get_env_info, initial_states) - - def _run(self, n_steps, n_episodes, fit_condition, render, quiet, record, get_env_info, initial_states=None): - assert n_episodes is not None and n_steps is None and initial_states is None\ - or n_episodes is None and n_steps is not None and initial_states is None\ - or n_episodes is None and n_steps is None and initial_states is not None + self._core_logic.initialize_evaluate() - self._n_episodes = len( initial_states) if initial_states is not None else n_episodes + n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes + dataset = Dataset(self.mdp.info, self.agent.info, n_steps, n_episodes_dataset) - if n_steps is not None: - move_condition = lambda: self._total_steps_counter < n_steps + return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states) - steps_progress_bar = tqdm(total=n_steps, dynamic_ncols=True, disable=quiet, leave=False) - episodes_progress_bar = tqdm(disable=True) - else: - move_condition = lambda: self._total_episodes_counter < self._n_episodes - - steps_progress_bar = tqdm(disable=True) - episodes_progress_bar = tqdm(total=self._n_episodes, dynamic_ncols=True, disable=quiet, leave=False) - - dataset, dataset_info = self._run_impl(move_condition, fit_condition, steps_progress_bar, episodes_progress_bar, - render, record, initial_states) - - if get_env_info: - return dataset, dataset_info - else: - return dataset - - def _run_impl(self, move_condition, fit_condition, steps_progress_bar, episodes_progress_bar, render, record, - initial_states): - self._total_episodes_counter = 0 - self._total_steps_counter = 0 - self._current_episodes_counter = 0 - self._current_steps_counter = 0 - - dataset = list() - dataset_info = defaultdict(list) + def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_states=None): + self._core_logic.initialize_run(n_steps, n_episodes, initial_states, quiet) last = True - while move_condition(): + while self._core_logic.move_required(): if last: - self.reset(initial_states) + self._reset(initial_states) + if self.agent.info.is_episodic: + dataset.append_theta(self._current_theta) sample, step_info = self._step(render, record) - self.callback_step([sample]) - - self._total_steps_counter += 1 - self._current_steps_counter += 1 - steps_progress_bar.update(1) - - if sample[-1]: - self._total_episodes_counter += 1 - self._current_episodes_counter += 1 - episodes_progress_bar.update(1) - - dataset.append(sample) + self.callback_step(sample) + self._core_logic.after_step(sample[5]) - for key, value in step_info.items(): - dataset_info[key].append(value) + dataset.append(sample, step_info) - if fit_condition(): - self.agent.fit(dataset, **dataset_info) - self._current_episodes_counter = 0 - self._current_steps_counter = 0 + if self._core_logic.fit_required(): + self.agent.fit(dataset) + self._core_logic.after_fit() for c in self.callbacks_fit: c(dataset) - dataset = list() - dataset_info = defaultdict(list) + dataset.clear() - last = sample[-1] + last = sample[5] self.agent.stop() self.mdp.stop() - if record: - self._record.stop() + self._end(record) - steps_progress_bar.close() - episodes_progress_bar.close() - - return dataset, dataset_info + return dataset def _step(self, render, record): """ @@ -199,42 +139,52 @@ def _step(self, render, record): state, the absorbing flag of the reached state and the last step flag. """ - action = self.agent.draw_action(self._state) + action, policy_next_state = self.agent.draw_action(self._state, self._policy_state) next_state, reward, absorbing, step_info = self.mdp.step(action) - self._episode_steps += 1 - if render: frame = self.mdp.render(record) if record: self._record(frame) - last = not( - self._episode_steps < self.mdp.info.horizon and not absorbing) + self._episode_steps += 1 + + last = self._episode_steps >= self.mdp.info.horizon or absorbing state = self._state - next_state = self._preprocess(next_state.copy()) + policy_state = self._policy_state + next_state = self._preprocess(next_state) self._state = next_state + self._policy_state = policy_next_state - return (state, action, reward, next_state, absorbing, last), step_info + return (state, action, reward, next_state, absorbing, last, policy_state, policy_next_state), step_info - def reset(self, initial_states=None): + def _reset(self, initial_states): """ Reset the state of the agent. """ - if initial_states is None or self._total_episodes_counter == self._n_episodes: - initial_state = None - else: - initial_state = initial_states[self._total_episodes_counter] - - self.agent.episode_start() - - self._state = self._preprocess(self.mdp.reset(initial_state).copy()) + initial_state = self._core_logic.get_initial_state(initial_states) + + state, episode_info = self.mdp.reset(initial_state) + self._policy_state, self._current_theta = self.agent.episode_start(episode_info) + self._state = self._preprocess(state) self.agent.next_action = None + self._episode_steps = 0 + def _end(self, record): + self._state = None + self._policy_state = None + self._current_theta = None + self._episode_steps = None + + if record: + self._record.stop() + + self._core_logic.terminate_run() + def _preprocess(self, state): """ Method to apply state preprocessors. diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py new file mode 100644 index 000000000..18deb2192 --- /dev/null +++ b/mushroom_rl/core/dataset.py @@ -0,0 +1,419 @@ +import numpy as np + +from collections import defaultdict + +from mushroom_rl.core.serialization import Serializable + +from ._impl import * + + +class Dataset(Serializable): + def __init__(self, mdp_info, agent_info, n_steps=None, n_episodes=None): + assert (n_steps is not None and n_episodes is None) or (n_steps is None and n_episodes is not None) + + if n_steps is not None: + n_samples = n_steps + else: + horizon = mdp_info.horizon + assert np.isfinite(horizon) + + n_samples = horizon * n_episodes + + state_shape = (n_samples,) + mdp_info.observation_space.shape + action_shape = (n_samples,) + mdp_info.action_space.shape + reward_shape = (n_samples,) + + if agent_info.is_stateful: + policy_state_shape = (n_samples,) + agent_info.policy_state_shape + else: + policy_state_shape = None + + state_type = mdp_info.observation_space.data_type + action_type = mdp_info.action_space.data_type + + self._info = defaultdict(list) + self._episode_info = defaultdict(list) + self._theta_list = list() + + if mdp_info.backend == 'numpy': + self._data = NumpyDataset(state_type, state_shape, action_type, action_shape, reward_shape, + policy_state_shape) + elif mdp_info.backend == 'torch': + self._data = TorchDataset(state_type, state_shape, action_type, action_shape, reward_shape, + policy_state_shape) + else: + self._data = ListDataset() + + self._converter = DataConversion.get_converter(mdp_info.backend) + + self._gamma = mdp_info.gamma + + self._add_save_attr( + _info='pickle', + _episode_info='pickle', + _theta_list='pickle', + _data='mushroom', + _converter='primitive', + _gamma='primitive', + ) + + @classmethod + def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, + policy_state=None, policy_next_state=None, info=None, episode_info=None, theta_list=None, + gamma=0.99, backend='numpy'): + """ + Creates a dataset of transitions from the provided arrays. + + Args: + states (np.ndarray): array of states; + actions (np.ndarray): array of actions; + rewards (np.ndarray): array of rewards; + next_states (np.ndarray): array of next_states; + absorbings (np.ndarray): array of absorbing flags; + lasts (np.ndarray): array of last flags; + policy_state (np.ndarray, None): array of policy internal states; + policy_next_state (np.ndarray, None): array of next policy internal states; + info (dict, None): dictiornay of step info; + episode_info (dict, None): dictiornary of episode info; + theta_list (list, None): list of policy parameters; + gamma (float, 0.99): discount factor; + backend (str, 'numpy'): backend to be used by the dataset. + + Returns: + The list of transitions. + + """ + assert len(states) == len(actions) == len(rewards) == len(next_states) == len(absorbings) == len(lasts) + + if policy_state is not None: + assert len(states) == len(policy_state) == len(policy_next_state) + + dataset = cls.__new__(cls) + dataset._gamma = gamma + + if info is None: + dataset._info = defaultdict(list) + else: + dataset._info = info.copy() + + if episode_info is None: + dataset._episode_info = defaultdict(list) + else: + dataset._episode_info = episode_info.copy() + + if theta_list is None: + dataset._theta_list = list() + else: + dataset._theta_list = theta_list + + if backend == 'numpy': + dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._converter = NumpyConversion + elif backend == 'torch': + dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._converter = TorchConversion + else: + dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._converter = ListConversion + + dataset._add_save_attr( + _info='pickle', + _episode_info='pickle', + _theta_list='pickle', + _data='mushroom', + _converter='primitive', + _gamma='primitive' + ) + + return dataset + + def append(self, step, info): + self._data.append(*step) + self._append_info(self._info, info) + + def append_episode_info(self, info): + self._append_info(self._episode_info, info) + + def append_theta(self, theta): + self._theta_list.append(theta) + + def get_info(self, field, index=None): + if index is None: + return self._info[field] + else: + return self._info[field][index] + + def clear(self): + self._info = defaultdict(list) + self._data.clear() + + def get_view(self, index): + dataset = self.copy() + + info_slice = defaultdict(list) + for key in self._info.keys(): + info_slice[key] = self._info[key][index] + + dataset._info = info_slice + dataset._episode_info = defaultdict(list) + dataset._data = self._data.get_view(index) + + return dataset + + def item(self): + assert len(self) == 1 + return self[0] + + def __getitem__(self, index): + if isinstance(index, (slice, np.ndarray)): + return self.get_view(index) + elif isinstance(index, int) and index < len(self._data): + return self._data[index] + else: + raise IndexError + + def __add__(self, other): + result = self.copy() + + new_info = self._merge_info(result.info, other.info) + new_episode_info = self._merge_info(result.episode_info, other.episode_info) + + result._info = new_info + result._episode_info = new_episode_info + result.theta_list = result._theta_list + other._theta_list + result._data = self._data + other._data + + return result + + def __len__(self): + return len(self._data) + + @property + def state(self): + return self._data.state + + @property + def action(self): + return self._data.action + + @property + def reward(self): + return self._data.reward + + @property + def next_state(self): + return self._data.next_state + + @property + def absorbing(self): + return self._data.absorbing + + @property + def last(self): + return self._data.last + + @property + def policy_state(self): + return self._data.policy_state + + @property + def policy_next_state(self): + return self._data.policy_next_state + + @property + def info(self): + return self._info + + @property + def episode_info(self): + return self._episode_info + + @property + def theta_list(self): + return self._theta_list + + @property + def episodes_length(self): + """ + Compute the length of each episode in the dataset. + + Args: + dataset (list): the dataset to consider. + + Returns: + A list of length of each episode in the dataset. + + """ + lengths = list() + l = 0 + for sample in self: + l += 1 + if sample[-1] == 1: + lengths.append(l) + l = 0 + + return lengths + + @property + def undiscounted_return(self): + return self.compute_J() + + @property + def discounted_return(self): + return self.compute_J(self._gamma) + + def parse(self, to='numpy'): + """ + Return the dataset as set of arrays. + + to (str, numpy): the backend to be used for the returned arrays. + + Returns: + A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last + + """ + return self._converter.convert(self.state, self.action, self.reward, self.next_state, + self.absorbing, self.last, to=to) + + def parse_policy_state(self, to='numpy'): + """ + Return the dataset as set of arrays. + + to (str, numpy): the backend to be used for the returned arrays. + + Returns: + A tuple containing the arrays that define the dataset, i.e. state, action, next state, absorbing and last + + """ + return self._converter.convert(self.policy_state, self.policy_next_state, to=to) + + def select_first_episodes(self, n_episodes): + """ + Return the first ``n_episodes`` episodes in the provided dataset. + + Args: + dataset (list): the dataset to consider; + n_episodes (int): the number of episodes to pick from the dataset; + + Returns: + A subset of the dataset containing the first ``n_episodes`` episodes. + + """ + assert n_episodes > 0, 'Number of episodes must be greater than zero.' + + last_idxs = np.argwhere(self.last).ravel() + return self[:last_idxs[n_episodes - 1] + 1] + + def select_random_samples(self, n_samples): + """ + Return the randomly picked desired number of samples in the provided + dataset. + + Args: + dataset (list): the dataset to consider; + n_samples (int): the number of samples to pick from the dataset. + + Returns: + A subset of the dataset containing randomly picked ``n_samples`` + samples. + + """ + assert n_samples >= 0, 'Number of samples must be greater than or equal to zero.' + + if n_samples == 0: + return np.array([[]]) + + idxs = np.random.randint(len(self), size=n_samples) + + return self[idxs] + + def get_init_states(self): + """ + Get the initial states of a dataset + + Args: + dataset (list): the dataset to consider. + + Returns: + An array of initial states of the considered dataset. + + """ + pick = True + x_0 = list() + for step in self: + if pick: + x_0.append(step[0]) + pick = step[-1] + return np.array(x_0) + + def compute_J(self, gamma=1.): + """ + Compute the cumulative discounted reward of each episode in the dataset. + + Args: + dataset (list): the dataset to consider; + gamma (float, 1.): discount factor. + + Returns: + The cumulative discounted reward of each episode in the dataset. + + """ + js = list() + + j = 0. + episode_steps = 0 + for i in range(len(self)): + j += gamma ** episode_steps * self.reward[i] + episode_steps += 1 + if self.last[i] or i == len(self) - 1: + js.append(j) + j = 0. + episode_steps = 0 + + if len(js) == 0: + return [0.] + return js + + def compute_metrics(self, gamma=1.): + """ + Compute the metrics of each complete episode in the dataset. + + Args: + dataset (list): the dataset to consider; + gamma (float, 1.): the discount factor. + + Returns: + The minimum score reached in an episode, + the maximum score reached in an episode, + the mean score reached, + the median score reached, + the number of completed episodes. + + If no episode has been completed, it returns 0 for all values. + + """ + i = 0 + for i in reversed(range(len(self))): + if self.last[i]: + i += 1 + break + + dataset = self[:i] + + if len(dataset) > 0: + J = dataset.compute_J(gamma) + return np.min(J), np.max(J), np.mean(J), np.median(J), len(J) + else: + return 0, 0, 0, 0, 0 + + @staticmethod + def _append_info(info, step_info): + for key, value in step_info.items(): + info[key].append(value) + + @staticmethod + def _merge_info(info, other_info): + new_info = defaultdict(list) + for key in info.keys(): + new_info[key] = info[key] + other_info[key] + return new_info diff --git a/mushroom_rl/core/environment.py b/mushroom_rl/core/environment.py index 7a2fa485b..1f83ce585 100644 --- a/mushroom_rl/core/environment.py +++ b/mushroom_rl/core/environment.py @@ -9,7 +9,7 @@ class MDPInfo(Serializable): This class is used to store the information of the environment. """ - def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1): + def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1, backend='numpy'): """ Constructor. @@ -18,7 +18,8 @@ def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1): action_space ([Box, Discrete]): the action space; gamma (float): the discount factor; horizon (int): the horizon; - dt (float, 1e-1): the control timestep of the environment. + dt (float, 1e-1): the control timestep of the environment; + backend (str, 'numpy'): the type of data library used to generate state and actions. """ self.observation_space = observation_space @@ -26,13 +27,15 @@ def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1): self.gamma = gamma self.horizon = horizon self.dt = dt + self.backend = backend self._add_save_attr( observation_space='mushroom', action_space='mushroom', gamma='primitive', horizon='primitive', - dt='primitive' + dt='primitive', + backend='primitive' ) @property @@ -145,7 +148,7 @@ def reset(self, state=None): state (np.ndarray, None): the state to set to the current state. Returns: - The current state. + The current state and a dictionary containing the info for the episode. """ raise NotImplementedError diff --git a/mushroom_rl/core/parallel_environment.py b/mushroom_rl/core/parallel_environment.py new file mode 100644 index 000000000..a1bc90d3e --- /dev/null +++ b/mushroom_rl/core/parallel_environment.py @@ -0,0 +1,127 @@ +from multiprocessing import Pipe +from multiprocessing import Process + +from .vectorized_env import VectorizedEnvironment + + +def _parallel_env_worker(remote, env_class, use_generator, args, kwargs): + + if use_generator: + env = env_class.generate(*args, **kwargs) + else: + env = env_class(*args, **kwargs) + + try: + while True: + cmd, data = remote.recv() + if cmd == 'step': + action = data[0] + res = env.step(action) + remote.send(res) + elif cmd == 'reset': + init_states = data[0] + res = env.reset(init_states) + remote.send(res) + elif cmd in 'stop': + env.stop() + elif cmd == 'info': + remote.send(env.info) + elif cmd == 'seed': + env.seed(int(data)) + else: + raise NotImplementedError() + finally: + remote.close() + + +class ParallelEnvironment(VectorizedEnvironment): + """ + Basic interface to run in parallel multiple copies of the same environment. + This class assumes that the environments are homogeneus, i.e. have the same type and MDP info. + + """ + def __init__(self, env_class, *args, n_envs=-1, use_generator=False, **kwargs): + """ + Constructor. + + Args: + env_class (class): The environment class to be used; + *args: the positional arguments to give to the constructor or to the generator of the class; + n_envs (int, -1): number of parallel copies of environment to construct; + use_generator (bool, False): wheather to use the generator to build the environment or not; + **kwargs: keyword arguments to set to the constructor or to the generator; + + """ + assert n_envs > 1 + + self._remotes, self._work_remotes = zip(*[Pipe() for _ in range(n_envs)]) + self._processes = [Process(target=_parallel_env_worker, + args=(work_remote, env_class, use_generator, args, kwargs)) + for work_remote in self._work_remotes] + + for p in self._processes: + p.start() + + self._remotes[0].send(('info', None)) + mdp_info = self._remotes[0].recv() + + super().__init__(mdp_info, n_envs) + + def step_all(self, env_mask, action): + for i, remote in enumerate(self._remotes): + if env_mask[i]: + remote.send(('step', action[i, :])) + + results = [] + for i, remote in enumerate(self._remotes): + if env_mask[i]: + results.extend(remote.recv()) + + return zip(*results) # FIXME!!! + + def reset_all(self, env_mask, state=None): + for i in range(self._n_envs): + state_i = state[i, :] if state is not None else None + self._remotes[i].send(('reset', state_i)) + + results = [] + for i, remote in enumerate(self._remotes): + if env_mask[i]: + results.extend(remote.recv()) + + return zip(*results) # FIXME!!! + + def seed(self, seed): + for remote in self._remotes: + remote.send(('seed', seed)) + + for remote in self._remotes: + remote.recv() + + def stop(self): + for remote in self._remotes: + remote.send(('stop', None)) + + def __del__(self): + for remote in self._remotes: + remote.send(('close', None)) + for p in self._processes: + p.join() + + @staticmethod + def generate(env, *args, n_envs=-1, **kwargs): + """ + Method to generate an array of multiple copies of the same environment, calling the generate method n_envs times + + Args: + env (class): the environment to be constructed; + *args: positional arguments to be passed to the constructor; + n_envs (int, -1): number of environments to generate; + **kwargs: keywords arguments to be passed to the constructor + + Returns: + A list containing multiple copies of the environment. + + """ + use_generator = hasattr(env, 'generate') + return ParallelEnvironment(env, *args, n_envs=n_envs, use_generator=use_generator, **kwargs) \ No newline at end of file diff --git a/mushroom_rl/core/vectorized_env.py b/mushroom_rl/core/vectorized_env.py new file mode 100644 index 000000000..ad5bf38ab --- /dev/null +++ b/mushroom_rl/core/vectorized_env.py @@ -0,0 +1,29 @@ +import numpy as np + +from .environment import Environment + + +class VectorizedEnvironment(Environment): + """ + Class to create a Mushroom environment using the PyBullet simulator. + + """ + def __init__(self, mdp_info, n_envs): + self._n_envs = n_envs + super().__init__(mdp_info) + + def reset(self, state=None): + env_mask = np.zeros(dtype=bool) + env_mask[0] = True + return self.reset_all(env_mask, state) + + def step(self, action): + env_mask = np.zeros(dtype=bool) + env_mask[0] = True + return self.step_all(env_mask, action) + + def step_all(self, env_mask, action): + raise NotImplementedError + + def reset_all(self, env_mask, state=None): + raise NotImplementedError diff --git a/mushroom_rl/environments/atari.py b/mushroom_rl/environments/atari.py index c11321ea9..15e778eb7 100644 --- a/mushroom_rl/environments/atari.py +++ b/mushroom_rl/environments/atari.py @@ -109,7 +109,7 @@ def reset(self, state=None): self._current_no_op = np.random.randint(self._max_no_op_actions + 1) - return LazyFrames(list(self._state), self._history_length) + return LazyFrames(list(self._state), self._history_length), {} def step(self, action): action = action[0] @@ -129,13 +129,11 @@ def step(self, action): if self._episode_ends_at_life: absorbing = True self._lives = info['lives'] - self._force_fire = self.env.unwrapped.get_action_meanings()[ - 1] == 'FIRE' + self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE' self._state.append(preprocess_frame(obs, self._img_size)) - return LazyFrames(list(self._state), - self._history_length), reward, absorbing, info + return LazyFrames(list(self._state), self._history_length), reward, absorbing, info def render(self, record=False): self.env.render(mode='human') diff --git a/mushroom_rl/environments/car_on_hill.py b/mushroom_rl/environments/car_on_hill.py index 32eacdc02..f03368426 100644 --- a/mushroom_rl/environments/car_on_hill.py +++ b/mushroom_rl/environments/car_on_hill.py @@ -46,7 +46,7 @@ def reset(self, state=None): else: self._state = state - return self._state + return self._state, {} def step(self, action): action = self._discrete_actions[action[0]] @@ -55,12 +55,10 @@ def step(self, action): self._state = new_state[-1, :-1] - if self._state[0] < -self.max_pos or \ - np.abs(self._state[1]) > self.max_velocity: + if self._state[0] < -self.max_pos or np.abs(self._state[1]) > self.max_velocity: reward = -1. absorbing = True - elif self._state[0] > self.max_pos and \ - np.abs(self._state[1]) <= self.max_velocity: + elif self._state[0] > self.max_pos and np.abs(self._state[1]) <= self.max_velocity: reward = 1. absorbing = True else: diff --git a/mushroom_rl/environments/cart_pole.py b/mushroom_rl/environments/cart_pole.py index fe15afe37..0dd934516 100644 --- a/mushroom_rl/environments/cart_pole.py +++ b/mushroom_rl/environments/cart_pole.py @@ -63,7 +63,7 @@ def reset(self, state=None): self._state[0] = normalize_angle(self._state[0]) self._last_u = 0 - return self._state + return self._state, {} def step(self, action): if action == 0: @@ -103,8 +103,7 @@ def render(self, record=False): direction = -np.sign(self._last_u) * np.array([1, 0]) value = np.abs(self._last_u) - self._viewer.force_arrow(start, direction, value, - self._max_u, self._l / 5) + self._viewer.force_arrow(start, direction, value, self._max_u, self._l / 5) frame = self._viewer.get_frame() if record else None @@ -120,9 +119,9 @@ def _dynamics(self, state, t, u): omega = state[1] d_theta = omega - d_omega = (self._g * np.sin(theta) - self._alpha * self._m * self._l * .5 * - d_theta ** 2 * np.sin(2 * theta) * .5 - self._alpha * np.cos( - theta) * u) / (2 / 3 * self._l - self._alpha * self._m * - self._l * .5 * np.cos(theta) ** 2) + d_omega = (self._g * np.sin(theta) + - self._alpha * self._m * self._l * .5 * d_theta ** 2 * np.sin(2 * theta) * .5 + - self._alpha * np.cos(theta) * u) / (2 / 3 * self._l - + self._alpha * self._m * self._l * .5 * np.cos(theta) ** 2) return d_theta, d_omega diff --git a/mushroom_rl/environments/dm_control_env.py b/mushroom_rl/environments/dm_control_env.py index 030d4977d..c5844f68d 100644 --- a/mushroom_rl/environments/dm_control_env.py +++ b/mushroom_rl/environments/dm_control_env.py @@ -79,7 +79,7 @@ def reset(self, state=None): else: raise NotImplementedError - return self._state + return self._state, {} def step(self, action): step = self.env.step(action) diff --git a/mushroom_rl/environments/finite_mdp.py b/mushroom_rl/environments/finite_mdp.py index b49995c00..3296e98dc 100644 --- a/mushroom_rl/environments/finite_mdp.py +++ b/mushroom_rl/environments/finite_mdp.py @@ -49,7 +49,7 @@ def reset(self, state=None): else: self._state = state - return self._state + return self._state, {} def step(self, action): p = self.p[self._state[0], action[0], :] diff --git a/mushroom_rl/environments/grid_world.py b/mushroom_rl/environments/grid_world.py index 360aead01..923440c0c 100644 --- a/mushroom_rl/environments/grid_world.py +++ b/mushroom_rl/environments/grid_world.py @@ -32,8 +32,7 @@ def __init__(self, mdp_info, height, width, start, goal): self._goal = goal # Visualization - self._viewer = Viewer(self._width, self._height, 500, - self._height * 500 // self._width) + self._viewer = Viewer(self._width, self._height, 500, self._height * 500 // self._width) super().__init__(mdp_info) @@ -43,7 +42,7 @@ def reset(self, state=None): self._state = state - return self._state + return self._state, {} def step(self, action): state = self.convert_to_grid(self._state, self._width) @@ -56,23 +55,18 @@ def step(self, action): def render(self, record=False): for row in range(1, self._height): for col in range(1, self._width): - self._viewer.line(np.array([col, 0]), - np.array([col, self._height])) - self._viewer.line(np.array([0, row]), - np.array([self._width, row])) + self._viewer.line(np.array([col, 0]), np.array([col, self._height])) + self._viewer.line(np.array([0, row]), np.array([self._width, row])) - goal_center = np.array([.5 + self._goal[1], - self._height - (.5 + self._goal[0])]) + goal_center = np.array([.5 + self._goal[1], self._height - (.5 + self._goal[0])]) self._viewer.square(goal_center, 0, 1, (0, 255, 0)) start_grid = self.convert_to_grid(self._start, self._width) - start_center = np.array([.5 + start_grid[1], - self._height - (.5 + start_grid[0])]) + start_center = np.array([.5 + start_grid[1], self._height - (.5 + start_grid[0])]) self._viewer.square(start_center, 0, 1, (255, 0, 0)) state_grid = self.convert_to_grid(self._state, self._width) - state_center = np.array([.5 + state_grid[1], - self._height - (.5 + state_grid[0])]) + state_center = np.array([.5 + state_grid[1], self._height - (.5 + state_grid[0])]) self._viewer.circle(state_center, .4, (0, 0, 255)) frame = self._viewer.get_frame() if record else None diff --git a/mushroom_rl/environments/gym_env.py b/mushroom_rl/environments/gym_env.py index 66f3e12bf..aa4015ad1 100644 --- a/mushroom_rl/environments/gym_env.py +++ b/mushroom_rl/environments/gym_env.py @@ -83,12 +83,12 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args= def reset(self, state=None): if state is None: - return np.atleast_1d(self.env.reset()) + return np.atleast_1d(self.env.reset()), {} else: self.env.reset() self.env.state = state - return np.atleast_1d(state) + return np.atleast_1d(state), {} def step(self, action): action = self._convert_action(action) diff --git a/mushroom_rl/environments/habitat_env.py b/mushroom_rl/environments/habitat_env.py index 685550517..f52c0cb3d 100644 --- a/mushroom_rl/environments/habitat_env.py +++ b/mushroom_rl/environments/habitat_env.py @@ -251,7 +251,7 @@ def __init__(self, wrapper, config_file, base_config_file=None, horizon=None, ga def reset(self, state=None): assert state is None, 'Cannot set Habitat state' obs = self._convert_observation(np.atleast_1d(self.env.reset())) - return obs + return obs, {} def step(self, action): action = self._convert_action(action) diff --git a/mushroom_rl/environments/igibson_env.py b/mushroom_rl/environments/igibson_env.py index a9d7b80e6..3293919fc 100644 --- a/mushroom_rl/environments/igibson_env.py +++ b/mushroom_rl/environments/igibson_env.py @@ -113,7 +113,7 @@ def __init__(self, config_file, horizon=None, gamma=0.99, is_discrete=False, def reset(self, state=None): assert state is None, 'Cannot set iGibson state' - return self._convert_observation(np.atleast_1d(self.env.reset())) + return self._convert_observation(np.atleast_1d(self.env.reset())), {} def step(self, action): action = self._convert_action(action) diff --git a/mushroom_rl/environments/inverted_pendulum.py b/mushroom_rl/environments/inverted_pendulum.py index 3fc7ed1b9..9f7f773f0 100644 --- a/mushroom_rl/environments/inverted_pendulum.py +++ b/mushroom_rl/environments/inverted_pendulum.py @@ -65,11 +65,10 @@ def reset(self, state=None): else: self._state = state self._state[0] = normalize_angle(self._state[0]) - self._state[1] = self._bound(self._state[1], -self._max_omega, - self._max_omega) + self._state[1] = self._bound(self._state[1], -self._max_omega, self._max_omega) self._last_u = 0.0 - return self._state + return self._state, {} def step(self, action): u = self._bound(action[0], -self._max_u, self._max_u) diff --git a/mushroom_rl/environments/lqr.py b/mushroom_rl/environments/lqr.py index 8e317d693..8b7fb4046 100644 --- a/mushroom_rl/environments/lqr.py +++ b/mushroom_rl/environments/lqr.py @@ -114,26 +114,21 @@ def generate(dimensions=None, s_dim=None, a_dim=None, max_pos=np.inf, max_action def reset(self, state=None): if state is None: if self.random_init: - self._state = self._bound( - np.random.uniform(-3, 3, size=self.A.shape[0]), - self.info.observation_space.low, - self.info.observation_space.high - ) + rand_state = np.random.uniform(-3, 3, size=self.A.shape[0]) + self._state = self._bound(rand_state, self.info.observation_space.low, self.info.observation_space.high) elif self._initial_state is not None: self._state = self._initial_state else: - init_value = .9 * self._max_pos if np.isfinite( - self._max_pos) else 10 + init_value = .9 * self._max_pos if np.isfinite(self._max_pos) else 10 self._state = init_value * np.ones(self.A.shape[0]) else: self._state = state - return self._state + return self._state, {} def step(self, action): x = self._state - u = self._bound(action, self.info.action_space.low, - self.info.action_space.high) + u = self._bound(action, self.info.action_space.low, self.info.action_space.high) reward = -(x.dot(self.Q).dot(x) + u.dot(self.R).dot(u)) self._state = self.A.dot(x) + self.B.dot(u) diff --git a/mushroom_rl/environments/minigrid_env.py b/mushroom_rl/environments/minigrid_env.py index acbb9084e..daf688f05 100644 --- a/mushroom_rl/environments/minigrid_env.py +++ b/mushroom_rl/environments/minigrid_env.py @@ -82,13 +82,13 @@ def reset(self, state=None): self._state) for _ in range(self._history_length)], maxlen=self._history_length ) - return LazyFrames(list(self._state), self._history_length) + return LazyFrames(list(self._state), self._history_length), {} def step(self, action): obs, reward, absorbing, info = self.env.step(action) - reward *= 1. # Int to float + reward = float(reward) if reward > 0: - reward = 1. # MiniGrid discounts rewards based on timesteps, but we need raw rewards + reward = 1. # MiniGrid discounts rewards based on timesteps, but we need raw rewards self._state.append(preprocess_frame(obs, self._img_size)) diff --git a/mushroom_rl/environments/mujoco.py b/mushroom_rl/environments/mujoco.py index cbc59802f..71884d3c9 100644 --- a/mushroom_rl/environments/mujoco.py +++ b/mushroom_rl/environments/mujoco.py @@ -131,7 +131,7 @@ def reset(self, obs=None): self.setup(obs) self._obs = self._create_observation(self.obs_helper._build_obs(self._data)) - return self._modify_observation(self._obs) + return self._modify_observation(self._obs), {} def step(self, action): cur_obs = self._obs.copy() diff --git a/mushroom_rl/environments/puddle_world.py b/mushroom_rl/environments/puddle_world.py index 140623572..2f12360d1 100644 --- a/mushroom_rl/environments/puddle_world.py +++ b/mushroom_rl/environments/puddle_world.py @@ -68,24 +68,24 @@ def reset(self, state=None): else: self._state = state - return self._state + return self._state, {} def step(self, action): idx = action[0] - self._state += self._actions[idx] + np.random.uniform( - low=-self._noise_step, high=self._noise_step, size=(2,)) - self._state = np.clip(self._state, 0., 1.) + noise = np.random.uniform(low=-self._noise_step, high=self._noise_step, size=(2,)) + next_state = self._state + self._actions[idx] + noise + next_state = np.clip(next_state, 0., 1.) - absorbing = np.linalg.norm((self._state - self._goal), - ord=1) < self._goal_threshold + absorbing = np.linalg.norm((next_state - self._goal), ord=1) < self._goal_threshold if not absorbing: - reward = np.random.randn() * self._noise_reward + self._get_reward( - self._state) + reward = np.random.randn() * self._noise_reward + self._get_reward(next_state) else: reward = self._reward_goal - return self._state, reward, absorbing, {} + self._state = next_state + + return next_state, reward, absorbing, {} def render(self, record=False): if self._pixels is None: @@ -95,16 +95,14 @@ def render(self, record=False): for j in range(img_size): x = i / img_size y = j / img_size - pixels[i, img_size - 1 - j] = self._get_reward( - np.array([x, y])) + pixels[i, img_size - 1 - j] = self._get_reward(np.array([x, y])) pixels -= pixels.min() pixels *= 255. / pixels.max() self._pixels = np.floor(255 - pixels) self._viewer.background_image(self._pixels) - self._viewer.circle(self._state, 0.01, - color=(0, 255, 0)) + self._viewer.circle(self._state, 0.01, color=(0, 255, 0)) goal_area = [ [-self._goal_threshold, 0], @@ -112,8 +110,7 @@ def render(self, record=False): [self._goal_threshold, 0], [0, -self._goal_threshold] ] - self._viewer.polygon(self._goal, 0, goal_area, - color=(255, 0, 0), width=1) + self._viewer.polygon(self._goal, 0, goal_area, color=(255, 0, 0), width=1) frame = self._viewer.get_frame() if record else None @@ -128,7 +125,6 @@ def stop(self): def _get_reward(self, state): reward = -1. for cen, wid in zip(self._puddle_center, self._puddle_width): - reward -= 2. * norm.pdf(state[0], cen[0], wid[0]) * norm.pdf( - state[1], cen[1], wid[1]) + reward -= 2. * norm.pdf(state[0], cen[0], wid[0]) * norm.pdf(state[1], cen[1], wid[1]) return reward diff --git a/mushroom_rl/environments/pybullet.py b/mushroom_rl/environments/pybullet.py index 11931ae88..8bb276b8e 100644 --- a/mushroom_rl/environments/pybullet.py +++ b/mushroom_rl/environments/pybullet.py @@ -93,7 +93,7 @@ def reset(self, state=None): self._state = self._indexer.create_sim_state() observation = self._create_observation(self._state) - return observation + return observation, {} def render(self, record=False): frame = self._viewer.display() diff --git a/mushroom_rl/environments/segway.py b/mushroom_rl/environments/segway.py index f63958b4b..bba32c6f6 100644 --- a/mushroom_rl/environments/segway.py +++ b/mushroom_rl/environments/segway.py @@ -66,7 +66,7 @@ def reset(self, state=None): self._last_x = 0 - return self._state + return self._state, {} def step(self, action): u = self._bound(action[0], -self._max_u, self._max_u) @@ -101,12 +101,10 @@ def _dynamics(self, state, t, u): omegaP = d_alpha - dOmegaP = -(h2 * self._l * self._Mp * self._r * np.sin( - alpha) * omegaP**2 - self._g * h1 * self._l * self._Mp * np.sin( - alpha) + (h2 + h1) * u) / (h1 * h3 - h2**2) - dOmegaR = (h3 * self._l * self._Mp * self._r * np.sin( - alpha) * omegaP**2 - self._g * h2 * self._l * self._Mp * np.sin( - alpha) + (h3 + h2) * u) / (h1 * h3 - h2**2) + dOmegaP = -(h2 * self._l * self._Mp * self._r * np.sin( alpha) * omegaP**2 + - self._g * h1 * self._l * self._Mp * np.sin(alpha) + (h2 + h1) * u) / (h1 * h3 - h2**2) + dOmegaR = (h3 * self._l * self._Mp * self._r * np.sin(alpha) * omegaP**2 + - self._g * h2 * self._l * self._Mp * np.sin(alpha) + (h3 + h2) * u) / (h1 * h3 - h2**2) dx = list() dx.append(omegaP) @@ -124,8 +122,7 @@ def render(self, record=False): self._last_x += dx if self._last_x > 2.5 * self._l or self._last_x < -2.5 * self._l: - self._last_x = (2.5 * self._l + self._last_x) % ( - 5 * self._l) - 2.5 * self._l + self._last_x = (2.5 * self._l + self._last_x) % (5 * self._l) - 2.5 * self._l start[0] += self._last_x end[0] += -2 * self._l * np.sin(self._state[0]) + self._last_x diff --git a/mushroom_rl/environments/ship_steering.py b/mushroom_rl/environments/ship_steering.py index e2cc7c073..4502f4c8a 100644 --- a/mushroom_rl/environments/ship_steering.py +++ b/mushroom_rl/environments/ship_steering.py @@ -67,7 +67,7 @@ def reset(self, state=None): else: self._state = state - return self._state + return self._state, {} def step(self, action): @@ -83,10 +83,7 @@ def step(self, action): new_state[2] = normalize_angle(state[2] + state[3] * self.info.dt) new_state[3] = state[3] + (r - state[3]) * self.info.dt / self._T - if new_state[0] > self.field_size \ - or new_state[1] > self.field_size \ - or new_state[0] < 0 or new_state[1] < 0: - + if new_state[0] > self.field_size or new_state[1] > self.field_size or new_state[0] < 0 or new_state[1] < 0: new_state[0] = self._bound(new_state[0], 0, self.field_size) new_state[1] = self._bound(new_state[1], 0, self.field_size) diff --git a/mushroom_rl/policy/__init__.py b/mushroom_rl/policy/__init__.py index 74aabee61..ed96bdc54 100644 --- a/mushroom_rl/policy/__init__.py +++ b/mushroom_rl/policy/__init__.py @@ -5,6 +5,7 @@ StateStdGaussianPolicy, StateLogStdGaussianPolicy from .deterministic_policy import DeterministicPolicy from .torch_policy import TorchPolicy, GaussianTorchPolicy, BoltzmannTorchPolicy +from .recurrent_torch_policy import RecurrentGaussianTorchPolicy from .promps import ProMP from .dmp import DMP diff --git a/mushroom_rl/policy/deterministic_policy.py b/mushroom_rl/policy/deterministic_policy.py index 065de0ca1..62097de78 100644 --- a/mushroom_rl/policy/deterministic_policy.py +++ b/mushroom_rl/policy/deterministic_policy.py @@ -10,7 +10,7 @@ class DeterministicPolicy(ParametricPolicy): differentiable, even if the mean value approximator is differentiable. """ - def __init__(self, mu): + def __init__(self, mu, policy_state_shape=None): """ Constructor. @@ -19,11 +19,13 @@ def __init__(self, mu): in each state. """ + super().__init__(policy_state_shape) + self._approximator = mu self._predict_params = dict() - self._add_save_attr(_approximator='mushroom') - self._add_save_attr(_predict_params='pickle') + self._add_save_attr(_approximator='mushroom', + _predict_params='pickle') def get_regressor(self): """ @@ -35,13 +37,13 @@ def get_regressor(self): """ return self._approximator - def __call__(self, state, action): + def __call__(self, state, action, policy_state=None): policy_action = self._approximator.predict(state, **self._predict_params) return 1. if np.array_equal(action, policy_action) else 0. - def draw_action(self, state): - return self._approximator.predict(state, **self._predict_params) + def draw_action(self, state, policy_state=None): + return self._approximator.predict(state, **self._predict_params), None def set_weights(self, weights): self._approximator.set_weights(weights) diff --git a/mushroom_rl/policy/gaussian_policy.py b/mushroom_rl/policy/gaussian_policy.py index c7f8623cf..f92f4d90e 100644 --- a/mushroom_rl/policy/gaussian_policy.py +++ b/mushroom_rl/policy/gaussian_policy.py @@ -9,15 +9,22 @@ class AbstractGaussianPolicy(ParametricPolicy): Abstract class of Gaussian policies. """ - def __call__(self, state, action): + def __init__(self, policy_state_shape=None): + """ + Constructor. + + """ + super().__init__(policy_state_shape) + + def __call__(self, state, action, policy_state=None): mu, sigma = self._compute_multivariate_gaussian(state)[:2] return multivariate_normal.pdf(action, mu, sigma) - def draw_action(self, state): + def draw_action(self, state, policy_state=None): mu, sigma = self._compute_multivariate_gaussian(state)[:2] - return np.random.multivariate_normal(mu, sigma) + return np.random.multivariate_normal(mu, sigma), None class GaussianPolicy(AbstractGaussianPolicy): @@ -29,7 +36,7 @@ class GaussianPolicy(AbstractGaussianPolicy): matrix is fixed. """ - def __init__(self, mu, sigma): + def __init__(self, mu, sigma, policy_state_shape=None): """ Constructor. @@ -41,6 +48,8 @@ def __init__(self, mu, sigma): where n is the action dimensionality. """ + super().__init__(policy_state_shape) + self._approximator = mu self._predict_params = dict() self._inv_sigma = np.linalg.inv(sigma) @@ -65,7 +74,7 @@ def set_sigma(self, sigma): self._sigma = sigma self._inv_sigma = np.linalg.inv(sigma) - def diff_log(self, state, action): + def diff_log(self, state, action, policy_state=None): mu, _, inv_sigma = self._compute_multivariate_gaussian(state) delta = action - mu @@ -97,16 +106,12 @@ def _compute_multivariate_gaussian(self, state): class DiagonalGaussianPolicy(AbstractGaussianPolicy): """ - Gaussian policy with learnable standard deviation. - The Covariance matrix is - constrained to be a diagonal matrix, where the diagonal is the squared - standard deviation vector. - This is a differentiable policy for continuous action spaces. - This policy is similar to the gaussian policy, but the weights includes - also the standard deviation. + Gaussian policy with learnable standard deviation. The Covariance matrix is constrained to be a diagonal matrix, + where the diagonal is the squared standard deviation vector. This is a differentiable policy for continuous action + spaces. This policy is similar to the gaussian policy, but the weights includes also the standard deviation. """ - def __init__(self, mu, std): + def __init__(self, mu, std, policy_state_shape=None): """ Constructor. @@ -117,6 +122,8 @@ def __init__(self, mu, std): this vector must be equal to the action dimensionality. """ + super().__init__(policy_state_shape) + self._approximator = mu self._predict_params = dict() self._std = std @@ -138,7 +145,7 @@ def set_std(self, std): """ self._std = std - def diff_log(self, state, action): + def diff_log(self, state, action, policy_state=None): mu, _, inv_sigma = self._compute_multivariate_gaussian(state) delta = action - mu @@ -189,7 +196,7 @@ class StateStdGaussianPolicy(AbstractGaussianPolicy): deviation depends on the current state. """ - def __init__(self, mu, std, eps=1e-6): + def __init__(self, mu, std, eps=1e-6, policy_state_shape=None): """ Constructor. @@ -205,6 +212,8 @@ def __init__(self, mu, std, eps=1e-6): """ assert(eps > 0) + super().__init__(policy_state_shape) + self._mu_approximator = mu self._std_approximator = std self._predict_params = dict() @@ -217,7 +226,7 @@ def __init__(self, mu, std, eps=1e-6): _eps='primitive' ) - def diff_log(self, state, action): + def diff_log(self, state, action, policy_state=None): mu, sigma, std = self._compute_multivariate_gaussian(state) diag_sigma = np.diag(sigma) @@ -282,7 +291,7 @@ class StateLogStdGaussianPolicy(AbstractGaussianPolicy): regressor represents the logarithm of the standard deviation. """ - def __init__(self, mu, log_std): + def __init__(self, mu, log_std, policy_state_shape=None): """ Constructor. @@ -294,6 +303,8 @@ def __init__(self, mu, log_std): regressor must be equal to the action dimensionality. """ + super().__init__(policy_state_shape) + self._mu_approximator = mu self._log_std_approximator = log_std self._predict_params = dict() @@ -304,7 +315,7 @@ def __init__(self, mu, log_std): _predict_params='pickle' ) - def diff_log(self, state, action): + def diff_log(self, state, action, policy_state=None): mu, sigma = self._compute_multivariate_gaussian(state) diag_sigma = np.diag(sigma) @@ -343,8 +354,7 @@ def get_weights(self): @property def weights_size(self): - return self._mu_approximator.weights_size + \ - self._log_std_approximator.weights_size + return self._mu_approximator.weights_size + self._log_std_approximator.weights_size def _compute_multivariate_gaussian(self, state): mu = np.reshape(self._mu_approximator.predict( diff --git a/mushroom_rl/policy/noise_policy.py b/mushroom_rl/policy/noise_policy.py index 317d4eabd..6f2143c93 100644 --- a/mushroom_rl/policy/noise_policy.py +++ b/mushroom_rl/policy/noise_policy.py @@ -42,23 +42,21 @@ def __init__(self, mu, sigma, theta, dt, x0=None): _sigma='numpy', _theta='primitive', _dt='primitive', - _x0='numpy', - _x_prev='numpy' + _x0='numpy' ) - def __call__(self, state, action): + super().__init__(self._approximator.output_shape) + + def __call__(self, state, action=None, policy_state=None): raise NotImplementedError - def draw_action(self, state): + def draw_action(self, state, policy_state): mu = self._approximator.predict(state, **self._predict_params) - x = self._x_prev - self._theta * self._x_prev * self._dt +\ - self._sigma * np.sqrt(self._dt) * np.random.normal( - size=self._approximator.output_shape - ) - self._x_prev = x + x = policy_state - self._theta * policy_state * self._dt +\ + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._approximator.output_shape) - return mu + x + return mu + x, x def set_weights(self, weights): self._approximator.set_weights(weights) @@ -71,7 +69,7 @@ def weights_size(self): return self._approximator.weights_size def reset(self): - self._x_prev = self._x0 if self._x0 is not None else np.zeros(self._approximator.output_shape) + return self._x0 if self._x0 is not None else np.zeros(self._approximator.output_shape) class ClippedGaussianPolicy(ParametricPolicy): @@ -89,7 +87,7 @@ class ClippedGaussianPolicy(ParametricPolicy): if the value is bigger than the boundaries. Thus, the non-differentiability. """ - def __init__(self, mu, sigma, low, high): + def __init__(self, mu, sigma, low, high, policy_state_shape=None): """ Constructor. @@ -105,6 +103,8 @@ def __init__(self, mu, sigma, low, high): component. """ + super().__init__(policy_state_shape) + self._approximator = mu self._predict_params = dict() self._sigma = sigma @@ -119,15 +119,15 @@ def __init__(self, mu, sigma, low, high): _high='numpy' ) - def __call__(self, state, action): + def __call__(self, state, action=None, policy_state=None): raise NotImplementedError - def draw_action(self, state): + def draw_action(self, state, policy_state=None): mu = np.reshape(self._approximator.predict(np.expand_dims(state, axis=0), **self._predict_params), -1) action_raw = np.random.multivariate_normal(mu, self._sigma) - return np.clip(action_raw, self._low, self._high) + return np.clip(action_raw, self._low, self._high), None def set_weights(self, weights): self._approximator.set_weights(weights) diff --git a/mushroom_rl/policy/policy.py b/mushroom_rl/policy/policy.py index ef5cf0a3b..0d55b6c00 100644 --- a/mushroom_rl/policy/policy.py +++ b/mushroom_rl/policy/policy.py @@ -9,32 +9,44 @@ class Policy(Serializable): A policy is used by mushroom agents to interact with the environment. """ - def __call__(self, *args): + def __init__(self, policy_state_shape=None): + """ + Constructor. + + Args: + policy_state_shape (tuple, None): the shape of the internal state of the policy. + + """ + self.policy_state_shape = policy_state_shape + + def __call__(self, state, action, policy_state): """ Compute the probability of taking action in a certain state following the policy. Args: - *args (list): list containing a state or a state and an action. + state: state where you want to evaluate the policy density; + action: action where you want to evaluate the policy density; + policy_state: internal_state where you want to evaluate the policy density. Returns: - The probability of all actions following the policy in the given - state if the list contains only the state, else the probability - of the given action in the given state following the policy. If - the action space is continuous, state and action must be provided + The probability of all actions following the policy in the given state if the list contains only the state, + else the probability of the given action in the given state following the policy. If the action space is + continuous, state and action must be provided """ raise NotImplementedError - def draw_action(self, state): + def draw_action(self, state, policy_state): """ Sample an action in ``state`` using the policy. Args: - state (np.ndarray): the state where the agent is. + state: the state where the agent is; + policy_state: the internal state of the policy. Returns: - The action sampled from the policy. + The action sampled from the policy and optionally the next policy state. """ raise NotImplementedError @@ -44,20 +56,36 @@ def reset(self): Useful when the policy needs a special initialization at the beginning of an episode. + Returns: + The initial policy state (by default None). + """ - pass + return None + + @property + def is_stateful(self): + return self.policy_state_shape is not None class ParametricPolicy(Policy): """ Interface for a generic parametric policy. - A parametric policy is a policy that depends on set of parameters, - called the policy weights. - If the policy is differentiable, the derivative of the probability for a - specified state-action pair can be provided. + A parametric policy is a policy that depends on set of parameters, called the policy weights. + For differentiable policies, the derivative of the probability for a specified state-action pair can be provided. + """ - def diff_log(self, state, action): + def __init__(self, policy_state_shape=None): + """ + Constructor. + + Args: + policy_state_shape (tuple, None): the shape of the internal state of the policy. + + """ + super().__init__(policy_state_shape) + + def diff_log(self, state, action, policy_state): """ Compute the gradient of the logarithm of the probability density function, in the specified state and action pair, i.e.: @@ -67,15 +95,16 @@ def diff_log(self, state, action): Args: - state (np.ndarray): the state where the gradient is computed - action (np.ndarray): the action where the gradient is computed + state: the state where the gradient is computed; + action: the action where the gradient is computed; + policy_state: the internal state of the policy. Returns: The gradient of the logarithm of the pdf w.r.t. the policy weights """ raise RuntimeError('The policy is not differentiable') - def diff(self, state, action): + def diff(self, state, action, policy_state=None): """ Compute the derivative of the probability density function, in the specified state and action pair. Normally it is computed w.r.t. the @@ -87,13 +116,14 @@ def diff(self, state, action): Args: - state (np.ndarray): the state where the derivative is computed - action (np.ndarray): the action where the derivative is computed + state: the state where the derivative is computed; + action: the action where the derivative is computed; + policy_state: the internal state of the policy. Returns: The derivative w.r.t. the policy weights """ - return self(state, action) * self.diff_log(state, action) + return self(state, action, policy_state) * self.diff_log(state, action, policy_state) def set_weights(self, weights): """ diff --git a/mushroom_rl/policy/promps.py b/mushroom_rl/policy/promps.py index e34d2137f..d074a3e83 100644 --- a/mushroom_rl/policy/promps.py +++ b/mushroom_rl/policy/promps.py @@ -30,21 +30,20 @@ def __init__(self, mu, phi, duration, sigma=None, periodic=False): """ assert sigma is None or (len(sigma.shape) == 2 and sigma.shape[0] == sigma.shape[1]) + super().__init__(policy_state_shape=(1,)) + self._approximator = mu self._phi = phi self._duration = duration self._sigma = sigma self._periodic = periodic - self._step = 0 - self._add_save_attr( _approximator='mushroom', _phi='mushroom', _duration='primitive', _sigma='numpy', - _periodic='primitive', - _step='primitive' + _periodic='primitive' ) def __call__(self, state, action): @@ -56,19 +55,19 @@ def __call__(self, state, action): else: return multivariate_normal.pdf(action, mu, self._sigma) - def draw_action(self, state): + def draw_action(self, state, policy_state): z = self._compute_phase(state) - self.update_time(state) - mu = self._approximator(self._phi(z)) + next_policy_state = self.update_time(state, policy_state) + if self._sigma is None: - return mu + return mu, next_policy_state else: - return np.random.multivariate_normal(mu, self._sigma) + return np.random.multivariate_normal(mu, self._sigma), next_policy_state - def update_time(self, state): + def update_time(self, state, policy_state): """ Method that updates the time counter. Can be overridden to introduce complex state-dependant behaviors. @@ -76,12 +75,14 @@ def update_time(self, state): state (np.ndarray): The current state of the system. """ - self._step += 1 + policy_state += 1 + + if not self._periodic and policy_state >= self._duration: + policy_state = self._duration - if not self._periodic and self._step >= self._duration: - self._step = self._duration + return policy_state - def _compute_phase(self, state): + def _compute_phase(self, state, policy_state): """ Method that updates the state variable. It can be overridden to implement state dependent phase. @@ -92,7 +93,7 @@ def _compute_phase(self, state): The current value of the phase variable """ - return self._step / self._duration + return policy_state / self._duration def set_weights(self, weights): self._approximator.set_weights(weights) @@ -113,4 +114,4 @@ def set_duration(self, duration): self._duration = duration - 1 def reset(self): - self._step = 0 + return 0 diff --git a/mushroom_rl/policy/recurrent_torch_policy.py b/mushroom_rl/policy/recurrent_torch_policy.py new file mode 100644 index 000000000..0849d3cc3 --- /dev/null +++ b/mushroom_rl/policy/recurrent_torch_policy.py @@ -0,0 +1,63 @@ +import torch +import numpy as np + +from mushroom_rl.policy import GaussianTorchPolicy +from mushroom_rl.utils.torch import to_float_tensor +from mushroom_rl.utils.parameters import to_parameter + + +class RecurrentGaussianTorchPolicy(GaussianTorchPolicy): + def __init__(self, policy_state_shape, log_std_min=-20, log_std_max=2, **kwargs): + + super().__init__(policy_state_shape=policy_state_shape, **kwargs) + + self._log_std_min = to_parameter(log_std_min) + self._log_std_max = to_parameter(log_std_max) + + def reset(self): + return torch.zeros(self.policy_state_shape) + + def draw_action(self, state, policy_state): + with torch.no_grad(): + state = to_float_tensor(state) + policy_state = torch.as_tensor(policy_state) + a, policy_state = self.draw_action_t(state, policy_state) + return torch.squeeze(a, dim=0).detach().cpu().numpy(), policy_state + + def draw_action_t(self, state, policy_state): + lengths = torch.tensor([1]) + state = torch.atleast_2d(state).view(1, 1, -1) + policy_state = torch.atleast_2d(policy_state) + + dist, policy_state = self.distribution_and_policy_state_t(state, policy_state, lengths) + action = dist.sample().detach() + + return action, policy_state + + def log_prob_t(self, state, action, policy_state, lengths): + return self.distribution_t(state, policy_state, lengths).log_prob(action.squeeze())[:, None] + + def entropy_t(self, state=None): + return self._action_dim / 2 * np.log(2 * np.pi * np.e) + torch.sum(self._log_sigma) + + def distribution(self, state, policy_state, lengths): + s = to_float_tensor(state, self._use_cuda) + + return self.distribution_t(s, policy_state, lengths) + + def distribution_t(self, state, policy_state, lengths): + mu, sigma, _ = self.get_mean_and_covariance_and_policy_state(state, policy_state, lengths) + return torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=sigma) + + def distribution_and_policy_state_t(self, state, policy_state, lengths): + mu, sigma, policy_state = self.get_mean_and_covariance_and_policy_state(state, policy_state, lengths) + return torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=sigma), policy_state + + def get_mean_and_covariance_and_policy_state(self, state, policy_state, lengths): + mu, next_hidden_state = self._mu(state, policy_state, lengths, **self._predict_params, output_tensor=True) + + # Bound the log_std + log_sigma = torch.clamp(self._log_sigma, self._log_std_min(), self._log_std_max()) + + covariance = torch.diag(torch.exp(2 * log_sigma)) + return mu, covariance, next_hidden_state diff --git a/mushroom_rl/policy/td_policy.py b/mushroom_rl/policy/td_policy.py index 8f79d27c4..fbadca2a0 100644 --- a/mushroom_rl/policy/td_policy.py +++ b/mushroom_rl/policy/td_policy.py @@ -7,11 +7,13 @@ class TDPolicy(Policy): - def __init__(self): + def __init__(self, policy_state_shape=None): """ Constructor. """ + super().__init__(policy_state_shape) + self._approximator = None self._predict_params = dict() @@ -40,7 +42,7 @@ class EpsGreedy(TDPolicy): Epsilon greedy policy. """ - def __init__(self, epsilon): + def __init__(self, epsilon, policy_state_shape=None): """ Constructor. @@ -50,7 +52,7 @@ def __init__(self, epsilon): step. """ - super().__init__() + super().__init__(policy_state_shape) self._epsilon = to_parameter(epsilon) @@ -75,7 +77,7 @@ def __call__(self, *args): return probs - def draw_action(self, state): + def draw_action(self, state, policy_state=None): if not np.random.uniform() < self._epsilon(state): q = self._approximator.predict(state, **self._predict_params) max_a = np.argwhere(q == np.max(q)).ravel() @@ -83,9 +85,9 @@ def draw_action(self, state): if len(max_a) > 1: max_a = np.array([np.random.choice(max_a)]) - return max_a + return max_a, None - return np.array([np.random.choice(self._approximator.n_actions)]) + return np.array([np.random.choice(self._approximator.n_actions)]), None def set_epsilon(self, epsilon): """ @@ -116,7 +118,7 @@ class Boltzmann(TDPolicy): Boltzmann softmax policy. """ - def __init__(self, beta): + def __init__(self, beta, policy_state_shape=None): """ Constructor. @@ -127,7 +129,7 @@ def __init__(self, beta): more and more greedy. """ - super().__init__() + super().__init__(policy_state_shape) self._beta = to_parameter(beta) self._add_save_attr(_beta='mushroom') @@ -145,9 +147,8 @@ def __call__(self, *args): else: return qs / np.sum(qs) - def draw_action(self, state): - return np.array([np.random.choice(self._approximator.n_actions, - p=self(state))]) + def draw_action(self, state, policy_state=None): + return np.array([np.random.choice(self._approximator.n_actions, p=self(state))]), None def set_beta(self, beta): """ @@ -213,7 +214,7 @@ def f(beta): except ValueError: return 0. - def __init__(self, omega, beta_min=-10., beta_max=10.): + def __init__(self, omega, beta_min=-10., beta_max=10., policy_state_shape=None): """ Constructor. @@ -228,7 +229,7 @@ def __init__(self, omega, beta_min=-10., beta_max=10.): """ beta_mellow = self.MellowmaxParameter(self, omega, beta_min, beta_max) - super().__init__(beta_mellow) + super().__init__(beta_mellow, policy_state_shape) def set_beta(self, beta): raise RuntimeError('Cannot change the beta parameter of Mellowmax policy') diff --git a/mushroom_rl/policy/torch_policy.py b/mushroom_rl/policy/torch_policy.py index 3a1a6f32b..e9bd9a86a 100644 --- a/mushroom_rl/policy/torch_policy.py +++ b/mushroom_rl/policy/torch_policy.py @@ -20,7 +20,7 @@ class TorchPolicy(Policy): required. """ - def __init__(self, use_cuda): + def __init__(self, use_cuda, policy_state_shape=None): """ Constructor. @@ -28,22 +28,24 @@ def __init__(self, use_cuda): use_cuda (bool): whether to use cuda or not. """ + super().__init__(policy_state_shape) + self._use_cuda = use_cuda self._add_save_attr(_use_cuda='primitive') - def __call__(self, state, action): + def __call__(self, state, action, policy_state=None): s = to_float_tensor(np.atleast_2d(state), self._use_cuda) a = to_float_tensor(np.atleast_2d(action), self._use_cuda) return np.exp(self.log_prob_t(s, a).item()) - def draw_action(self, state): + def draw_action(self, state, policy_state=None): with torch.no_grad(): s = to_float_tensor(np.atleast_2d(state), self._use_cuda) a = self.draw_action_t(s) - return torch.squeeze(a, dim=0).detach().cpu().numpy() + return torch.squeeze(a, dim=0).detach().cpu().numpy(), None def distribution(self, state): """ @@ -167,9 +169,6 @@ def parameters(self): """ raise NotImplementedError - def reset(self): - pass - @property def use_cuda(self): """ @@ -185,7 +184,7 @@ class GaussianTorchPolicy(TorchPolicy): """ def __init__(self, network, input_shape, output_shape, std_0=1., - use_cuda=False, **params): + use_cuda=False, policy_state_shape=None, **params): """ Constructor. @@ -198,7 +197,7 @@ def __init__(self, network, input_shape, output_shape, std_0=1., params (dict): parameters used by the network constructor. """ - super().__init__(use_cuda) + super().__init__(use_cuda, policy_state_shape) self._action_dim = output_shape[0] @@ -260,7 +259,7 @@ class BoltzmannTorchPolicy(TorchPolicy): Torch policy implementing a Boltzmann policy. """ - def __init__(self, network, input_shape, output_shape, beta, use_cuda=False, **params): + def __init__(self, network, input_shape, output_shape, beta, use_cuda=False, policy_state_shape=None, **params): """ Constructor. @@ -276,13 +275,13 @@ def __init__(self, network, input_shape, output_shape, beta, use_cuda=False, **p params (dict): parameters used by the network constructor. """ - super().__init__(use_cuda) + super().__init__(use_cuda, policy_state_shape) self._action_dim = output_shape[0] self._predict_params = dict() self._logits = Regressor(TorchApproximator, input_shape, output_shape, - network=network, use_cuda=use_cuda, **params) + network=network, use_cuda=use_cuda, **params) self._beta = to_parameter(beta) self._add_save_attr( diff --git a/mushroom_rl/utils/dataset.py b/mushroom_rl/utils/dataset.py deleted file mode 100644 index 5c8acfb6b..000000000 --- a/mushroom_rl/utils/dataset.py +++ /dev/null @@ -1,235 +0,0 @@ -import numpy as np - -from mushroom_rl.utils.frames import LazyFrames - - -def parse_dataset(dataset, features=None): - """ - Split the dataset in its different components and return them. - - Args: - dataset (list): the dataset to parse; - features (object, None): features to apply to the states. - - Returns: - The np.ndarray of state, action, reward, next_state, absorbing flag and - last step flag. Features are applied to ``state`` and ``next_state``, - when provided. - - """ - assert len(dataset) > 0 - - shape = dataset[0][0].shape if features is None else (features.size,) - - state = np.ones((len(dataset),) + shape) - action = np.ones((len(dataset),) + dataset[0][1].shape) - reward = np.ones(len(dataset)) - next_state = np.ones((len(dataset),) + shape) - absorbing = np.ones(len(dataset)) - last = np.ones(len(dataset)) - - if features is not None: - for i in range(len(dataset)): - state[i, ...] = features(dataset[i][0]) - action[i, ...] = dataset[i][1] - reward[i] = dataset[i][2] - next_state[i, ...] = features(dataset[i][3]) - absorbing[i] = dataset[i][4] - last[i] = dataset[i][5] - else: - for i in range(len(dataset)): - state[i, ...] = dataset[i][0] - action[i, ...] = dataset[i][1] - reward[i] = dataset[i][2] - next_state[i, ...] = dataset[i][3] - absorbing[i] = dataset[i][4] - last[i] = dataset[i][5] - - return np.array(state), np.array(action), np.array(reward), np.array( - next_state), np.array(absorbing), np.array(last) - - -def arrays_as_dataset(states, actions, rewards, next_states, absorbings, lasts): - """ - Creates a dataset of transitions from the provided arrays. - - Args: - states (np.ndarray): array of states; - actions (np.ndarray): array of actions; - rewards (np.ndarray): array of rewards; - next_states (np.ndarray): array of next_states; - absorbings (np.ndarray): array of absorbing flags; - lasts (np.ndarray): array of last flags. - - Returns: - The list of transitions. - - """ - assert (len(states) == len(actions) == len(rewards) - == len(next_states) == len(absorbings) == len(lasts)) - - dataset = list() - for s, a, r, ss, ab, last in zip(states, actions, rewards, next_states, - absorbings.astype(bool), lasts.astype(bool) - ): - dataset.append((s, a, r.item(0), ss, ab.item(0), last.item(0))) - - return dataset - - -def compute_episodes_length(dataset): - """ - Compute the length of each episode in the dataset. - - Args: - dataset (list): the dataset to consider. - - Returns: - A list of length of each episode in the dataset. - - """ - lengths = list() - l = 0 - for sample in dataset: - l += 1 - if sample[-1] == 1: - lengths.append(l) - l = 0 - - return lengths - - -def select_first_episodes(dataset, n_episodes, parse=False): - """ - Return the first ``n_episodes`` episodes in the provided dataset. - - Args: - dataset (list): the dataset to consider; - n_episodes (int): the number of episodes to pick from the dataset; - parse (bool, False): whether to parse the dataset to return. - - Returns: - A subset of the dataset containing the first ``n_episodes`` episodes. - - """ - assert n_episodes >= 0, 'Number of episodes must be greater than or equal' \ - 'to zero.' - if n_episodes == 0: - return np.array([[]]) - - dataset = np.array(dataset, dtype=object) - last_idxs = np.argwhere(dataset[:, -1] == 1).ravel() - sub_dataset = dataset[:last_idxs[n_episodes - 1] + 1, :] - - return sub_dataset if not parse else parse_dataset(sub_dataset) - - -def select_random_samples(dataset, n_samples, parse=False): - """ - Return the randomly picked desired number of samples in the provided - dataset. - - Args: - dataset (list): the dataset to consider; - n_samples (int): the number of samples to pick from the dataset; - parse (bool, False): whether to parse the dataset to return. - - Returns: - A subset of the dataset containing randomly picked ``n_samples`` - samples. - - """ - assert n_samples >= 0, 'Number of samples must be greater than or equal' \ - 'to zero.' - if n_samples == 0: - return np.array([[]]) - - dataset = np.array(dataset, dtype=object) - idxs = np.random.randint(dataset.shape[0], size=n_samples) - sub_dataset = dataset[idxs, ...] - - return sub_dataset if not parse else parse_dataset(sub_dataset) - - -def get_init_states(dataset): - """ - Get the initial states of a dataset - - Args: - dataset (list): the dataset to consider. - - Returns: - An array of initial states of the considered dataset. - - """ - pick = True - x_0 = list() - for d in dataset: - if pick: - if isinstance(d[0], LazyFrames): - x_0.append(np.array(d[0])) - else: - x_0.append(d[0]) - pick = d[-1] - return np.array(x_0) - - -def compute_J(dataset, gamma=1.): - """ - Compute the cumulative discounted reward of each episode in the dataset. - - Args: - dataset (list): the dataset to consider; - gamma (float, 1.): discount factor. - - Returns: - The cumulative discounted reward of each episode in the dataset. - - """ - js = list() - - j = 0. - episode_steps = 0 - for i in range(len(dataset)): - j += gamma ** episode_steps * dataset[i][2] - episode_steps += 1 - if dataset[i][-1] or i == len(dataset) - 1: - js.append(j) - j = 0. - episode_steps = 0 - - if len(js) == 0: - return [0.] - return js - - -def compute_metrics(dataset, gamma=1.): - """ - Compute the metrics of each complete episode in the dataset. - - Args: - dataset (list): the dataset to consider; - gamma (float, 1.): the discount factor. - - Returns: - The minimum score reached in an episode, - the maximum score reached in an episode, - the mean score reached, - the median score reached, - the number of completed episodes. - - If no episode has been completed, it returns 0 for all values. - - """ - for i in reversed(range(len(dataset))): - if dataset[i][-1]: - i += 1 - break - - dataset = dataset[:i] - - if len(dataset) > 0: - J = compute_J(dataset, gamma) - return np.min(J), np.max(J), np.mean(J), np.median(J), len(J) - else: - return 0, 0, 0, 0, 0 diff --git a/mushroom_rl/utils/spaces.py b/mushroom_rl/utils/spaces.py index c5e252cb7..990ba0fa8 100644 --- a/mushroom_rl/utils/spaces.py +++ b/mushroom_rl/utils/spaces.py @@ -9,7 +9,7 @@ class Box(Serializable): spaces. It is similar to the ``Box`` class in ``gym.spaces.box``. """ - def __init__(self, low, high, shape=None): + def __init__(self, low, high, shape=None, data_type=float): """ Constructor. @@ -26,6 +26,7 @@ def __init__(self, low, high, shape=None): of the i-th dimension; shape (np.ndarray, None): the dimension of the space. Must match the shape of ``low`` and ``high``, if they are np.ndarray. + data_type (class, float): the data type to be used. """ if shape is None: @@ -42,9 +43,12 @@ def __init__(self, low, high, shape=None): assert self._low.shape == self._high.shape + self._data_type = data_type + self._add_save_attr( _low='numpy', - _high='numpy' + _high='numpy', + _data_type='primitive' ) @property @@ -74,6 +78,10 @@ def shape(self): """ return self._shape + @property + def data_type(self): + return self._data_type + def _post_load(self): self._shape = self._low.shape @@ -117,3 +125,7 @@ def shape(self): """ return 1, + + @property + def data_type(self): + return int diff --git a/mushroom_rl/utils/value_functions.py b/mushroom_rl/utils/value_functions.py index 0b5e8e7bb..496aa558d 100644 --- a/mushroom_rl/utils/value_functions.py +++ b/mushroom_rl/utils/value_functions.py @@ -1,4 +1,4 @@ -import numpy as np +import torch def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma): @@ -9,31 +9,32 @@ def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma): Args: V (Regressor): the current value function regressor; - s (numpy.ndarray): the set of states in which we want + s (torch.tensor): the set of states in which we want to evaluate the advantage; - ss (numpy.ndarray): the set of next states in which we want + ss (torch.tensor): the set of next states in which we want to evaluate the advantage; - r (numpy.ndarray): the reward obtained in each transition + r (torch.tensor): the reward obtained in each transition from state s to state ss; - absorbing (numpy.ndarray): an array of boolean flags indicating + absorbing (torch.tensor): an array of boolean flags indicating if the reached state is absorbing; gamma (float): the discount factor of the considered problem. Returns: The new estimate for the value function of the next state and the advantage function. """ - r = r.squeeze() - q = np.zeros(len(r)) - v = V(s).squeeze() + with torch.no_grad(): + r = r.squeeze() + q = torch.zeros(len(r)) + v = V(s, output_tensor=True).squeeze() - q_next = V(ss[-1]).squeeze().item() - for rev_k in range(len(r)): - k = len(r) - rev_k - 1 - q_next = r[k] + gamma * q_next * (1. - absorbing[k]) - q[k] = q_next + q_next = V(ss[-1]).squeeze().item() + for rev_k in range(len(r)): + k = len(r) - rev_k - 1 + q_next = r[k] + gamma * q_next * (1 - absorbing[k].int()) + q[k] = q_next - adv = q - v - return q[:, np.newaxis], adv[:, np.newaxis] + adv = q - v + return q[:, None], adv[:, None] def compute_advantage(V, s, ss, r, absorbing, gamma): @@ -43,25 +44,26 @@ def compute_advantage(V, s, ss, r, absorbing, gamma): Args: V (Regressor): the current value function regressor; - s (numpy.ndarray): the set of states in which we want + s (torch.tensor): the set of states in which we want to evaluate the advantage; - ss (numpy.ndarray): the set of next states in which we want + ss (torch.tensor): the set of next states in which we want to evaluate the advantage; - r (numpy.ndarray): the reward obtained in each transition + r (torch.tensor): the reward obtained in each transition from state s to state ss; - absorbing (numpy.ndarray): an array of boolean flags indicating + absorbing (torch.tensor): an array of boolean flags indicating if the reached state is absorbing; gamma (float): the discount factor of the considered problem. Returns: The new estimate for the value function of the next state and the advantage function. """ - v = V(s).squeeze() - v_next = V(ss).squeeze() * (1 - absorbing) + with torch.no_grad(): + v = V(s, output_tensor=True).squeeze() + v_next = V(ss).squeeze() * (1 - absorbing.int()) - q = r + gamma * v_next - adv = q - v - return q[:, np.newaxis], adv[:, np.newaxis] + q = r + gamma * v_next + adv = q - v + return q[:, None], adv[:, None] def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): @@ -75,15 +77,15 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): Args: V (Regressor): the current value function regressor; - s (numpy.ndarray): the set of states in which we want + s (torch.tensor): the set of states in which we want to evaluate the advantage; - ss (numpy.ndarray): the set of next states in which we want + ss (torch.tensor): the set of next states in which we want to evaluate the advantage; - r (numpy.ndarray): the reward obtained in each transition + r (torch.tensor): the reward obtained in each transition from state s to state ss; - absorbing (numpy.ndarray): an array of boolean flags indicating + absorbing (torch.tensor): an array of boolean flags indicating if the reached state is absorbing; - last (numpy.ndarray): an array of boolean flags indicating + last (torch.tensor): an array of boolean flags indicating if the reached state is the last of the trajectory; gamma (float): the discount factor of the considered problem; lam (float): the value for the lamba coefficient used by GEA @@ -92,15 +94,16 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): The new estimate for the value function of the next state and the estimated generalized advantage. """ - v = V(s) - v_next = V(ss) - gen_adv = np.empty_like(v) - for rev_k in range(len(v)): - k = len(v) - rev_k - 1 - if last[k] or rev_k == 0: - gen_adv[k] = r[k] - v[k] - if not absorbing[k]: - gen_adv[k] += gamma * v_next[k] - else: - gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] - return gen_adv + v, gen_adv \ No newline at end of file + with torch.no_grad(): + v = V(s, output_tensor=True) + v_next = V(ss, output_tensor=True) + gen_adv = torch.empty_like(v) + for rev_k in range(len(v)): + k = len(v) - rev_k - 1 + if last[k] or rev_k == 0: + gen_adv[k] = r[k] - v[k] + if not absorbing[k]: + gen_adv[k] += gamma * v_next[k] + else: + gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] + return gen_adv + v, gen_adv \ No newline at end of file diff --git a/tests/algorithms/helper/utils.py b/tests/algorithms/helper/utils.py index 2fed7444a..7f2b9634c 100644 --- a/tests/algorithms/helper/utils.py +++ b/tests/algorithms/helper/utils.py @@ -4,7 +4,7 @@ import itertools import mushroom_rl -from mushroom_rl.core import MDPInfo +from mushroom_rl.core import MDPInfo, AgentInfo from mushroom_rl.policy.td_policy import TDPolicy from mushroom_rl.policy.torch_policy import TorchPolicy from mushroom_rl.policy.policy import ParametricPolicy @@ -60,6 +60,8 @@ def assert_eq(cls, this, that): assert cls.eq_chain(this, that) elif cls._check_type(this, that, MDPInfo): assert cls.eq_mdp_info(this, that) + elif cls._check_type(this, that, AgentInfo): + assert cls.eq_agent_info(this, that) elif cls._check_type(this, that, ReplayMemory): assert cls.eq_replay_memory(this, that) elif cls._check_type(this, that, PrioritizedReplayMemory): @@ -170,6 +172,18 @@ def eq_mdp_info(cls, this, that): res &= this.horizon == that.horizon return res + @classmethod + def eq_agent_info(cls, this, that): + """ + Compare two mdp_info objects for equality + """ + res = this.is_episodic == that.is_episodic + res &= this.is_stateful == that.is_stateful + res &= this.policy_state_shape == that.policy_state_shape + res &= this.backend == that.backend + + return res + @classmethod def eq_ornstein_uhlenbeck_policy(cls, this, that): """ diff --git a/tests/algorithms/test_dpg.py b/tests/algorithms/test_dpg.py index 203427d98..ef8de329e 100644 --- a/tests/algorithms/test_dpg.py +++ b/tests/algorithms/test_dpg.py @@ -35,16 +35,12 @@ def learn_copdac_q(): input_shape = (phi.size,) - mu = Regressor(LinearApproximator, input_shape=input_shape, - output_shape=mdp.info.action_space.shape) + mu = Regressor(LinearApproximator, input_shape=input_shape, output_shape=mdp.info.action_space.shape, phi=phi) sigma = 1e-1 * np.eye(1) policy = GaussianPolicy(mu, sigma) - agent = COPDAC_Q(mdp.info, policy, mu, - alpha_theta, alpha_omega, alpha_v, - value_function_features=phi, - policy_features=phi) + agent = COPDAC_Q(mdp.info, policy, mu, alpha_theta, alpha_omega, alpha_v, value_function_features=phi) # Train core = Core(agent, mdp) diff --git a/tests/algorithms/test_fqi.py b/tests/algorithms/test_fqi.py index a92fc71a5..253e4b127 100644 --- a/tests/algorithms/test_fqi.py +++ b/tests/algorithms/test_fqi.py @@ -9,7 +9,6 @@ from mushroom_rl.core import Core from mushroom_rl.environments import * from mushroom_rl.policy import EpsGreedy -from mushroom_rl.utils.dataset import compute_J from mushroom_rl.utils.parameters import Parameter @@ -31,8 +30,7 @@ def learn(alg, alg_params): approximator = ExtraTreesRegressor # Agent - agent = alg(mdp.info, pi, approximator, - approximator_params=approximator_params, **alg_params) + agent = alg(mdp.info, pi, approximator, approximator_params=approximator_params, **alg_params) # Algorithm core = Core(agent, mdp) @@ -44,7 +42,7 @@ def learn(alg, alg_params): agent.policy.set_epsilon(test_epsilon) dataset = core.evaluate(n_episodes=2) - return agent, np.mean(compute_J(dataset, mdp.info.gamma)) + return agent, np.mean(dataset.compute_J(mdp.info.gamma)) def test_fqi(): diff --git a/tests/algorithms/test_lspi.py b/tests/algorithms/test_lspi.py index 93a37ca98..f47399133 100644 --- a/tests/algorithms/test_lspi.py +++ b/tests/algorithms/test_lspi.py @@ -30,9 +30,9 @@ def learn_lspi(): fit_params = dict() approximator_params = dict(input_shape=(features.size,), output_shape=(mdp.info.action_space.n,), - n_actions=mdp.info.action_space.n) - agent = LSPI(mdp.info, pi, approximator_params=approximator_params, - fit_params=fit_params, features=features) + n_actions=mdp.info.action_space.n, + phi=features) + agent = LSPI(mdp.info, pi, approximator_params=approximator_params, fit_params=fit_params) # Algorithm core = Core(agent, mdp) diff --git a/tests/algorithms/test_stochastic_ac.py b/tests/algorithms/test_stochastic_ac.py index 2a1266c44..503be0423 100644 --- a/tests/algorithms/test_stochastic_ac.py +++ b/tests/algorithms/test_stochastic_ac.py @@ -40,11 +40,9 @@ def learn(alg): input_shape = (phi.size,) - mu = Regressor(LinearApproximator, input_shape=input_shape, - output_shape=mdp.info.action_space.shape) + mu = Regressor(LinearApproximator, input_shape=input_shape, output_shape=mdp.info.action_space.shape, phi=phi) - std = Regressor(LinearApproximator, input_shape=input_shape, - output_shape=mdp.info.action_space.shape) + std = Regressor(LinearApproximator, input_shape=input_shape, output_shape=mdp.info.action_space.shape, phi=phi) std_0 = np.sqrt(1.) std.set_weights(np.log(std_0) / n_tilings * np.ones(std.weights_size)) @@ -52,12 +50,11 @@ def learn(alg): policy = StateLogStdGaussianPolicy(mu, std) if alg is StochasticAC: - agent = alg(mdp.info, policy, alpha_theta, alpha_v, lambda_par=.5, - value_function_features=psi, policy_features=phi) + agent = alg(mdp.info, policy, alpha_theta, alpha_v, lambda_par=.5, value_function_features=psi) elif alg is StochasticAC_AVG: - agent = alg(mdp.info, policy, alpha_theta, alpha_v, alpha_r, - lambda_par=.5, value_function_features=psi, - policy_features=phi) + agent = alg(mdp.info, policy, alpha_theta, alpha_v, alpha_r, lambda_par=.5, value_function_features=psi) + else: + assert False core = Core(agent, mdp) diff --git a/tests/algorithms/test_td.py b/tests/algorithms/test_td.py index 9acbbaee2..fb043e733 100644 --- a/tests/algorithms/test_td.py +++ b/tests/algorithms/test_td.py @@ -18,6 +18,14 @@ from mushroom_rl.utils.parameters import Parameter +def assert_properly_loaded(agent_save, agent_load): + for att, method in vars(agent_save).items(): + if att != 'next_action': + save_attr = getattr(agent_save, att) + load_attr = getattr(agent_load, att) + tu.assert_eq(save_attr, load_attr) + + class Network(nn.Module): def __init__(self, input_shape, output_shape, **kwargs): super().__init__() @@ -80,10 +88,7 @@ def test_q_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_double_q_learning(): @@ -122,11 +127,7 @@ def test_double_q_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_weighted_q_learning(): @@ -160,11 +161,7 @@ def test_weighted_q_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_maxmin_q_learning(): @@ -198,11 +195,7 @@ def test_maxmin_q_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_speedy_q_learning(): @@ -236,11 +229,7 @@ def test_speedy_q_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_sarsa(): @@ -274,11 +263,7 @@ def test_sarsa_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_q_lambda(): @@ -312,11 +297,7 @@ def test_q_lambda_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_sarsa_lambda_discrete(): @@ -350,11 +331,7 @@ def test_sarsa_lambda_discrete_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_sarsa_lambda_continuous_linear(): @@ -369,11 +346,11 @@ def test_sarsa_lambda_continuous_linear(): approximator_params = dict( input_shape=(features.size,), output_shape=(mdp_continuous.info.action_space.n,), - n_actions=mdp_continuous.info.action_space.n + n_actions=mdp_continuous.info.action_space.n, + phi=features ) agent = SARSALambdaContinuous(mdp_continuous.info, pi, LinearApproximator, - Parameter(.1), .9, features=features, - approximator_params=approximator_params) + Parameter(.1), .9, approximator_params=approximator_params) core = Core(agent, mdp_continuous) @@ -402,11 +379,11 @@ def test_sarsa_lambda_continuous_linear_save(tmpdir): approximator_params = dict( input_shape=(features.size,), output_shape=(mdp_continuous.info.action_space.n,), - n_actions=mdp_continuous.info.action_space.n + n_actions=mdp_continuous.info.action_space.n, + phi=features, ) - agent_save = SARSALambdaContinuous(mdp_continuous.info, pi, LinearApproximator, - Parameter(.1), .9, features=features, - approximator_params=approximator_params) + agent_save = SARSALambdaContinuous(mdp_continuous.info, pi, LinearApproximator, Parameter(.1), .9, + approximator_params=approximator_params) core = Core(agent_save, mdp_continuous) @@ -416,28 +393,19 @@ def test_sarsa_lambda_continuous_linear_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_sarsa_lambda_continuous_nn(): pi, _, mdp_continuous = initialize() - - features = Features( - n_outputs=mdp_continuous.info.observation_space.shape[0] - ) approximator_params = dict( - input_shape=(features.size,), + input_shape=mdp_continuous.info.observation_space.shape, output_shape=(mdp_continuous.info.action_space.n,), network=Network, - n_actions=mdp_continuous.info.action_space.n + n_actions=mdp_continuous.info.action_space.n, ) - agent = SARSALambdaContinuous(mdp_continuous.info, pi, TorchApproximator, - Parameter(.1), .9, features=features, + agent = SARSALambdaContinuous(mdp_continuous.info, pi, TorchApproximator, Parameter(.1), .9, approximator_params=approximator_params) core = Core(agent, mdp_continuous) @@ -457,19 +425,14 @@ def test_sarsa_lambda_continuous_nn_save(tmpdir): pi, _, mdp_continuous = initialize() - features = Features( - n_outputs=mdp_continuous.info.observation_space.shape[0] - ) - approximator_params = dict( - input_shape=(features.size,), + input_shape=mdp_continuous.info.observation_space.shape, output_shape=(mdp_continuous.info.action_space.n,), network=Network, n_actions=mdp_continuous.info.action_space.n ) - agent_save = SARSALambdaContinuous(mdp_continuous.info, pi, TorchApproximator, - Parameter(.1), .9, features=features, - approximator_params=approximator_params) + agent_save = SARSALambdaContinuous(mdp_continuous.info, pi, TorchApproximator, Parameter(.1), .9, + approximator_params=approximator_params) core = Core(agent_save, mdp_continuous) @@ -479,11 +442,7 @@ def test_sarsa_lambda_continuous_nn_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_expected_sarsa(): @@ -517,11 +476,7 @@ def test_expected_sarsa_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_true_online_sarsa_lambda(): @@ -536,10 +491,10 @@ def test_true_online_sarsa_lambda(): approximator_params = dict( input_shape=(features.size,), output_shape=(mdp_continuous.info.action_space.n,), - n_actions=mdp_continuous.info.action_space.n + n_actions=mdp_continuous.info.action_space.n, + phi=features, ) - agent = TrueOnlineSARSALambda(mdp_continuous.info, pi, - Parameter(.1), .9, features=features, + agent = TrueOnlineSARSALambda(mdp_continuous.info, pi, Parameter(.1), .9, approximator_params=approximator_params) core = Core(agent, mdp_continuous) @@ -571,11 +526,11 @@ def test_true_online_sarsa_lambda_save(tmpdir): approximator_params = dict( input_shape=(features.size,), output_shape=(mdp_continuous.info.action_space.n,), - n_actions=mdp_continuous.info.action_space.n + n_actions=mdp_continuous.info.action_space.n, + phi=features, ) - agent_save = TrueOnlineSARSALambda(mdp_continuous.info, pi, - Parameter(.1), .9, features=features, - approximator_params=approximator_params) + agent_save = TrueOnlineSARSALambda(mdp_continuous.info, pi, Parameter(.1), .9, + approximator_params=approximator_params) core = Core(agent_save, mdp_continuous) @@ -585,11 +540,7 @@ def test_true_online_sarsa_lambda_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_r_learning(): @@ -623,11 +574,7 @@ def test_r_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) def test_rq_learning(): @@ -707,7 +654,4 @@ def test_rq_learning_save(tmpdir): agent_save.save(agent_path) agent_load = Agent.load(agent_path) - for att, method in vars(agent_save).items(): - save_attr = getattr(agent_save, att) - load_attr = getattr(agent_load, att) - tu.assert_eq(save_attr, load_attr) + assert_properly_loaded(agent_save, agent_load) diff --git a/tests/core/test_core.py b/tests/core/test_core.py index da821fc05..aa6fc8778 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -8,10 +8,11 @@ class RandomDiscretePolicy(Policy): def __init__(self, n): + super().__init__() self._n = n - def draw_action(self, state): - return [np.random.randint(self._n)] + def draw_action(self, state, policy_state=None): + return [np.random.randint(self._n)], None class DummyAgent(Agent): @@ -19,7 +20,7 @@ def __init__(self, mdp_info): policy = RandomDiscretePolicy(mdp_info.action_space.n) super().__init__(mdp_info, policy) - def fit(self, dataset, **info): + def fit(self, dataset): pass @@ -35,13 +36,13 @@ def test_core(): core.learn(n_steps=100, n_steps_per_fit=1) - dataset, info = core.evaluate(n_steps=20, get_env_info=True) + dataset = core.evaluate(n_steps=20) - assert 'lives' in info - assert 'episode_frame_number' in info - assert 'frame_number' in info + assert 'lives' in dataset.info + assert 'episode_frame_number' in dataset.info + assert 'frame_number' in dataset.info - info_lives = np.array(info['lives']) + info_lives = np.array(dataset.info['lives']) print(info_lives) lives_gt = np.array([5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py new file mode 100644 index 000000000..474205537 --- /dev/null +++ b/tests/core/test_dataset.py @@ -0,0 +1,123 @@ +import numpy as np +import torch + +from mushroom_rl.core import Core, Dataset +from mushroom_rl.algorithms.value import SARSA +from mushroom_rl.environments import GridWorld +from mushroom_rl.utils.parameters import Parameter +from mushroom_rl.policy import EpsGreedy + + +def generate_dataset(mdp, n_episodes): + epsilon = Parameter(value=0.) + alpha = Parameter(value=0.) + pi = EpsGreedy(epsilon=epsilon) + + agent = SARSA(mdp.info, pi, alpha) + core = Core(agent, mdp) + + return core.evaluate(n_episodes=n_episodes) + + +def test_dataset(): + np.random.seed(42) + mdp = GridWorld(3, 3, (2, 2)) + dataset = generate_dataset(mdp, 10) + + J = dataset.compute_J(mdp.info.gamma) + J_test = np.array([4.304672100000001, 2.287679245496101, 3.138105960900001, 0.13302794647291147, + 7.290000000000001, 1.8530201888518416, 1.3508517176729928, 0.011790184577738602, + 1.3508517176729928, 7.290000000000001]) + assert np.allclose(J, J_test) + + L = dataset.episodes_length + L_test = np.array([9, 15, 12, 42, 4, 17, 20, 65, 20, 4]) + assert np.array_equal(L, L_test) + + dataset_ep = dataset.select_first_episodes(3) + J = dataset_ep.compute_J(mdp.info.gamma) + assert np.allclose(J, J_test[:3]) + + L = dataset_ep.episodes_length + assert np.allclose(L, L_test[:3]) + + samples = dataset.select_random_samples(2) + s, a, r, ss, ab, last = samples.parse() + s_test = np.array([[5.], [6.]]) + a_test = np.array([[3.], [0.]]) + r_test = np.zeros(2) + ss_test = np.array([[5], [3]]) + ab_test = np.zeros(2) + last_test = np.zeros(2) + assert np.array_equal(s, s_test) + assert np.array_equal(a, a_test) + assert np.array_equal(r, r_test) + assert np.array_equal(ss, ss_test) + assert np.array_equal(ab, ab_test) + assert np.array_equal(last, last_test) + + s0 = dataset.get_init_states() + s0_test = np.zeros((10, 1)) + assert np.array_equal(s0, s0_test) + + index = np.sum(L_test[:2]) + L_test[2]//2 + min_J, max_J, mean_J, median_J, n_episodes = dataset[:index].compute_metrics(mdp.info.gamma) + assert min_J == 2.287679245496101 + assert max_J == 4.304672100000001 + assert mean_J == 3.296175672748051 + assert median_J == 3.296175672748051 + assert n_episodes == 2 + + +def test_dataset_creation(): + np.random.seed(42) + + mdp = GridWorld(3, 3, (2, 2)) + dataset = generate_dataset(mdp, 5) + + parsed = tuple(dataset.parse()) + parsed_torch = (torch.from_numpy(array) for array in parsed) + + print(len(parsed)) + + new_numpy_dataset = Dataset.from_array(*parsed, gamma=mdp.info.gamma) + new_list_dataset = Dataset.from_array(*parsed, gamma=mdp.info.gamma, backend='list') + new_torch_dataset = Dataset.from_array(*parsed, gamma=mdp.info.gamma, backend='torch') + + assert vars(dataset).keys() == vars(new_numpy_dataset).keys() + assert vars(dataset).keys() == vars(new_list_dataset).keys() + assert vars(dataset).keys() == vars(new_torch_dataset).keys() + + for array_1, array_2 in zip(parsed, new_numpy_dataset.parse()): + assert np.array_equal(array_1, array_2) + + for array_1, array_2 in zip(parsed, new_list_dataset.parse()): + assert np.array_equal(array_1, array_2) + + for array_1, array_2 in zip(parsed_torch, new_torch_dataset.parse(to='torch')): + assert torch.equal(array_1, array_2) + + +def test_dataset_loading(tmpdir): + np.random.seed(42) + + mdp = GridWorld(3, 3, (2, 2)) + dataset = generate_dataset(mdp, 20) + + path = tmpdir / 'dataset_test.msh' + dataset.save(path) + + new_dataset = dataset.load(path) + + assert vars(dataset).keys() == vars(new_dataset).keys() + + assert np.array_equal(dataset.state, new_dataset.state) and \ + np.array_equal(dataset.action, new_dataset.action) and \ + np.array_equal(dataset.reward, new_dataset.reward) and \ + np.array_equal(dataset.next_state, new_dataset.next_state) and \ + np.array_equal(dataset.absorbing, new_dataset.absorbing) and \ + np.array_equal(dataset.last, new_dataset.last) + + assert dataset._gamma == new_dataset._gamma + + diff --git a/tests/environments/mujoco_envs/test_ball_in_a_cup.py b/tests/environments/mujoco_envs/test_ball_in_a_cup.py index fd5eebc60..c20982e1b 100644 --- a/tests/environments/mujoco_envs/test_ball_in_a_cup.py +++ b/tests/environments/mujoco_envs/test_ball_in_a_cup.py @@ -14,10 +14,10 @@ def test_ball_in_a_cup(): p_gains = np.array([200, 300, 100, 100, 10, 10, 2.5])/5 d_gains = np.array([7, 15, 5, 2.5, 0.3, 0.3, 0.05])/10 - obs_0 = env.reset() + obs_0, _ = env.reset() for _ in [1,2]: - obs = env.reset() + obs, _ = env.reset() assert np.array_equal(obs, obs_0) done = False diff --git a/tests/policy/test_deterministic_policy.py b/tests/policy/test_deterministic_policy.py index c881e2c92..25753183a 100644 --- a/tests/policy/test_deterministic_policy.py +++ b/tests/policy/test_deterministic_policy.py @@ -36,5 +36,6 @@ def test_deterministic_policy(): assert pi(s_test_2, a_test) == 0 a_stored = np.array([-1.86941072, -0.1789696]) - assert np.allclose(pi.draw_action(s_test_1), a_stored) + action, _ = pi.draw_action(s_test_1) + assert np.allclose(action, a_stored) diff --git a/tests/policy/test_gaussian_policy.py b/tests/policy/test_gaussian_policy.py index 54e0b539f..66756703b 100644 --- a/tests/policy/test_gaussian_policy.py +++ b/tests/policy/test_gaussian_policy.py @@ -22,7 +22,7 @@ def test_univariate_gaussian(): for x_i in x: state = np.atleast_1d(x_i) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) exact_diff = pi.diff(state, action) numerical_diff = numerical_diff_policy(pi, state, action) @@ -50,7 +50,7 @@ def test_multivariate_gaussian(): for x_i in x: state = np.atleast_1d(x_i) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) exact_diff = pi.diff(state, action) numerical_diff = numerical_diff_policy(pi, state, action) @@ -76,7 +76,7 @@ def test_multivariate_diagonal_gaussian(): for x_i in x: state = np.atleast_1d(x_i) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) exact_diff = pi.diff(state, action) numerical_diff = numerical_diff_policy(pi, state, action) @@ -104,7 +104,7 @@ def test_multivariate_state_std_gaussian(): for x_i in x: state = np.atleast_1d(x_i) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) exact_diff = pi.diff(state, action) numerical_diff = numerical_diff_policy(pi, state, action) @@ -132,7 +132,7 @@ def test_multivariate_state_log_std_gaussian(): for x_i in x: state = np.atleast_1d(x_i) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) exact_diff = pi.diff(state, action) numerical_diff = numerical_diff_policy(pi, state, action) diff --git a/tests/policy/test_noise_policy.py b/tests/policy/test_noise_policy.py index aff1f3254..85213c2c4 100644 --- a/tests/policy/test_noise_policy.py +++ b/tests/policy/test_noise_policy.py @@ -17,12 +17,14 @@ def test_ornstein_uhlenbeck_policy(): state = np.random.randn(5) - action = pi.draw_action(state) + policy_state = pi.reset() + + action, policy_state = pi.draw_action(state, policy_state) action_test = np.array([-1.95896171, 1.91292747]) assert np.allclose(action, action_test) - pi.reset() - action = pi.draw_action(state) + policy_state = pi.reset() + action, policy_state = pi.draw_action(state, policy_state) action_test = np.array([-1.94161061, 1.92233358]) assert np.allclose(action, action_test) diff --git a/tests/policy/test_policy_interface.py b/tests/policy/test_policy_interface.py index a61607b7a..dee48ba63 100644 --- a/tests/policy/test_policy_interface.py +++ b/tests/policy/test_policy_interface.py @@ -12,15 +12,15 @@ def abstract_method_tester(f, ex, *args): def test_policy_interface(): tmp = Policy() - abstract_method_tester(tmp.__call__, NotImplementedError) - abstract_method_tester(tmp.draw_action, NotImplementedError, None) + abstract_method_tester(tmp.__call__, NotImplementedError, None, None, None) + abstract_method_tester(tmp.draw_action, NotImplementedError, None, None) tmp.reset() def test_parametric_policy(): tmp = ParametricPolicy() - abstract_method_tester(tmp.diff_log, RuntimeError, None, None) - abstract_method_tester(tmp.diff, RuntimeError, None, None) + abstract_method_tester(tmp.diff_log, RuntimeError, None, None, None) + abstract_method_tester(tmp.diff, RuntimeError, None, None, None) abstract_method_tester(tmp.set_weights, NotImplementedError, None) abstract_method_tester(tmp.get_weights, NotImplementedError) try: diff --git a/tests/policy/test_td_policy.py b/tests/policy/test_td_policy.py index e1890f278..0563f7dc8 100644 --- a/tests/policy/test_td_policy.py +++ b/tests/policy/test_td_policy.py @@ -33,7 +33,7 @@ def test_eps_greedy(): p_sa_test = np.array([0.93333333]) assert np.allclose(p_sa, p_sa_test) - a = pi.draw_action(s) + a, _ = pi.draw_action(s) a_test = 1 assert a.item() == a_test @@ -70,7 +70,7 @@ def test_boltzmann(): p_sa_test = np.array([0.36223227]) assert np.allclose(p_sa, p_sa_test) - a = pi.draw_action(s) + a, _ = pi.draw_action(s) a_test = 2 assert a.item() == a_test @@ -106,7 +106,7 @@ def test_mellowmax(): p_sa_test = np.array([0.69215916]) assert np.allclose(p_sa, p_sa_test) - a = pi.draw_action(s) + a, _ = pi.draw_action(s) a_test = 2 assert a.item() == a_test diff --git a/tests/policy/test_torch_policy.py b/tests/policy/test_torch_policy.py index 17614ad08..ce5dac90b 100644 --- a/tests/policy/test_torch_policy.py +++ b/tests/policy/test_torch_policy.py @@ -61,7 +61,7 @@ def test_gaussian_torch_policy(): pi = GaussianTorchPolicy(Network, (3,), (2,), n_features=50) state = np.random.rand(3) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) action_test = np.array([-0.21276927, 0.27437747]) assert np.allclose(action, action_test) @@ -81,7 +81,7 @@ def test_boltzmann_torch_policy(): pi = BoltzmannTorchPolicy(Network, (3,), (2,), beta, n_features=50) state = np.random.rand(3, 3) - action = pi.draw_action(state) + action, _ = pi.draw_action(state) action_test = np.array([1, 0, 0]) assert np.allclose(action, action_test) diff --git a/tests/utils/test_dataset.py b/tests/utils/test_dataset.py deleted file mode 100644 index 37dce8581..000000000 --- a/tests/utils/test_dataset.py +++ /dev/null @@ -1,67 +0,0 @@ -from mushroom_rl.core import Core -from mushroom_rl.algorithms.value import SARSA -from mushroom_rl.environments import GridWorld -from mushroom_rl.utils.parameters import Parameter -from mushroom_rl.policy import EpsGreedy - -from mushroom_rl.utils.dataset import * - - -def test_dataset_utils(): - np.random.seed(88) - - mdp = GridWorld(3, 3, (2,2)) - epsilon = Parameter(value=0.) - alpha = Parameter(value=0.) - pi = EpsGreedy(epsilon=epsilon) - - agent = SARSA(mdp.info, pi, alpha) - core = Core(agent, mdp) - - dataset = core.evaluate(n_episodes=10) - - J = compute_J(dataset, mdp.info.gamma) - J_test = np.array([1.16106307e-03, 2.78128389e-01, 1.66771817e+00, 3.09031544e-01, - 1.19725152e-01, 9.84770902e-01, 1.06111661e-02, 2.05891132e+00, - 2.28767925e+00, 4.23911583e-01]) - assert np.allclose(J, J_test) - - L = compute_episodes_length(dataset) - L_test = np.array([87, 35, 18, 34, 43, 23, 66, 16, 15, 31]) - assert np.array_equal(L, L_test) - - dataset_ep = select_first_episodes(dataset, 3) - J = compute_J(dataset_ep, mdp.info.gamma) - assert np.allclose(J, J_test[:3]) - - L = compute_episodes_length(dataset_ep) - assert np.allclose(L, L_test[:3]) - - samples = select_random_samples(dataset, 2) - s, a, r, ss, ab, last = parse_dataset(samples) - s_test = np.array([[6.], [1.]]) - a_test = np.array([[0.], [1.]]) - r_test = np.zeros(2) - ss_test = np.array([[3], [4]]) - ab_test = np.zeros(2) - last_test = np.zeros(2) - assert np.array_equal(s, s_test) - assert np.array_equal(a, a_test) - assert np.array_equal(r, r_test) - assert np.array_equal(ss, ss_test) - assert np.array_equal(ab, ab_test) - assert np.array_equal(last, last_test) - - s0 = get_init_states(dataset) - s0_test = np.zeros((10, 1)) - assert np.array_equal(s0, s0_test) - - index = np.sum(L_test[:2]) + L_test[2]//2 - min_J, max_J, mean_J, median_J, n_episodes = compute_metrics(dataset[:index], mdp.info.gamma) - assert min_J == 0.0011610630703530948 - assert max_J == 0.2781283894436937 - assert mean_J == 0.1396447262570234 - assert median_J == 0.1396447262570234 - assert n_episodes == 2 - - diff --git a/tests/utils/test_preprocessors.py b/tests/utils/test_preprocessors.py index dbe584f10..ed7987dac 100644 --- a/tests/utils/test_preprocessors.py +++ b/tests/utils/test_preprocessors.py @@ -64,20 +64,18 @@ def test_normalizing_preprocessor(tmpdir): alg_params = dict(batch_size=5, initial_replay_size=10, max_replay_size=500, target_update_frequency=50) - agent = DQN(mdp.info, pi, TorchApproximator, - approximator_params=approximator_params, **alg_params) + agent = DQN(mdp.info, pi, TorchApproximator, approximator_params=approximator_params, **alg_params) - norm_box = MinMaxPreprocessor(mdp_info=mdp.info, - clip_obs=5.0, alpha=0.001) + norm_box = MinMaxPreprocessor(mdp_info=mdp.info, clip_obs=5.0, alpha=0.001) agent.add_preprocessor(norm_box) core = Core(agent, mdp) core.learn(n_steps=100, n_steps_per_fit=1, quiet=True) + dataset = core.evaluate(n_steps=1000) # training correctly - assert (core._state.min() >= -norm_box._clip_obs - and core._state.max() <= norm_box._clip_obs) + assert (dataset.state.min() >= -norm_box._clip_obs and dataset.state.max() <= norm_box._clip_obs) # save current dict state_dict1 = deepcopy(norm_box.__dict__)