diff --git a/.gitignore b/.gitignore index bdd9cee37..c1ae69e5b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,11 @@ *.DS_store build/ dist/ +examples/mushroom_rl_recordings/ examples/habitat/Replica-Dataset examples/habitat/data mushroom_rl.egg-info/ +mushroom_rl_recordings/ .idea/ *.pyc *.pyd diff --git a/Makefile b/Makefile index 3c30854f2..7f13efeef 100644 --- a/Makefile +++ b/Makefile @@ -13,3 +13,5 @@ upload: clean: rm -rf dist rm -rf build + +.NOTPARALLEL: diff --git a/TODO.txt b/TODO.txt index afc80021d..2208cca03 100644 --- a/TODO.txt +++ b/TODO.txt @@ -17,7 +17,6 @@ Approximator: * add neural network generator For Mushroom 2.0: - * Record method in environment and record option in the core * Simplify Regressor interface: drop GenericRegressor, remove facade pattern * vectorize basis functions and simplify interface, simplify facade pattern * remove custom save for plotting, use Serializable diff --git a/docs/source/tutorials/code/room_env.py b/docs/source/tutorials/code/room_env.py index cfc85a86b..8aea07b7d 100644 --- a/docs/source/tutorials/code/room_env.py +++ b/docs/source/tutorials/code/room_env.py @@ -7,7 +7,7 @@ class RoomToyEnv(Environment): - def __init__(self, size=5., goal=[2.5, 2.5], goal_radius=0.6): + def __init__(self, size=5., goal=(2.5, 2.5), goal_radius=0.6): # Save important environment information self._size = size @@ -23,7 +23,7 @@ def __init__(self, size=5., goal=[2.5, 2.5], goal_radius=0.6): observation_space = Box(0, size, shape) # Create the MDPInfo structure, needed by the environment interface - mdp_info = MDPInfo(observation_space, action_space, gamma=0.99, horizon=100) + mdp_info = MDPInfo(observation_space, action_space, gamma=0.99, horizon=100, dt=0.1) super().__init__(mdp_info) @@ -86,15 +86,20 @@ def step(self, action): # Return all the information + empty dictionary (used to pass additional information) return self._state, reward, absorbing, {} - def render(self): + def render(self, record=False): # Draw a red circle for the agent self._viewer.circle(self._state, 0.1, color=(255, 0, 0)) # Draw a green circle for the goal self._viewer.circle(self._goal, self._goal_radius, color=(0, 255, 0)) - # Display the image for 0.1 seconds - self._viewer.display(0.1) + # Get the image if the record flag is set to true + frame = self._viewer.get_frame() if record else None + + # Display the image for the control time (0.1 seconds) + self._viewer.display(self.info.dt) + + return frame # Register the class diff --git a/docs/source/tutorials/tutorials.5_environments.rst b/docs/source/tutorials/tutorials.5_environments.rst index ca50e0a8a..754b3f04a 100644 --- a/docs/source/tutorials/tutorials.5_environments.rst +++ b/docs/source/tutorials/tutorials.5_environments.rst @@ -159,14 +159,14 @@ visualization tool for 2D Reinforcement Learning algorithms. The viewer class ha simply draw two circles representing the agent and the goal area: .. literalinclude:: code/room_env.py - :lines: 89-97 + :lines: 89-102 For more information about the viewer, refer to the class documentation. To conclude our environment, it's also possible to register it as specified in the previous section of this tutorial: .. literalinclude:: code/room_env.py - :lines: 100-101 + :lines: 105-106 Learning in the toy environment @@ -179,17 +179,17 @@ We first import all necessary classes and utilities, then we construct the envir reproducibility). .. literalinclude:: code/room_env.py - :lines: 103-116 + :lines: 108-121 We now proceed then to create the agent policy, which is a linear policy using tiles features, similar to the one used by the Mountain Car experiment from R. Sutton book. .. literalinclude:: code/room_env.py - :lines: 118-139 + :lines: 123-144 Finally, using the ``Core`` class we set up an RL experiment. We first evaluate the initial policy for three episodes on the environment. Then we learn the task using the algorithm build above for 20000 steps. In the end, we evaluate the learned policy for 3 more episodes. .. literalinclude:: code/room_env.py - :lines: 141- + :lines: 146- diff --git a/mushroom_rl/__init__.py b/mushroom_rl/__init__.py index 4175d3998..52af183e5 100644 --- a/mushroom_rl/__init__.py +++ b/mushroom_rl/__init__.py @@ -1 +1 @@ -__version__ = '1.9.2' +__version__ = '1.10.0' diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py index 157d4bb96..b4236f2b3 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/sac.py @@ -17,27 +17,20 @@ class SACPolicy(Policy): """ - Class used to implement the policy used by the Soft Actor-Critic - algorithm. The policy is a Gaussian policy squashed by a tanh. - This class implements the compute_action_and_log_prob and the - compute_action_and_log_prob_t methods, that are fundamental for - the internals calculations of the SAC algorithm. + Class used to implement the policy used by the Soft Actor-Critic algorithm. + The policy is a Gaussian policy squashed by a tanh. This class implements the compute_action_and_log_prob and the + compute_action_and_log_prob_t methods, that are fundamental for the internals calculations of the SAC algorithm. """ - def __init__(self, mu_approximator, sigma_approximator, min_a, max_a, - log_std_min, log_std_max): + def __init__(self, mu_approximator, sigma_approximator, min_a, max_a, log_std_min, log_std_max): """ Constructor. Args: - mu_approximator (Regressor): a regressor computing mean in given a - state; - sigma_approximator (Regressor): a regressor computing the variance - in given a state; - min_a (np.ndarray): a vector specifying the minimum action value - for each component; - max_a (np.ndarray): a vector specifying the maximum action value - for each component. + mu_approximator (Regressor): a regressor computing mean in given a state; + sigma_approximator (Regressor): a regressor computing the variance in given a state; + min_a (np.ndarray): a vector specifying the minimum action value for each component; + max_a (np.ndarray): a vector specifying the maximum action value for each component. log_std_min ([float, Parameter]): min value for the policy log std; log_std_max ([float, Parameter]): max value for the policy log std. @@ -78,8 +71,7 @@ def draw_action(self, state): def compute_action_and_log_prob(self, state): """ - Function that samples actions using the reparametrization trick and - the log probability for such actions. + Function that samples actions using the reparametrization trick and the log probability for such actions. Args: state (np.ndarray): the state in which the action is sampled. @@ -93,17 +85,15 @@ def compute_action_and_log_prob(self, state): def compute_action_and_log_prob_t(self, state, compute_log_prob=True): """ - Function that samples actions using the reparametrization trick and, - optionally, the log probability for such actions. + Function that samples actions using the reparametrization trick and, optionally, the log probability for such + actions. Args: state (np.ndarray): the state in which the action is sampled; - compute_log_prob (bool, True): whether to compute the log - probability or not. + compute_log_prob (bool, True): whether to compute the log probability or not. Returns: - The actions sampled and, optionally, the log probability as torch - tensors. + The actions sampled and, optionally, the log probability as torch tensors. """ dist = self.distribution(state) @@ -123,8 +113,7 @@ def distribution(self, state): Compute the policy distribution in the given states. Args: - state (np.ndarray): the set of states where the distribution is - computed. + state (np.ndarray): the set of states where the distribution is computed. Returns: The torch distribution for the provided states. @@ -147,19 +136,15 @@ def entropy(self, state=None): The value of the entropy of the policy. """ - - return torch.mean(self.distribution(state).entropy()).detach().cpu().numpy().item() - - def reset(self): - pass + _, log_pi = self.compute_action_and_log_prob(state) + return -log_pi.mean() def set_weights(self, weights): """ Setter. Args: - weights (np.ndarray): the vector of the new weights to be used by - the policy. + weights (np.ndarray): the vector of the new weights to be used by the policy. """ mu_weights = weights[:self._mu_approximator.weights_size] @@ -190,8 +175,7 @@ def use_cuda(self): def parameters(self): """ - Returns the trainable policy parameters, as expected by torch - optimizers. + Returns the trainable policy parameters, as expected by torch optimizers. Returns: List of parameters to be optimized. @@ -208,38 +192,30 @@ class SAC(DeepAC): Haarnoja T. et al.. 2019. """ - def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, - actor_optimizer, critic_params, batch_size, - initial_replay_size, max_replay_size, warmup_transitions, tau, - lr_alpha, log_std_min=-20, log_std_max=2, target_entropy=None, - critic_fit_params=None): + def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, actor_optimizer, critic_params, batch_size, + initial_replay_size, max_replay_size, warmup_transitions, tau, lr_alpha, use_log_alpha_loss=False, + log_std_min=-20, log_std_max=2, target_entropy=None,critic_fit_params=None): """ Constructor. Args: - actor_mu_params (dict): parameters of the actor mean approximator - to build; - actor_sigma_params (dict): parameters of the actor sigm - approximator to build; - actor_optimizer (dict): parameters to specify the actor - optimizer algorithm; - critic_params (dict): parameters of the critic approximator to - build; + actor_mu_params (dict): parameters of the actor mean approximator to build; + actor_sigma_params (dict): parameters of the actor sigma approximator to build; + actor_optimizer (dict): parameters to specify the actor optimizer algorithm; + critic_params (dict): parameters of the critic approximator to build; batch_size ((int, Parameter)): the number of samples in a batch; - initial_replay_size (int): the number of samples to collect before - starting the learning; - max_replay_size (int): the maximum number of samples in the replay - memory; - warmup_transitions ([int, Parameter]): number of samples to accumulate in the - replay memory to start the policy fitting; + initial_replay_size (int): the number of samples to collect before starting the learning; + max_replay_size (int): the maximum number of samples in the replay memory; + warmup_transitions ([int, Parameter]): number of samples to accumulate in the replay memory to start the + policy fitting; tau ([float, Parameter]): value of coefficient for soft updates; lr_alpha ([float, Parameter]): Learning rate for the entropy coefficient; + use_log_alpha_loss (bool, False): whether to use the original implementation loss or the one from the + paper; log_std_min ([float, Parameter]): Min value for the policy log std; log_std_max ([float, Parameter]): Max value for the policy log std; - target_entropy (float, None): target entropy for the policy, if - None a default value is computed ; - critic_fit_params (dict, None): parameters of the fitting algorithm - of the critic approximator. + target_entropy (float, None): target entropy for the policy, if None a default value is computed; + critic_fit_params (dict, None): parameters of the fitting algorithm of the critic approximator. """ self._critic_fit_params = dict() if critic_fit_params is None else critic_fit_params @@ -248,6 +224,8 @@ def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, self._warmup_transitions = to_parameter(warmup_transitions) self._tau = to_parameter(tau) + self._use_log_alpha_loss = use_log_alpha_loss + if target_entropy is None: self._target_entropy = -np.prod(mdp_info.action_space.shape).astype(np.float32) else: @@ -261,25 +239,16 @@ def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, critic_params['n_models'] = 2 target_critic_params = deepcopy(critic_params) - self._critic_approximator = Regressor(TorchApproximator, - **critic_params) - self._target_critic_approximator = Regressor(TorchApproximator, - **target_critic_params) - - actor_mu_approximator = Regressor(TorchApproximator, - **actor_mu_params) - actor_sigma_approximator = Regressor(TorchApproximator, - **actor_sigma_params) - - policy = SACPolicy(actor_mu_approximator, - actor_sigma_approximator, - mdp_info.action_space.low, - mdp_info.action_space.high, - log_std_min, - log_std_max) - - self._init_target(self._critic_approximator, - self._target_critic_approximator) + self._critic_approximator = Regressor(TorchApproximator, **critic_params) + self._target_critic_approximator = Regressor(TorchApproximator, **target_critic_params) + + actor_mu_approximator = Regressor(TorchApproximator, **actor_mu_params) + actor_sigma_approximator = Regressor(TorchApproximator, **actor_sigma_params) + + policy = SACPolicy(actor_mu_approximator, actor_sigma_approximator, mdp_info.action_space.low, + mdp_info.action_space.high, log_std_min, log_std_max) + + self._init_target(self._critic_approximator, self._target_critic_approximator) self._log_alpha = torch.tensor(0., dtype=torch.float32) @@ -302,6 +271,7 @@ def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, _replay_memory='mushroom', _critic_approximator='mushroom', _target_critic_approximator='mushroom', + _use_log_alpha_loss='primitive', _log_alpha='torch', _alpha_optim='torch' ) @@ -311,8 +281,7 @@ def __init__(self, mdp_info, actor_mu_params, actor_sigma_params, def fit(self, dataset, **info): self._replay_memory.add(dataset) if self._replay_memory.initialized: - state, action, reward, next_state, absorbing, _ = \ - self._replay_memory.get(self._batch_size()) + state, action, reward, next_state, absorbing, _ = self._replay_memory.get(self._batch_size()) if self._replay_memory.size > self._warmup_transitions(): action_new, log_prob = self.policy.compute_action_and_log_prob_t(state) @@ -323,24 +292,23 @@ def fit(self, dataset, **info): q_next = self._next_q(next_state, absorbing) q = reward + self.mdp_info.gamma * q_next - self._critic_approximator.fit(state, action, q, - **self._critic_fit_params) + self._critic_approximator.fit(state, action, q, **self._critic_fit_params) - self._update_target(self._critic_approximator, - self._target_critic_approximator) + self._update_target(self._critic_approximator, self._target_critic_approximator) def _loss(self, state, action_new, log_prob): - q_0 = self._critic_approximator(state, action_new, - output_tensor=True, idx=0) - q_1 = self._critic_approximator(state, action_new, - output_tensor=True, idx=1) + q_0 = self._critic_approximator(state, action_new, output_tensor=True, idx=0) + q_1 = self._critic_approximator(state, action_new, output_tensor=True, idx=1) q = torch.min(q_0, q_1) return (self._alpha * log_prob - q).mean() def _update_alpha(self, log_prob): - alpha_loss = - (self._log_alpha * (log_prob + self._target_entropy)).mean() + if self._use_log_alpha_loss: + alpha_loss = - (self._log_alpha * (log_prob + self._target_entropy)).mean() + else: + alpha_loss = - (self._alpha * (log_prob + self._target_entropy)).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() @@ -348,14 +316,11 @@ def _update_alpha(self, log_prob): def _next_q(self, next_state, absorbing): """ Args: - next_state (np.ndarray): the states where next action has to be - evaluated; - absorbing (np.ndarray): the absorbing flag for the states in - ``next_state``. + next_state (np.ndarray): the states where next action has to be evaluated; + absorbing (np.ndarray): the absorbing flag for the states in ``next_state``. Returns: - Action-values returned by the critic for ``next_state`` and the - action returned by the actor. + Action-values returned by the critic for ``next_state`` and the action returned by the actor. """ a, log_prob_next = self.policy.compute_action_and_log_prob(next_state) diff --git a/mushroom_rl/approximators/parametric/torch_approximator.py b/mushroom_rl/approximators/parametric/torch_approximator.py index 2300e54b9..51a896451 100644 --- a/mushroom_rl/approximators/parametric/torch_approximator.py +++ b/mushroom_rl/approximators/parametric/torch_approximator.py @@ -96,7 +96,7 @@ def predict(self, *args, output_tensor=False, **kwargs): """ if not self._use_cuda: - torch_args = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x + torch_args = [torch.as_tensor(x) if isinstance(x, np.ndarray) else x for x in args] val = self.network(*torch_args, **kwargs) @@ -107,7 +107,7 @@ def predict(self, *args, output_tensor=False, **kwargs): else: val = val.detach().numpy() else: - torch_args = [torch.from_numpy(x).cuda() + torch_args = [torch.as_tensor(x).cuda() if isinstance(x, np.ndarray) else x.cuda() for x in args] val = self.network(*torch_args, **kwargs) @@ -239,15 +239,15 @@ def _fit_batch(self, batch, use_weights, kwargs): def _compute_batch_loss(self, batch, use_weights, kwargs): if use_weights: - weights = torch.from_numpy(batch[-1]).type(torch.float) + weights = torch.as_tensor(batch[-1]).type(torch.float) if self._use_cuda: weights = weights.cuda() batch = batch[:-1] if not self._use_cuda: - torch_args = [torch.from_numpy(x) for x in batch] + torch_args = [torch.as_tensor(x) for x in batch] else: - torch_args = [torch.from_numpy(x).cuda() for x in batch] + torch_args = [torch.as_tensor(x).cuda() for x in batch] x = torch_args[:-self._n_fit_targets] @@ -317,9 +317,9 @@ def diff(self, *args, **kwargs): """ if not self._use_cuda: - torch_args = [torch.from_numpy(np.atleast_2d(x)) for x in args] + torch_args = [torch.as_tensor(np.atleast_2d(x)) for x in args] else: - torch_args = [torch.from_numpy(np.atleast_2d(x)).cuda() + torch_args = [torch.as_tensor(np.atleast_2d(x)).cuda() for x in args] y_hat = self.network(*torch_args, **kwargs) diff --git a/mushroom_rl/core/core.py b/mushroom_rl/core/core.py index 5e480fa16..22d6e20e3 100644 --- a/mushroom_rl/core/core.py +++ b/mushroom_rl/core/core.py @@ -1,6 +1,7 @@ from tqdm import tqdm from collections import defaultdict +from mushroom_rl.utils.record import VideoRecorder class Core(object): @@ -8,15 +9,14 @@ class Core(object): Implements the functions to run a generic algorithm. """ - def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None): + def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None, record_dictionary=None): """ Constructor. Args: agent (Agent): the agent moving according to a policy; mdp (Environment): the environment in which the agent moves; - callbacks_fit (list): list of callbacks to execute at the end of - each fit; + callbacks_fit (list): list of callbacks to execute at the end of each fit; callback_step (Callback): callback to execute after each step; """ @@ -36,14 +36,17 @@ def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None): self._n_steps_per_fit = None self._n_episodes_per_fit = None + if record_dictionary is None: + record_dictionary = dict() + self._record = self._build_recorder_class(**record_dictionary) + def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, - n_episodes_per_fit=None, render=False, quiet=False): + n_episodes_per_fit=None, render=False, quiet=False, record=False): """ - This function moves the agent in the environment and fits the policy - using the collected samples. The agent can be moved for a given number - of steps or a given number of episodes and, independently from this - choice, the policy can be fitted after a given number of steps or a - given number of episodes. By default, the environment is reset. + This function moves the agent in the environment and fits the policy using the collected samples. + The agent can be moved for a given number of steps or a given number of episodes and, independently from this + choice, the policy can be fitted after a given number of steps or a given number of episodes. + The environment is reset at the beginning of the learning process. Args: n_steps (int, None): number of steps to move the agent; @@ -53,88 +56,82 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_per_fit (int, None): number of episodes between each fit of the policy; render (bool, False): whether to render the environment or not; - quiet (bool, False): whether to show the progress bar or not. + quiet (bool, False): whether to show the progress bar or not; + record (bool, False): whether to record a video of the environment or not. If True, also the render flag + should be set to True. """ assert (n_episodes_per_fit is not None and n_steps_per_fit is None)\ or (n_episodes_per_fit is None and n_steps_per_fit is not None) + assert (render and record) or (not record), "To record, the render flag must be set to true" + self._n_steps_per_fit = n_steps_per_fit self._n_episodes_per_fit = n_episodes_per_fit if n_steps_per_fit is not None: - fit_condition =\ - lambda: self._current_steps_counter >= self._n_steps_per_fit + fit_condition = lambda: self._current_steps_counter >= self._n_steps_per_fit else: - fit_condition = lambda: self._current_episodes_counter\ - >= self._n_episodes_per_fit + fit_condition = lambda: self._current_episodes_counter >= self._n_episodes_per_fit - self._run(n_steps, n_episodes, fit_condition, render, quiet, get_env_info=False) + self._run(n_steps, n_episodes, fit_condition, render, quiet, record, get_env_info=False) def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, - render=False, quiet=False, get_env_info=False): + render=False, quiet=False, record=False, get_env_info=False): """ This function moves the agent in the environment using its policy. - The agent is moved for a provided number of steps, episodes, or from - a set of initial states for the whole episode. By default, the - environment is reset. + The agent is moved for a provided number of steps, episodes, or from a set of initial states for the whole + episode. The environment is reset at the beginning of the learning process. Args: - initial_states (np.ndarray, None): the starting states of each - episode; + initial_states (np.ndarray, None): the starting states of each episode; n_steps (int, None): number of steps to move the agent; n_episodes (int, None): number of episodes to move the agent; render (bool, False): whether to render the environment or not; quiet (bool, False): whether to show the progress bar or not; - get_env_info (bool, False): whether to return the environment - info list or not. + record (bool, False): whether to record a video of the environment or not. If True, also the render flag + should be set to True; + get_env_info (bool, False): whether to return the environment info list or not. Returns: The collected dataset and, optionally, an extra dataset of environment info, collected at each step. """ + assert (render and record) or (not record), "To record, the render flag must be set to true" + fit_condition = lambda: False - return self._run(n_steps, n_episodes, fit_condition, render, quiet, get_env_info, - initial_states) + return self._run(n_steps, n_episodes, fit_condition, render, quiet, record, get_env_info, initial_states) - def _run(self, n_steps, n_episodes, fit_condition, render, quiet, get_env_info, - initial_states=None): + def _run(self, n_steps, n_episodes, fit_condition, render, quiet, record, get_env_info, initial_states=None): assert n_episodes is not None and n_steps is None and initial_states is None\ or n_episodes is None and n_steps is not None and initial_states is None\ or n_episodes is None and n_steps is None and initial_states is not None - self._n_episodes = len( - initial_states) if initial_states is not None else n_episodes + self._n_episodes = len( initial_states) if initial_states is not None else n_episodes if n_steps is not None: - move_condition =\ - lambda: self._total_steps_counter < n_steps + move_condition = lambda: self._total_steps_counter < n_steps - steps_progress_bar = tqdm(total=n_steps, - dynamic_ncols=True, disable=quiet, - leave=False) + steps_progress_bar = tqdm(total=n_steps, dynamic_ncols=True, disable=quiet, leave=False) episodes_progress_bar = tqdm(disable=True) else: - move_condition =\ - lambda: self._total_episodes_counter < self._n_episodes + move_condition = lambda: self._total_episodes_counter < self._n_episodes steps_progress_bar = tqdm(disable=True) - episodes_progress_bar = tqdm(total=self._n_episodes, - dynamic_ncols=True, disable=quiet, - leave=False) + episodes_progress_bar = tqdm(total=self._n_episodes, dynamic_ncols=True, disable=quiet, leave=False) - dataset, dataset_info = self._run_impl(move_condition, fit_condition, steps_progress_bar, - episodes_progress_bar, render, initial_states) + dataset, dataset_info = self._run_impl(move_condition, fit_condition, steps_progress_bar, episodes_progress_bar, + render, record, initial_states) if get_env_info: return dataset, dataset_info else: return dataset - def _run_impl(self, move_condition, fit_condition, steps_progress_bar, - episodes_progress_bar, render, initial_states): + def _run_impl(self, move_condition, fit_condition, steps_progress_bar, episodes_progress_bar, render, record, + initial_states): self._total_episodes_counter = 0 self._total_steps_counter = 0 self._current_episodes_counter = 0 @@ -148,7 +145,7 @@ def _run_impl(self, move_condition, fit_condition, steps_progress_bar, if last: self.reset(initial_states) - sample, step_info = self._step(render) + sample, step_info = self._step(render, record) self.callback_step([sample]) @@ -182,12 +179,15 @@ def _run_impl(self, move_condition, fit_condition, steps_progress_bar, self.agent.stop() self.mdp.stop() + if record: + self._record.stop() + steps_progress_bar.close() episodes_progress_bar.close() return dataset, dataset_info - def _step(self, render): + def _step(self, render, record): """ Single step. @@ -195,9 +195,8 @@ def _step(self, render): render (bool): whether to render or not. Returns: - A tuple containing the previous state, the action sampled by the - agent, the reward obtained, the reached state, the absorbing flag - of the reached state and the last step flag. + A tuple containing the previous state, the action sampled by the agent, the reward obtained, the reached + state, the absorbing flag of the reached state and the last step flag. """ action = self.agent.draw_action(self._state) @@ -206,7 +205,10 @@ def _step(self, render): self._episode_steps += 1 if render: - self.mdp.render() + frame = self.mdp.render(record) + + if record: + self._record(frame) last = not( self._episode_steps < self.mdp.info.horizon and not absorbing) @@ -222,8 +224,7 @@ def reset(self, initial_states=None): Reset the state of the agent. """ - if initial_states is None\ - or self._total_episodes_counter == self._n_episodes: + if initial_states is None or self._total_episodes_counter == self._n_episodes: initial_state = None else: initial_state = initial_states[self._total_episodes_counter] @@ -249,3 +250,24 @@ def _preprocess(self, state): state = p(state) return state + + def _build_recorder_class(self, recorder_class=None, fps=None, **kwargs): + """ + Method to create a video recorder class. + + Args: + recorder_class (class): the class used to record the video. By default, we use the ``VideoRecorder`` class + from mushroom. The class must implement the ``__call__`` and ``stop`` methods. + + Returns: + The recorder object. + + """ + + if not recorder_class: + recorder_class = VideoRecorder + + if not fps: + fps = int(1 / self.mdp.info.dt) + + return recorder_class(fps=fps, **kwargs) diff --git a/mushroom_rl/core/environment.py b/mushroom_rl/core/environment.py index 12e497f84..7a2fa485b 100644 --- a/mushroom_rl/core/environment.py +++ b/mushroom_rl/core/environment.py @@ -9,7 +9,7 @@ class MDPInfo(Serializable): This class is used to store the information of the environment. """ - def __init__(self, observation_space, action_space, gamma, horizon): + def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1): """ Constructor. @@ -17,27 +17,29 @@ def __init__(self, observation_space, action_space, gamma, horizon): observation_space ([Box, Discrete]): the state space; action_space ([Box, Discrete]): the action space; gamma (float): the discount factor; - horizon (int): the horizon. + horizon (int): the horizon; + dt (float, 1e-1): the control timestep of the environment. """ self.observation_space = observation_space self.action_space = action_space self.gamma = gamma self.horizon = horizon + self.dt = dt self._add_save_attr( observation_space='mushroom', action_space='mushroom', gamma='primitive', - horizon='primitive' + horizon='primitive', + dt='primitive' ) @property def size(self): """ Returns: - The sum of the number of discrete states and discrete actions. Only - works for discrete spaces. + The sum of the number of discrete states and discrete actions. Only works for discrete spaces. """ return self.observation_space.size + self.action_space.size @@ -46,8 +48,7 @@ def size(self): def shape(self): """ Returns: - The concatenation of the shape tuple of the state and action - spaces. + The concatenation of the shape tuple of the state and action spaces. """ return self.observation_space.shape + self.action_space.shape @@ -86,10 +87,9 @@ def make(env_name, *args, **kwargs): """ Generate an environment given an environment name and parameters. The environment is created using the generate method, if available. Otherwise, the constructor is used. - The generate method has a simpler interface than the constructor, making it easier to generate - a standard version of the environment. If the environment name contains a '.' separator, the string - is splitted, the first element is used to select the environment and the other elements are passed as - positional parameters. + The generate method has a simpler interface than the constructor, making it easier to generate a standard + version of the environment. If the environment name contains a '.' separator, the string is splitted, the first + element is used to select the environment and the other elements are passed as positional parameters. Args: env_name (str): Name of the environment, @@ -118,8 +118,7 @@ def __init__(self, mdp_info): Constructor. Args: - mdp_info (MDPInfo): an object containing the info of the - environment. + mdp_info (MDPInfo): an object containing the info of the environment. """ self._mdp_info = mdp_info @@ -135,8 +134,7 @@ def seed(self, seed): if hasattr(self, 'env') and hasattr(self.env, 'seed'): self.env.seed(seed) else: - warnings.warn('This environment has no custom seed. ' - 'The call will have no effect. ' + warnings.warn('This environment has no custom seed. The call will have no effect. ' 'You can set the seed manually by setting numpy/torch seed') def reset(self, state=None): @@ -160,21 +158,28 @@ def step(self, action): action (np.ndarray): the action to execute. Returns: - The state reached by the agent executing ``action`` in its current - state, the reward obtained in the transition and a flag to signal - if the next state is absorbing. Also, an additional dictionary is + The state reached by the agent executing ``action`` in its current state, the reward obtained in the + transition and a flag to signal if the next state is absorbing. Also, an additional dictionary is returned (possibly empty). """ raise NotImplementedError - def render(self): + def render(self, record=False): + """ + Args: + record (bool, False): whether the visualized image should be returned or not. + + Returns: + The visualized image, or None if the record flag is set to false. + + """ raise NotImplementedError def stop(self): """ - Method used to stop an mdp. Useful when dealing with real world - environments, simulators, or when using openai-gym rendering + Method used to stop an mdp. Useful when dealing with real world environments, simulators, or when using + openai-gym rendering """ pass diff --git a/mushroom_rl/environments/__init__.py b/mushroom_rl/environments/__init__.py index 0b4c53cba..c966a67d6 100644 --- a/mushroom_rl/environments/__init__.py +++ b/mushroom_rl/environments/__init__.py @@ -43,7 +43,7 @@ try: MuJoCo = None - from .mujoco import MuJoCo + from .mujoco import MuJoCo, MultiMuJoCo from .mujoco_envs import * except ImportError: pass diff --git a/mushroom_rl/environments/atari.py b/mushroom_rl/environments/atari.py index 659e2d387..c11321ea9 100644 --- a/mushroom_rl/environments/atari.py +++ b/mushroom_rl/environments/atari.py @@ -91,7 +91,8 @@ def __init__(self, name, width=84, height=84, ends_at_life=False, low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0])) horizon = np.inf # the gym time limit is used. gamma = .99 - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + dt = 1/60 + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) super().__init__(mdp_info) @@ -136,8 +137,13 @@ def step(self, action): return LazyFrames(list(self._state), self._history_length), reward, absorbing, info - def render(self, mode='human'): - self.env.render(mode=mode) + def render(self, record=False): + self.env.render(mode='human') + + if record: + return self.env.render(mode='rgb_array') + else: + return None def stop(self): self.env.close() diff --git a/mushroom_rl/environments/car_on_hill.py b/mushroom_rl/environments/car_on_hill.py index bf822d92a..32eacdc02 100644 --- a/mushroom_rl/environments/car_on_hill.py +++ b/mushroom_rl/environments/car_on_hill.py @@ -27,13 +27,13 @@ def __init__(self, horizon=100, gamma=.95): high = np.array([self.max_pos, self.max_velocity]) self._g = 9.81 self._m = 1. - self._dt = .1 self._discrete_actions = [-4., 4.] # MDP properties + dt = .1 observation_space = spaces.Box(low=-high, high=high) action_space = spaces.Discrete(2) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) # Visualization self._viewer = Viewer(1, 1) @@ -51,7 +51,7 @@ def reset(self, state=None): def step(self, action): action = self._discrete_actions[action[0]] sa = np.append(self._state, action) - new_state = odeint(self._dpds, sa, [0, self._dt]) + new_state = odeint(self._dpds, sa, [0, self.info.dt]) self._state = new_state[-1, :-1] @@ -69,7 +69,7 @@ def step(self, action): return self._state, reward, absorbing, {} - def render(self): + def render(self, record=False): # Slope self._viewer.function(0, 1, self._height) @@ -91,7 +91,14 @@ def render(self): angle = self._angle(x_car) self._viewer.polygon(c_car, angle, car_body, color=(32, 193, 54)) - self._viewer.display(self._dt) + frame = self._viewer.get_frame() if record else None + + self._viewer.display(self.info.dt) + + return frame + + def stop(self): + self._viewer.close() @staticmethod def _angle(x): diff --git a/mushroom_rl/environments/cart_pole.py b/mushroom_rl/environments/cart_pole.py index 467e251b3..fe15afe37 100644 --- a/mushroom_rl/environments/cart_pole.py +++ b/mushroom_rl/environments/cart_pole.py @@ -36,15 +36,15 @@ def __init__(self, m=2., M=8., l=.5, g=9.8, mu=1e-2, max_u=50., noise_u=10., self._g = g self._alpha = 1 / (self._m + self._M) self._mu = mu - self._dt = .1 self._max_u = max_u self._noise_u = noise_u high = np.array([np.inf, np.inf]) # MDP properties + dt = .1 observation_space = spaces.Box(low=-high, high=high) action_space = spaces.Discrete(3) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) # Visualization self._viewer = Viewer(2.5 * l, 2.5 * l) @@ -76,8 +76,7 @@ def step(self, action): self._last_u = u u += np.random.uniform(-self._noise_u, self._noise_u) - new_state = odeint(self._dynamics, self._state, [0, self._dt], - (u,)) + new_state = odeint(self._dynamics, self._state, [0, self.info.dt], (u,)) self._state = np.array(new_state[-1]) self._state[0] = normalize_angle(self._state[0]) @@ -91,7 +90,7 @@ def step(self, action): return self._state, reward, absorbing, {} - def render(self, mode='human'): + def render(self, record=False): start = 1.25 * self._l * np.ones(2) end = 1.25 * self._l * np.ones(2) @@ -107,7 +106,11 @@ def render(self, mode='human'): self._viewer.force_arrow(start, direction, value, self._max_u, self._l / 5) - self._viewer.display(self._dt) + frame = self._viewer.get_frame() if record else None + + self._viewer.display(self.info.dt) + + return frame def stop(self): self._viewer.close() diff --git a/mushroom_rl/environments/dm_control_env.py b/mushroom_rl/environments/dm_control_env.py index 4fc510b99..030d4977d 100644 --- a/mushroom_rl/environments/dm_control_env.py +++ b/mushroom_rl/environments/dm_control_env.py @@ -8,7 +8,7 @@ from mushroom_rl.core import Environment, MDPInfo from mushroom_rl.utils.spaces import * -from mushroom_rl.utils.viewer import ImageViewer +from mushroom_rl.utils.viewer import CV2Viewer class DMControl(Environment): @@ -62,9 +62,11 @@ def __init__(self, domain_name, task_name, horizon=None, gamma=0.99, task_kwargs # MDP properties action_space = self._convert_action_space(self.env.action_spec()) observation_space = self._convert_observation_space(self.env.observation_spec()) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) - self._viewer = ImageViewer((width_screen, height_screen), dt) + self._height_screen = height_screen + self._width_screen = width_screen + self._viewer = CV2Viewer("dm_control", dt, self._width_screen, self._height_screen) self._camera_id = camera_id super().__init__(mdp_info) @@ -88,12 +90,18 @@ def step(self, action): return self._state, reward, absorbing, {} - def render(self): - img = self.env.physics.render(self._viewer.size[1], - self._viewer.size[0], + def render(self, record=False): + img = self.env.physics.render(self._height_screen, + self._width_screen, self._camera_id) + self._viewer.display(img) + if record: + return img + else: + return None + def stop(self): self._viewer.close() @@ -109,12 +117,11 @@ def _convert_observation_space_vector(observation_space): return Box(low=-np.inf, high=np.inf, shape=(observation_shape,)) - @staticmethod def _convert_observation_space_pixels(observation_space): img_size = observation_space['pixels'].shape - return Box(low=0., high=255., - shape=(3, img_size[0], img_size[1])) + + return Box(low=0., high=255., shape=(3, img_size[0], img_size[1])) @staticmethod def _convert_action_space(action_space): diff --git a/mushroom_rl/environments/finite_mdp.py b/mushroom_rl/environments/finite_mdp.py index 71dff6bec..b49995c00 100644 --- a/mushroom_rl/environments/finite_mdp.py +++ b/mushroom_rl/environments/finite_mdp.py @@ -9,7 +9,7 @@ class FiniteMDP(Environment): Finite Markov Decision Process. """ - def __init__(self, p, rew, mu=None, gamma=.9, horizon=np.inf): + def __init__(self, p, rew, mu=None, gamma=.9, horizon=np.inf, dt=1e-1): """ Constructor. @@ -18,7 +18,8 @@ def __init__(self, p, rew, mu=None, gamma=.9, horizon=np.inf): rew (np.ndarray): reward matrix; mu (np.ndarray, None): initial state probability distribution; gamma (float, .9): discount factor; - horizon (int, np.inf): the horizon. + horizon (int, np.inf): the horizon; + dt (float, 1e-1): the control timestep of the environment. """ assert p.shape == rew.shape @@ -34,7 +35,7 @@ def __init__(self, p, rew, mu=None, gamma=.9, horizon=np.inf): action_space = spaces.Discrete(p.shape[1]) horizon = horizon gamma = gamma - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) super().__init__(mdp_info) diff --git a/mushroom_rl/environments/grid_world.py b/mushroom_rl/environments/grid_world.py index 01cfb45b0..360aead01 100644 --- a/mushroom_rl/environments/grid_world.py +++ b/mushroom_rl/environments/grid_world.py @@ -23,8 +23,7 @@ def __init__(self, mdp_info, height, width, start, goal): """ assert not np.array_equal(start, goal) - assert goal[0] < height and goal[1] < width,\ - 'Goal position not suitable for the grid world dimension.' + assert goal[0] < height and goal[1] < width, 'Goal position not suitable for the grid world dimension.' self._state = None self._height = height @@ -54,7 +53,7 @@ def step(self, action): return self._state, reward, absorbing, info - def render(self): + def render(self, record=False): for row in range(1, self._height): for col in range(1, self._width): self._viewer.line(np.array([col, 0]), @@ -76,8 +75,15 @@ def render(self): self._height - (.5 + state_grid[0])]) self._viewer.circle(state_center, .4, (0, 0, 255)) + frame = self._viewer.get_frame() if record else None + self._viewer.display(.1) + return frame + + def stop(self): + self._viewer.close() + def _step(self, state, action): raise NotImplementedError('AbstractGridWorld is an abstract class.') @@ -110,13 +116,24 @@ class GridWorld(AbstractGridWorld): Standard grid world. """ - def __init__(self, height, width, goal, start=(0, 0)): + def __init__(self, height, width, goal, start=(0, 0), dt=0.1): + """ + Constructor + + Args: + height (int): height of the grid; + width (int): width of the grid; + goal (tuple): 2D coordinates of the goal state; + start (tuple, (0, 0)): 2D coordinates of the starting state; + dt (float, 0.1): the control timestep of the environment. + + """ # MDP properties observation_space = spaces.Discrete(height * width) action_space = spaces.Discrete(4) horizon = 100 gamma = .9 - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) super().__init__(mdp_info, height, width, start, goal) @@ -139,13 +156,24 @@ class GridWorldVanHasselt(AbstractGridWorld): "Double Q-Learning". Hasselt H. V.. 2010. """ - def __init__(self, height=3, width=3, goal=(0, 2), start=(2, 0)): + def __init__(self, height=3, width=3, goal=(0, 2), start=(2, 0), dt=0.1): + """ + Constructor + + Args: + height (int, 3): height of the grid; + width (int, 3): width of the grid; + goal (tuple, (0, 2)): 2D coordinates of the goal state; + start (tuple, (2, 0)): 2D coordinates of the starting state; + dt (float, 0.1): the control timestep of the environment. + + """ # MDP properties observation_space = spaces.Discrete(height * width) action_space = spaces.Discrete(4) horizon = np.inf gamma = .95 - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) super().__init__(mdp_info, height, width, start, goal) diff --git a/mushroom_rl/environments/gym_env.py b/mushroom_rl/environments/gym_env.py index db2e9bd90..66f3e12bf 100644 --- a/mushroom_rl/environments/gym_env.py +++ b/mushroom_rl/environments/gym_env.py @@ -53,8 +53,6 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args= self.env = gym.make(name, **env_args) - self._render_dt = self.env.unwrapped.dt if hasattr(self.env.unwrapped, "dt") else 0.0 - if wrappers is not None: if wrappers_args is None: wrappers_args = [dict()] * len(wrappers) @@ -71,9 +69,10 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args= gym_spaces.MultiDiscrete) assert not isinstance(self.env.action_space, gym_spaces.MultiDiscrete) + dt = self.env.unwrapped.dt if hasattr(self.env.unwrapped, "dt") else 0.1 action_space = self._convert_gym_space(self.env.action_space) observation_space = self._convert_gym_space(self.env.observation_space) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) if isinstance(action_space, Discrete): self._convert_action = lambda a: a[0] @@ -97,11 +96,19 @@ def step(self, action): return np.atleast_1d(obs), reward, absorbing, info - def render(self, mode='human'): + def render(self, record=False): if self._first or self._not_pybullet: - self.env.render(mode=mode) + self.env.render(mode='human') + self._first = False - time.sleep(self._render_dt) + time.sleep(self.info.dt) + + if record: + return self.env.render(mode='rgb_array') + else: + return None + + return None def stop(self): try: diff --git a/mushroom_rl/environments/habitat_env.py b/mushroom_rl/environments/habitat_env.py index 48b48006d..685550517 100644 --- a/mushroom_rl/environments/habitat_env.py +++ b/mushroom_rl/environments/habitat_env.py @@ -233,17 +233,18 @@ def __init__(self, wrapper, config_file, base_config_file=None, horizon=None, ga self._img_size = env.observation_space.shape[0:2] # MDP properties + dt = 1/10 action_space = self.env.action_space observation_space = Box( low=0., high=255., shape=(3, self._img_size[1], self._img_size[0])) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) if isinstance(action_space, Discrete): self._convert_action = lambda a: a[0] else: self._convert_action = lambda a: a - self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), 1/10) + self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), dt) Environment.__init__(self, mdp_info) @@ -260,16 +261,18 @@ def step(self, action): def stop(self): self._viewer.close() - def render(self, mode='rgb_array'): - if mode == "rgb_array": - frame = observations_to_image( - self.env._last_full_obs, self.env.unwrapped._env.get_metrics() - ) - else: - raise ValueError(f"Render mode {mode} not currently supported.") + def render(self, record=False): + frame = observations_to_image( + self.env._last_full_obs, self.env.unwrapped._env.get_metrics() + ) self._viewer.display(frame) + if record: + return frame + else: + return None + @staticmethod def _convert_observation(observation): return observation.transpose((2, 0, 1)) diff --git a/mushroom_rl/environments/igibson_env.py b/mushroom_rl/environments/igibson_env.py index fb681afee..a9d7b80e6 100644 --- a/mushroom_rl/environments/igibson_env.py +++ b/mushroom_rl/environments/igibson_env.py @@ -95,17 +95,18 @@ def __init__(self, config_file, horizon=None, gamma=0.99, is_discrete=False, self._img_size = env.observation_space.shape[0:2] # MDP properties + dt = 1/60 action_space = self.env.action_space observation_space = Box( low=0., high=255., shape=(3, self._img_size[1], self._img_size[0])) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) if isinstance(action_space, Discrete): self._convert_action = lambda a: a[0] else: self._convert_action = lambda a: a - self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), 1/60) + self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), dt) self._image = None Environment.__init__(self, mdp_info) @@ -126,9 +127,14 @@ def close(self): def stop(self): self._viewer.close() - def render(self, mode='human'): + def render(self, record=False): self._viewer.display(self._image) + if record: + return self._image + else: + return None + @staticmethod def _convert_observation(observation): return observation.transpose((2, 0, 1)) diff --git a/mushroom_rl/environments/inverted_pendulum.py b/mushroom_rl/environments/inverted_pendulum.py index a80023da8..3fc7ed1b9 100644 --- a/mushroom_rl/environments/inverted_pendulum.py +++ b/mushroom_rl/environments/inverted_pendulum.py @@ -38,15 +38,15 @@ def __init__(self, random_start=False, m=1., l=1., g=9.8, mu=1e-2, self._g = g self._mu = mu self._random = random_start - self._dt = .01 self._max_u = max_u self._max_omega = 5 / 2 * np.pi high = np.array([np.pi, self._max_omega]) # MDP properties + dt = .01 observation_space = spaces.Box(low=-high, high=high) action_space = spaces.Box(low=np.array([-max_u]), high=np.array([max_u])) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) # Visualization self._viewer = Viewer(2.5 * l, 2.5 * l) @@ -73,7 +73,7 @@ def reset(self, state=None): def step(self, action): u = self._bound(action[0], -self._max_u, self._max_u) - new_state = odeint(self._dynamics, self._state, [0, self._dt], args=(u.item(),)) + new_state = odeint(self._dynamics, self._state, [0, self.info.dt], args=(u.item(),)) self._state = np.array(new_state[-1]) self._state[0] = normalize_angle(self._state[0]) @@ -85,7 +85,7 @@ def step(self, action): return self._state, reward, False, {} - def render(self, mode='human'): + def render(self, record=False): start = 1.25 * self._l * np.ones(2) end = 1.25 * self._l * np.ones(2) @@ -97,13 +97,16 @@ def render(self, mode='human'): self._viewer.circle(end, self._l / 20) self._viewer.torque_arrow(start, -self._last_u, self._max_u, self._l / 5) - self._viewer.display(self._dt) + frame = self._viewer.get_frame() if record else None + + self._viewer.display(self.info.dt) + + return frame def stop(self): self._viewer.close() def _dynamics(self, state, t, u): - print(t) theta = state[0] omega = self._bound(state[1], -self._max_omega, self._max_omega) diff --git a/mushroom_rl/environments/lqr.py b/mushroom_rl/environments/lqr.py index 43949338e..8e317d693 100644 --- a/mushroom_rl/environments/lqr.py +++ b/mushroom_rl/environments/lqr.py @@ -25,9 +25,8 @@ class LQR(Environment): Parisi S., Pirotta M., Smacchia N., Bascetta L., Restelli M.. 2014 """ - def __init__(self, A, B, Q, R, max_pos=np.inf, max_action=np.inf, - random_init=False, episodic=False, gamma=0.9, horizon=50, - initial_state=None): + def __init__(self, A, B, Q, R, max_pos=np.inf, max_action=np.inf, random_init=False, episodic=False, gamma=0.9, + horizon=50, initial_state=None, dt=0.1): """ Constructor. @@ -42,7 +41,8 @@ def __init__(self, A, B, Q, R, max_pos=np.inf, max_action=np.inf, episodic (bool, False): end the episode when the state goes over the threshold; gamma (float, 0.9): discount factor; - horizon (int, 50): horizon of the mdp. + horizon (int, 50): horizon of the mdp; + dt (float, 0.1): the control timestep of the environment. """ self.A = A @@ -65,7 +65,9 @@ def __init__(self, A, B, Q, R, max_pos=np.inf, max_action=np.inf, observation_space = spaces.Box(low=low_x, high=high_x) action_space = spaces.Box(low=low_u, high=high_u) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) + + self._state = None super().__init__(mdp_info) @@ -143,8 +145,7 @@ def step(self, action): reward = -self._max_pos ** 2 * 10 absorbing = True else: - self._state = self._bound(self._state, - self.info.observation_space.low, + self._state = self._bound(self._state, self.info.observation_space.low, self.info.observation_space.high) return self._state, reward, absorbing, {} diff --git a/mushroom_rl/environments/minigrid_env.py b/mushroom_rl/environments/minigrid_env.py index 0c0a55858..acbb9084e 100644 --- a/mushroom_rl/environments/minigrid_env.py +++ b/mushroom_rl/environments/minigrid_env.py @@ -15,6 +15,7 @@ from mushroom_rl.utils.spaces import Discrete, Box from mushroom_rl.utils.frames import LazyFrames, preprocess_frame + class MiniGrid(Gym): """ Interface for gym_minigrid environments. It makes it possible to @@ -68,7 +69,8 @@ def __init__(self, name, horizon=None, gamma=0.99, history_length=4, observation_space = Box( low=0., high=obs_high, shape=(history_length, self._img_size[1], self._img_size[0])) self.env.max_steps = horizon + 1 # Hack to ignore gym time limit (do not use np.inf, since MiniGrid returns r(t) = 1 - 0.9t/T) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + dt = 1/self.env.unwrapped.metadata["render_fps"] + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) Environment.__init__(self, mdp_info) diff --git a/mushroom_rl/environments/mujoco.py b/mushroom_rl/environments/mujoco.py index 2e3b28072..cbc59802f 100644 --- a/mushroom_rl/environments/mujoco.py +++ b/mushroom_rl/environments/mujoco.py @@ -1,5 +1,6 @@ import mujoco import numpy as np +from dm_control import mjcf from mushroom_rl.core import Environment, MDPInfo from mushroom_rl.utils.spaces import Box from mushroom_rl.utils.mujoco import * @@ -10,15 +11,14 @@ class MuJoCo(Environment): Class to create a Mushroom environment using the MuJoCo simulator. """ - def __init__(self, file_name, actuation_spec, observation_spec, gamma, horizon, timestep=None, n_substeps=1, + def __init__(self, xml_file, actuation_spec, observation_spec, gamma, horizon, timestep=None, n_substeps=1, n_intermediate_steps=1, additional_data_spec=None, collision_groups=None, max_joint_vel=None, **viewer_params): """ Constructor. Args: - file_name (string): The path to the XML file with which the - environment should be created; + xml_file (str/xml handle): A string with a path to the xml or an Mujoco xml handle. actuation_spec (list): A list specifying the names of the joints which should be controllable by the agent. Can be left empty when all actuators should be used; @@ -56,11 +56,11 @@ def __init__(self, file_name, actuation_spec, observation_spec, gamma, horizon, The list has to define a maximum velocity for every occurrence of JOINT_VEL in the observation_spec. The velocity will not be limited in mujoco **viewer_params: other parameters to be passed to the viewer. - See MujocoGlfwViewer documentation for the available options. + See MujocoViewer documentation for the available options. """ # Create the simulation - self._model = mujoco.MjModel.from_xml_path(file_name) + self._model = self.load_model(xml_file) if timestep is not None: self._model.opt.timestep = timestep self._timestep = timestep @@ -77,23 +77,9 @@ def __init__(self, file_name, actuation_spec, observation_spec, gamma, horizon, # Read the actuation spec and build the mapping between actions and ids # as well as their limits - if len(actuation_spec) == 0: - self._action_indices = [i for i in range(0, len(self._data.actuator_force))] - else: - self._action_indices = [] - for name in actuation_spec: - self._action_indices.append(self._model.actuator(name).id) + self._action_indices = self.get_action_indices(self._model, self._data, actuation_spec) - low = [] - high = [] - for index in self._action_indices: - if self._model.actuator_ctrllimited[index]: - low.append(self._model.actuator_ctrlrange[index][0]) - high.append(self._model.actuator_ctrlrange[index][1]) - else: - low.append(-np.inf) - high.append(np.inf) - action_space = Box(np.array(low), np.array(high)) + action_space = self.get_action_space(self._action_indices, self._model) # Read the observation spec to build a mapping at every step. It is # ensured that the values appear in the order they are specified. @@ -112,18 +98,29 @@ def __init__(self, file_name, actuation_spec, observation_spec, gamma, horizon, self.collision_groups = {} if collision_groups is not None: for name, geom_names in collision_groups: - self.collision_groups[name] = {mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_GEOM, geom_name) - for geom_name in geom_names} + col_group = list() + for geom_name in geom_names: + mj_id = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_GEOM, geom_name) + assert mj_id != -1, f"geom \"{geom_name}\" not found! Can't be used for collision-checking." + col_group.append(mj_id) + self.collision_groups[name] = set(col_group) # Finally, we create the MDP information and call the constructor of # the parent class - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, self.dt) mdp_info = self._modify_mdp_info(mdp_info) # set the warning callback to stop the simulation when a mujoco warning occurs mujoco.set_mju_user_warning(self.user_warning_raise_exception) + # check whether the function compute_action was overridden or not. If yes, we want to compute + # the action at simulation frequency, if not we do it at control frequency. + if type(self)._compute_action == MuJoCo._compute_action: + self._recompute_action_per_step = False + else: + self._recompute_action_per_step = True + super().__init__(mdp_info) def seed(self, seed): @@ -143,10 +140,13 @@ def step(self, action): self._step_init(cur_obs, action) + ctrl_action = None + for i in range(self._n_intermediate_steps): - ctrl_action = self._compute_action(cur_obs, action) - self._data.ctrl[self._action_indices] = ctrl_action + if self._recompute_action_per_step or ctrl_action is None: + ctrl_action = self._compute_action(cur_obs, action) + self._data.ctrl[self._action_indices] = ctrl_action self._simulation_pre_step() @@ -154,6 +154,10 @@ def step(self, action): self._simulation_post_step() + if self._recompute_action_per_step: + cur_obs = self._create_observation(self.obs_helper._build_obs(self._data)) + + if not self._recompute_action_per_step: cur_obs = self._create_observation(self.obs_helper._build_obs(self._data)) self._step_finalize() @@ -166,11 +170,11 @@ def step(self, action): return self._modify_observation(cur_obs), reward, absorbing, info - def render(self): + def render(self, record=False): if self._viewer is None: - self._viewer = MujocoGlfwViewer(self._model, self.dt, **self._viewer_params) + self._viewer = MujocoViewer(self._model, self.dt, record=record, **self._viewer_params) - self._viewer.render(self._data) + return self._viewer.render(self._data, record) def stop(self): if self._viewer is not None: @@ -436,7 +440,6 @@ def setup(self, obs): if obs is not None: self.obs_helper._modify_data(self._data, obs) - def get_all_observation_keys(self): """ A function that returns all observation keys defined in the observation specification. @@ -451,6 +454,55 @@ def get_all_observation_keys(self): def dt(self): return self._timestep * self._n_intermediate_steps * self._n_substeps + @staticmethod + def get_action_indices(model, data, actuation_spec): + """ + Returns the action indices given the MuJoCo model, data, and actuation_spec. + + Args: + model: MuJoCo model. + data: MuJoCo data structure. + actuation_spec (list): A list specifying the names of the joints + which should be controllable by the agent. Can be left empty + when all actuators should be used; + + Returns: + A list of actuator indices. + + """ + if len(actuation_spec) == 0: + action_indices = [i for i in range(0, len(data.actuator_force))] + else: + action_indices = [] + for name in actuation_spec: + action_indices.append(model.actuator(name).id) + return action_indices + + @staticmethod + def get_action_space(action_indices, model): + """ + Returns the action space bounding box given the action_indices and the model. + + Args: + action_indices (list): A list of actuator indices. + model: MuJoCo model. + + Returns: + A bounding box for the action space. + + """ + low = [] + high = [] + for index in action_indices: + if model.actuator_ctrllimited[index]: + low.append(model.actuator_ctrlrange[index][0]) + high.append(model.actuator_ctrlrange[index][1]) + else: + low.append(-np.inf) + high.append(np.inf) + action_space = Box(np.array(low), np.array(high)) + return action_space + @staticmethod def user_warning_raise_exception(warning): """ @@ -468,3 +520,224 @@ def user_warning_raise_exception(warning): raise RuntimeError(warning + 'Check for NaN in simulation.') else: raise RuntimeError('Got MuJoCo Warning: ' + warning) + + @staticmethod + def load_model(xml_file): + """ + Takes an xml_file and compiles and loads the model. + + Args: + xml_file (str/xml handle): A string with a path to the xml or an Mujoco xml handle. + + Returns: + Mujoco model. + + """ + if type(xml_file) == mjcf.element.RootElement: + # load from xml handle + model = mujoco.MjModel.from_xml_string(xml=xml_file.to_xml_string(), + assets=xml_file.get_assets()) + elif type(xml_file) == str: + # load from path + model = mujoco.MjModel.from_xml_path(xml_file) + else: + raise ValueError(f"Unsupported type for xml_file {type(xml_file)}.") + + return model + + +class MultiMuJoCo(MuJoCo): + """ + Class to create N environments at the same time using the MuJoCo simulator. This class is not meant to run + N environments in parallel, but to load and create N environments, and randomly sample one of the + environment every episode. + + """ + + def __init__(self, xml_files, actuation_spec, observation_spec, gamma, horizon, timestep=None, + n_substeps=1, n_intermediate_steps=1, additional_data_spec=None, collision_groups=None, + max_joint_vel=None, random_env_reset=True, **viewer_params): + """ + Constructor. + + Args: + xml_files (str/xml handle): A list containing strings with a path to the xml or Mujoco xml handles; + actuation_spec (list): A list specifying the names of the joints + which should be controllable by the agent. Can be left empty + when all actuators should be used; + observation_spec (list): A list containing the names of data that + should be made available to the agent as an observation and + their type (ObservationType). They are combined with a key, + which is used to access the data. An entry in the list + is given by: (key, name, type); + gamma (float): The discounting factor of the environment; + horizon (int): The maximum horizon for the environment; + timestep (float): The timestep used by the MuJoCo + simulator. If None, the default timestep specified in the XML will be used; + n_substeps (int, 1): The number of substeps to use by the MuJoCo + simulator. An action given by the agent will be applied for + n_substeps before the agent receives the next observation and + can act accordingly; + n_intermediate_steps (int, 1): The number of steps between every action + taken by the agent. Similar to n_substeps but allows the user + to modify, control and access intermediate states. + additional_data_spec (list, None): A list containing the data fields of + interest, which should be read from or written to during + simulation. The entries are given as the following tuples: + (key, name, type) key is a string for later referencing in the + "read_data" and "write_data" methods. The name is the name of + the object in the XML specification and the type is the + ObservationType; + collision_groups (list, None): A list containing groups of geoms for + which collisions should be checked during simulation via + ``check_collision``. The entries are given as: + ``(key, geom_names)``, where key is a string for later + referencing in the "check_collision" method, and geom_names is + a list of geom names in the XML specification. + max_joint_vel (list, None): A list with the maximum joint velocities which are provided in the mdp_info. + The list has to define a maximum velocity for every occurrence of JOINT_VEL in the observation_spec. The + velocity will not be limited in mujoco. + random_env_reset (bool): If True, a random environment/model is chosen after each episode. If False, it is + sequentially iterated through the environment/model list. + **viewer_params: other parameters to be passed to the viewer. + See MujocoViewer documentation for the available options. + + """ + # Create the simulation + self._random_env_reset = random_env_reset + self._models = [self.load_model(f) for f in xml_files] + + self._current_model_idx = 0 + self._model = self._models[self._current_model_idx] + if timestep is not None: + self._model.opt.timestep = timestep + self._timestep = timestep + else: + self._timestep = self._model.opt.timestep + + self._datas = [mujoco.MjData(m) for m in self._models] + self._data = self._datas[self._current_model_idx] + + self._n_intermediate_steps = n_intermediate_steps + self._n_substeps = n_substeps + self._viewer_params = viewer_params + self._viewer = None + self._obs = None + + # Read the actuation spec and build the mapping between actions and ids + # as well as their limits + self._action_indices = self.get_action_indices(self._model, self._data, actuation_spec) + + action_space = self.get_action_space(self._action_indices, self._model) + + # all env need to have the same action space, do sanity check + for m, d in zip(self._models, self._datas): + action_ind = self.get_action_indices(m, d, actuation_spec) + action_sp = self.get_action_space(action_ind, m) + if not np.array_equal(action_ind, self._action_indices) or \ + not np.array_equal(action_space.low, action_sp.low) or\ + not np.array_equal(action_space.high, action_sp.high): + raise ValueError("The provided environments differ in the their action spaces. " + "This is not allowed.") + + # Read the observation spec to build a mapping at every step. It is + # ensured that the values appear in the order they are specified. + self.obs_helpers = [ObservationHelper(observation_spec, self._model, self._data, + max_joint_velocity=max_joint_vel) + for m, d in zip(self._models, self._datas)] + self.obs_helper = self.obs_helpers[self._current_model_idx] + + observation_space = Box(*self.obs_helper.get_obs_limits()) + + # multi envs with different obs limits are now allowed, do sanity check + for oh in self.obs_helpers: + low, high = self.obs_helper.get_obs_limits() + if not np.array_equal(low, observation_space.low) or not np.array_equal(high, observation_space.high): + raise ValueError("The provided environments differ in the their observation limits. " + "This is not allowed.") + + # Pre-process the additional data to allow easier writing and reading + # to and from arrays in MuJoCo + self.additional_data = {} + if additional_data_spec is not None: + for key, name, ot in additional_data_spec: + self.additional_data[key] = (name, ot) + + # Pre-process the collision groups for "fast" detection of contacts + self.collision_groups = {} + if collision_groups is not None: + for name, geom_names in collision_groups: + col_group = list() + for geom_name in geom_names: + mj_id = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_GEOM, geom_name) + assert mj_id != -1, f"geom \"{geom_name}\" not found! Can't be used for collision-checking." + col_group.append(mj_id) + self.collision_groups[name] = set(col_group) + + # Finally, we create the MDP information and call the constructor of + # the parent class + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, self.dt) + + mdp_info = self._modify_mdp_info(mdp_info) + + # set the warning callback to stop the simulation when a mujoco warning occurs + mujoco.set_mju_user_warning(self.user_warning_raise_exception) + + # check whether the function compute_action was overridden or not. If yes, we want to compute + # the action at simulation frequency, if not we do it at control frequency. + if type(self)._compute_action == MuJoCo._compute_action: + self._recompute_action_per_step = False + else: + self._recompute_action_per_step = True + + # call grad-parent class, not MuJoCo + super(MuJoCo, self).__init__(mdp_info) + + def reset(self, obs=None): + mujoco.mj_resetData(self._model, self._data) + + if self._random_env_reset: + self._current_model_idx = np.random.randint(0, len(self._models)) + else: + self._current_model_idx = self._current_model_idx + 1 \ + if self._current_model_idx < len(self._models) - 1 else 0 + + self._model = self._models[self._current_model_idx] + self._data = self._datas[self._current_model_idx] + self.obs_helper = self.obs_helpers[self._current_model_idx] + self.setup(obs) + + if self._viewer is not None and self.more_than_one_env: + self._viewer.load_new_model(self._model) + + self._obs = self._create_observation(self.obs_helper._build_obs(self._data)) + return self._modify_observation(self._obs) + + @property + def more_than_one_env(self): + return len(self._models) > 1 + + @staticmethod + def _get_env_id_map(current_model_idx, n_models): + """ + Retuns a binary vector to identify environment. This can be passed to the observation space. + + Args: + current_model_idx (int): index of the current model. + n_models (int): total number of models. + + Returns: + ndarray containing a binary vector identifying the current environment. + + """ + n_models = np.maximum(n_models, 2) + bits_needed = 1+int(np.log((n_models-1))/np.log(2)) + id_mask = np.zeros(bits_needed) + bin_rep = np.binary_repr(current_model_idx)[::-1] + for i, b in enumerate(bin_rep): + idx = bits_needed - 1 - i # reverse idx + if int(b): + id_mask[idx] = 1.0 + else: + id_mask[idx] = 0.0 + return id_mask diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/base.py b/mushroom_rl/environments/mujoco_envs/air_hockey/base.py index e44edfad8..06de5564f 100644 --- a/mushroom_rl/environments/mujoco_envs/air_hockey/base.py +++ b/mushroom_rl/environments/mujoco_envs/air_hockey/base.py @@ -15,7 +15,8 @@ class AirHockeyBase(MuJoCo): """ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, horizon=500, - timestep=1 / 240., n_substeps=1, n_intermediate_steps=1, **viewer_params): + timestep=1 / 240., n_substeps=1, n_intermediate_steps=1, default_camera_mode="top_static", + **viewer_params): """ Constructor. @@ -113,7 +114,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor max_joint_vel = np.concatenate([spec['max_joint_vel'] for spec in self.agents]) super().__init__(scene, action_spec, observation_spec, gamma, horizon, timestep, n_substeps, n_intermediate_steps, additional_data, collision_spec, max_joint_vel, - **viewer_params) + default_camera_mode=default_camera_mode, **viewer_params) # Get the transformations from table to robot coordinate system for i, agent_spec in enumerate(self.agents): diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/defend.py b/mushroom_rl/environments/mujoco_envs/air_hockey/defend.py index 2caacc739..d86e277ea 100644 --- a/mushroom_rl/environments/mujoco_envs/air_hockey/defend.py +++ b/mushroom_rl/environments/mujoco_envs/air_hockey/defend.py @@ -102,29 +102,3 @@ def is_absorbing(self, state): if (self.has_hit or self.has_bounce) and puck_pos_x > 0: return True return super().is_absorbing(state) - - -if __name__ == '__main__': - env = AirHockeyDefend(obs_noise=False, n_intermediate_steps=4, random_init=True) - - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - while True: - action = np.zeros(3) - observation, reward, done, info = env.step(action) - env.render() - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 1 - - if done or steps > env.info.horizon: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/hit.py b/mushroom_rl/environments/mujoco_envs/air_hockey/hit.py index 44614eafb..d1c55b029 100644 --- a/mushroom_rl/environments/mujoco_envs/air_hockey/hit.py +++ b/mushroom_rl/environments/mujoco_envs/air_hockey/hit.py @@ -102,30 +102,6 @@ def reward(self, state, action, next_state, absorbing): r = 2 * r_hit + 10 * r_goal r -= self.action_penalty * np.linalg.norm(action) - return r + return r -if __name__ == '__main__': - env = AirHockeyHit(env_noise=False, obs_noise=False, n_intermediate_steps=4, random_init=True, - init_robot_state="right") - - env.reset() - R = 0. - J = 0. - gamma = 1. - steps = 0 - while True: - action = np.random.randn(3) * 5 - observation, reward, done, info = env.step(action) - env.render() - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 1 - if done or steps > env.info.horizon: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/prepare.py b/mushroom_rl/environments/mujoco_envs/air_hockey/prepare.py index f7ae803a4..70803bb1a 100644 --- a/mushroom_rl/environments/mujoco_envs/air_hockey/prepare.py +++ b/mushroom_rl/environments/mujoco_envs/air_hockey/prepare.py @@ -133,29 +133,3 @@ def is_absorbing(self, state): if puck_pos[0] > 0 or abs(puck_pos[1]) < 0.01: return True return False - - -if __name__ == '__main__': - env = AirHockeyPrepare(obs_noise=False, n_intermediate_steps=4, random_init=True, - sub_problem="bottom") - - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - while True: - action = np.random.randn(3) * 5 - observation, reward, done, info = env.step(action) - env.render() - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 4 - if done or steps > env.info.horizon * 2: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/repel.py b/mushroom_rl/environments/mujoco_envs/air_hockey/repel.py index a81569174..acd19b93b 100644 --- a/mushroom_rl/environments/mujoco_envs/air_hockey/repel.py +++ b/mushroom_rl/environments/mujoco_envs/air_hockey/repel.py @@ -101,29 +101,3 @@ def is_absorbing(self, state): if super().is_absorbing(state): return True return self.has_bounce - - -if __name__ == "__main__": - env = AirHockeyRepel(env_noise=False, obs_noise=False, n_intermediate_steps=4, - random_init=True) - - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - while True: - action = np.zeros(3) - observation, reward, done, info = env.step(action) - env.render() - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 1 - if done or steps > env.info.horizon: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() diff --git a/mushroom_rl/environments/puddle_world.py b/mushroom_rl/environments/puddle_world.py index bc9738087..140623572 100644 --- a/mushroom_rl/environments/puddle_world.py +++ b/mushroom_rl/environments/puddle_world.py @@ -51,9 +51,10 @@ def __init__(self, start=None, goal=None, goal_threshold=.1, noise_step=.025, self._actions[i][i // 2] = thrust * (i % 2 * 2 - 1) # MDP properties + dt = 0.1 action_space = Discrete(5) observation_space = Box(0., 1., shape=(2,)) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) # Visualization self._pixels = None @@ -86,7 +87,7 @@ def step(self, action): return self._state, reward, absorbing, {} - def render(self): + def render(self, record=False): if self._pixels is None: img_size = 100 pixels = np.zeros((img_size, img_size, 3)) @@ -114,7 +115,11 @@ def render(self): self._viewer.polygon(self._goal, 0, goal_area, color=(255, 0, 0), width=1) - self._viewer.display(0.1) + frame = self._viewer.get_frame() if record else None + + self._viewer.display(self.info.dt) + + return frame def stop(self): if self._viewer is not None: diff --git a/mushroom_rl/environments/pybullet.py b/mushroom_rl/environments/pybullet.py index ff4ad3bae..11931ae88 100644 --- a/mushroom_rl/environments/pybullet.py +++ b/mushroom_rl/environments/pybullet.py @@ -73,7 +73,7 @@ def __init__(self, files, actuation_spec, observation_spec, gamma, action_space = Box(*self._indexer.action_limits) observation_space = Box(*self._indexer.observation_limits) - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, self.dt) # Let the child class modify the mdp_info data structure mdp_info = self._modify_mdp_info(mdp_info) @@ -95,8 +95,13 @@ def reset(self, state=None): return observation - def render(self): - self._viewer.display() + def render(self, record=False): + frame = self._viewer.display() + + if record: + return frame + else: + return None def stop(self): self._viewer.close() diff --git a/mushroom_rl/environments/pybullet_envs/air_hockey/defend.py b/mushroom_rl/environments/pybullet_envs/air_hockey/defend.py index 5dd0c5cb9..c11c9dfda 100644 --- a/mushroom_rl/environments/pybullet_envs/air_hockey/defend.py +++ b/mushroom_rl/environments/pybullet_envs/air_hockey/defend.py @@ -159,31 +159,3 @@ def _create_observation(self, state): obs = super(AirHockeyDefendBullet, self)._create_observation(state) return np.append(obs, [self.has_hit, self.has_bounce]) - -if __name__ == '__main__': - import time - - env = AirHockeyDefendBullet(debug_gui=True, obs_noise=False, obs_delay=False, n_intermediate_steps=4, random_init=True) - - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - while True: - action = np.zeros(3) - observation, reward, done, info = env.step(action) - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 1 - - - if done or steps > env.info.horizon: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - time.sleep(1 / 60.) diff --git a/mushroom_rl/environments/pybullet_envs/air_hockey/hit.py b/mushroom_rl/environments/pybullet_envs/air_hockey/hit.py index ec91fbda8..ecee285c3 100644 --- a/mushroom_rl/environments/pybullet_envs/air_hockey/hit.py +++ b/mushroom_rl/environments/pybullet_envs/air_hockey/hit.py @@ -161,32 +161,3 @@ def _create_observation(self, state): obs = super(AirHockeyHitBullet, self)._create_observation(state) return np.append(obs, [self.has_hit]) - -if __name__ == '__main__': - import time - - env = AirHockeyHitBullet(debug_gui=True, env_noise=False, obs_noise=False, obs_delay=False, n_intermediate_steps=4, - table_boundary_terminate=True, random_init=True, init_robot_state="right") - - env.reset() - R = 0. - J = 0. - gamma = 1. - steps = 0 - - while True: - # action = np.random.randn(3) * 5 - action = np.array([0] * 3) - observation, reward, done, info = env.step(action) - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 1 - if done or steps > env.info.horizon: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - time.sleep(1 / 60.) diff --git a/mushroom_rl/environments/pybullet_envs/air_hockey/prepare.py b/mushroom_rl/environments/pybullet_envs/air_hockey/prepare.py index 8b33a11e0..9c5c96ea4 100644 --- a/mushroom_rl/environments/pybullet_envs/air_hockey/prepare.py +++ b/mushroom_rl/environments/pybullet_envs/air_hockey/prepare.py @@ -184,33 +184,3 @@ def _simulation_post_step(self): def _create_observation(self, state): obs = super(AirHockeyPrepareBullet, self)._create_observation(state) return np.append(obs, [self.has_hit]) - - -if __name__ == '__main__': - import time - - env = AirHockeyPrepareBullet(debug_gui=True, obs_noise=False, obs_delay=False, n_intermediate_steps=4, random_init=True, - init_state="bottom") - - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - while True: - - # action = np.random.randn(3) * 5 - action = np.array([0, 0, 0]) - observation, reward, done, info = env.step(action) - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 4 - if done or steps > env.info.horizon * 2: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - time.sleep(1 / 60.) \ No newline at end of file diff --git a/mushroom_rl/environments/pybullet_envs/air_hockey/repel.py b/mushroom_rl/environments/pybullet_envs/air_hockey/repel.py index 87b40b441..95233ac9b 100644 --- a/mushroom_rl/environments/pybullet_envs/air_hockey/repel.py +++ b/mushroom_rl/environments/pybullet_envs/air_hockey/repel.py @@ -152,31 +152,3 @@ def _simulation_post_step(self): def _create_observation(self, state): obs = super(AirHockeyRepelBullet, self)._create_observation(state) return np.append(obs, [self.has_hit]) - - -if __name__ == "__main__": - import time - - env = AirHockeyRepelBullet(debug_gui=True, env_noise=False, obs_noise=False, obs_delay=False, n_intermediate_steps=4, - random_init=True) - - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - while True: - action = np.zeros(3) - observation, reward, done, info = env.step(action) - gamma *= env.info.gamma - J += gamma * reward - R += reward - steps += 1 - if done or steps > env.info.horizon: - print("J: ", J, " R: ", R) - R = 0. - J = 0. - gamma = 1. - steps = 0 - env.reset() - time.sleep(1 / 60.) diff --git a/mushroom_rl/environments/segway.py b/mushroom_rl/environments/segway.py index 24bb64116..f63958b4b 100644 --- a/mushroom_rl/environments/segway.py +++ b/mushroom_rl/environments/segway.py @@ -31,7 +31,6 @@ def __init__(self, random_start=False): self._Ir = 4.54e-4 * 2 self._l = 13.8e-2 self._r = 5.5e-2 - self._dt = 1e-2 self._g = 9.81 self._max_u = 5 @@ -40,11 +39,12 @@ def __init__(self, random_start=False): high = np.array([-np.pi / 2, 15, 75]) # MDP properties + dt = 1e-2 observation_space = spaces.Box(low=-high, high=high) action_space = spaces.Box(low=np.array([-self._max_u]), high=np.array([self._max_u])) horizon = 300 - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) # Visualization self._viewer = Viewer(5 * self._l, 5 * self._l) @@ -70,8 +70,7 @@ def reset(self, state=None): def step(self, action): u = self._bound(action[0], -self._max_u, self._max_u) - new_state = odeint(self._dynamics, self._state, [0, self._dt], - (u,)) + new_state = odeint(self._dynamics, self._state, [0, self.info.dt], (u,)) self._state = np.array(new_state[-1]) self._state[0] = normalize_angle(self._state[0]) @@ -116,11 +115,11 @@ def _dynamics(self, state, t, u): return dx - def render(self, mode='human'): + def render(self, record=False): start = 2.5 * self._l * np.ones(2) end = 2.5 * self._l * np.ones(2) - dx = -self._state[2] * self._r * self._dt + dx = -self._state[2] * self._r * self.info.dt self._last_x += dx @@ -132,15 +131,21 @@ def render(self, mode='human'): end[0] += -2 * self._l * np.sin(self._state[0]) + self._last_x end[1] += 2 * self._l * np.cos(self._state[0]) - if (start[0] > 5 * self._l and end[0] > 5 * self._l) \ - or (start[0] < 0 and end[0] < 0): + if (start[0] > 5 * self._l and end[0] > 5 * self._l) or (start[0] < 0 and end[0] < 0): start[0] = start[0] % 5 * self._l end[0] = end[0] % 5 * self._l self._viewer.line(start, end) self._viewer.circle(start, self._r) - self._viewer.display(self._dt) + frame = self._viewer.get_frame() if record else None + + self._viewer.display(self.info.dt) + + return frame + + def stop(self): + self._viewer.close() diff --git a/mushroom_rl/environments/ship_steering.py b/mushroom_rl/environments/ship_steering.py index ecc4a77a9..e2cc7c073 100644 --- a/mushroom_rl/environments/ship_steering.py +++ b/mushroom_rl/environments/ship_steering.py @@ -9,8 +9,7 @@ class ShipSteering(Environment): """ The Ship Steering environment as presented in: - "Hierarchical Policy Gradient Algorithms". Ghavamzadeh M. and Mahadevan S.. - 2013. + "Hierarchical Policy Gradient Algorithms". Ghavamzadeh M. and Mahadevan S.. 2013. """ def __init__(self, small=True, n_steps_action=3): @@ -30,7 +29,6 @@ def __init__(self, small=True, n_steps_action=3): self.omega_max = np.array([np.pi / 12.]) self._v = 3. self._T = 5. - self._dt = .2 self._gate_s = np.empty(2) self._gate_e = np.empty(2) self._gate_s[0] = 100 if small else 350 @@ -44,11 +42,12 @@ def __init__(self, small=True, n_steps_action=3): self.n_steps_action = n_steps_action # MDP properties + dt = .2 observation_space = spaces.Box(low=low, high=high) action_space = spaces.Box(low=-self.omega_max, high=self.omega_max) horizon = 5000 gamma = .99 - mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) # Visualization self._viewer = Viewer(self.field_size, self.field_size, @@ -79,10 +78,10 @@ def step(self, action): for _ in range(self.n_steps_action): state = new_state new_state = np.empty(4) - new_state[0] = state[0] + self._v * np.cos(state[2]) * self._dt - new_state[1] = state[1] + self._v * np.sin(state[2]) * self._dt - new_state[2] = normalize_angle(state[2] + state[3] * self._dt) - new_state[3] = state[3] + (r - state[3]) * self._dt / self._T + new_state[0] = state[0] + self._v * np.cos(state[2]) * self.info.dt + new_state[1] = state[1] + self._v * np.sin(state[2]) * self.info.dt + new_state[2] = normalize_angle(state[2] + state[3] * self.info.dt) + new_state[3] = state[3] + (r - state[3]) * self.info.dt / self._T if new_state[0] > self.field_size \ or new_state[1] > self.field_size \ @@ -107,7 +106,7 @@ def step(self, action): return self._state, reward, absorbing, {} - def render(self, mode='human'): + def render(self, record=False): self._viewer.line(self._gate_s, self._gate_e, width=3) @@ -121,7 +120,11 @@ def render(self, mode='human'): self._viewer.polygon(self._state[:2], self._state[2], boat, color=(32, 193, 54)) - self._viewer.display(self._dt) + frame = self._viewer.get_frame() if record else None + + self._viewer.display(self.info.dt) + + return frame def stop(self): self._viewer.close() diff --git a/mushroom_rl/policy/torch_policy.py b/mushroom_rl/policy/torch_policy.py index e1710b777..3a1a6f32b 100644 --- a/mushroom_rl/policy/torch_policy.py +++ b/mushroom_rl/policy/torch_policy.py @@ -6,7 +6,7 @@ from mushroom_rl.policy import Policy from mushroom_rl.approximators import Regressor from mushroom_rl.approximators.parametric import TorchApproximator -from mushroom_rl.utils.torch import to_float_tensor +from mushroom_rl.utils.torch import to_float_tensor, CategoricalWrapper from mushroom_rl.utils.parameters import to_parameter from itertools import chain @@ -230,11 +230,12 @@ def entropy_t(self, state=None): return self._action_dim / 2 * np.log(2 * np.pi * np.e) + torch.sum(self._log_sigma) def distribution_t(self, state): - mu, sigma = self.get_mean_and_covariance(state) - return torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=sigma) + mu, chol_sigma = self.get_mean_and_chol(state) + return torch.distributions.MultivariateNormal(loc=mu, scale_tril=chol_sigma, validate_args=False) - def get_mean_and_covariance(self, state): - return self._mu(state, **self._predict_params, output_tensor=True), torch.diag(torch.exp(2 * self._log_sigma)) + def get_mean_and_chol(self, state): + assert torch.all(torch.exp(self._log_sigma) > 0) + return self._mu(state, **self._predict_params, output_tensor=True), torch.diag(torch.exp(self._log_sigma)) def set_weights(self, weights): log_sigma_data = torch.from_numpy(weights[-self._action_dim:]) @@ -259,13 +260,6 @@ class BoltzmannTorchPolicy(TorchPolicy): Torch policy implementing a Boltzmann policy. """ - class CategoricalWrapper(torch.distributions.Categorical): - def __init__(self, logits): - super().__init__(logits=logits) - - def log_prob(self, value): - return super().log_prob(value.squeeze()) - def __init__(self, network, input_shape, output_shape, beta, use_cuda=False, **params): """ Constructor. @@ -314,7 +308,7 @@ def entropy_t(self, state): def distribution_t(self, state): logits = self._logits(state, **self._predict_params, output_tensor=True) * self._beta(state.numpy()) - return BoltzmannTorchPolicy.CategoricalWrapper(logits) + return CategoricalWrapper(logits) def set_weights(self, weights): self._logits.set_weights(weights) diff --git a/mushroom_rl/utils/angles.py b/mushroom_rl/utils/angles.py index 974a25ccc..16fa606c8 100644 --- a/mushroom_rl/utils/angles.py +++ b/mushroom_rl/utils/angles.py @@ -118,3 +118,31 @@ def euler_to_quat(euler): return R.from_euler('xyz', euler).as_quat()[[3, 0, 1, 2]] else: return R.from_euler('xyz', euler.T).as_quat()[:, [3, 0, 1, 2]].T + + +def mat_to_euler(mat): + """ + Convert a rotation matrix to euler angles. + + Args: + mat (np.ndarray): a 3d rotation matrix. + + Returns: + The euler angles [x, y, z] representation of the quaternion + + """ + return R.from_matrix(mat).as_euler('xyz') + + +def euler_to_mat(euler): + """ + Convert euler angles into a a rotation matrix. + + Args: + euler (np.ndarray): euler angles [x, y, z] to be converted. + + Returns: + The rotation matrix representation of the euler angles + + """ + return R.from_euler('xyz', euler).as_matrix() diff --git a/mushroom_rl/utils/mujoco/__init__.py b/mushroom_rl/utils/mujoco/__init__.py index 4fca999d5..0b5917b5b 100644 --- a/mushroom_rl/utils/mujoco/__init__.py +++ b/mushroom_rl/utils/mujoco/__init__.py @@ -1,3 +1,3 @@ -from .viewer import MujocoGlfwViewer +from .viewer import MujocoViewer from .observation_helper import ObservationHelper, ObservationType from .kinematics import forward_kinematics diff --git a/mushroom_rl/utils/mujoco/viewer.py b/mushroom_rl/utils/mujoco/viewer.py index ecef6f8b8..fcdda9de2 100644 --- a/mushroom_rl/utils/mujoco/viewer.py +++ b/mushroom_rl/utils/mujoco/viewer.py @@ -1,19 +1,76 @@ +import os import glfw import mujoco import time - +import collections +from itertools import cycle import numpy as np -class MujocoGlfwViewer: +def _import_egl(width, height): + from mujoco.egl import GLContext + + return GLContext(width, height) + + +def _import_glfw(width, height): + from mujoco.glfw import GLContext + + return GLContext(width, height) + + +def _import_osmesa(width, height): + from mujoco.osmesa import GLContext + + return GLContext(width, height) + + +_ALL_RENDERERS = collections.OrderedDict( + [ + ("glfw", _import_glfw), + ("egl", _import_egl), + ("osmesa", _import_osmesa), + ] +) + + +class MujocoViewer: """ - Class that creates a Glfw viewer for mujoco environments. - Controls: - Space: Pause / Unpause simulation - c: Turn contact force and constraint visualisation on / off - t: Make models transparent + Class that creates a viewer for mujoco environments. + """ - def __init__(self, model, dt, width=1920, height=1080, start_paused=False, custom_render_callback=None): + + def __init__(self, model, dt, width=1920, height=1080, start_paused=False, + custom_render_callback=None, record=False, camera_params=None, + default_camera_mode="static", hide_menu_on_startup=None, + geom_group_visualization_on_startup=None, headless=False): + """ + Constructor. + + Args: + model: Mujoco model. + dt (float): Timestep of the environment, (not the simulation). + width (int): Width of the viewer window. + height (int): Height of the viewer window. + start_paused (bool): If True, the rendering is paused in the beginning of the simulation. + custom_render_callback (func): Custom render callback function, which is supposed to be called + during rendering. + record (bool): If true, frames are returned during rendering. + camera_params (dict): Dictionary of dictionaries including custom parameterization of the three cameras. + Checkout the function get_default_camera_params() to know what parameters are expected. Is some camera + type specification or parameter is missing, the default one is used. + hide_menu_on_startup (bool): If True, the menu is hidden on startup. + geom_group_visualization_on_startup (int/list): int or list defining which geom group_ids should be + visualized on startup. If None, all are visualized. + headless (bool): If True, render will be done in headless mode. + + """ + + if hide_menu_on_startup is None and headless: + hide_menu_on_startup = True + elif hide_menu_on_startup is None and not headless: + hide_menu_on_startup = False + self.button_left = False self.button_right = False self.button_middle = False @@ -23,40 +80,101 @@ def __init__(self, model, dt, width=1920, height=1080, start_paused=False, custo self.frames = 0 self.start_time = time.time() - glfw.init() - glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, 0) + self._headless = headless + self._model = model + self._font_scale = 100 + + if headless: + # use the OpenGL render that is available on the machine + self._opengl_context = self.setup_opengl_backend_headless(width, height) + self._opengl_context.make_current() + self._width, self._height = self.update_headless_size(width, height) + else: + # use glfw + self._width, self._height = width, height + glfw.init() + glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, 0) + self._window = glfw.create_window(width=self._width, height=self._height, + title="MuJoCo", monitor=None, share=None) + glfw.make_context_current(self._window) + glfw.set_mouse_button_callback(self._window, self.mouse_button) + glfw.set_cursor_pos_callback(self._window, self.mouse_move) + glfw.set_key_callback(self._window, self.keyboard) + glfw.set_scroll_callback(self._window, self.scroll) + + self._set_mujoco_buffers() + + if record and not headless: + # dont allow to change the window size to have equal frame size during recording + glfw.window_hint(glfw.RESIZABLE, False) + + self._viewport = mujoco.MjrRect(0, 0, self._width, self._height) self._loop_count = 0 self._time_per_render = 1 / 60. + self._run_speed_factor = 1.0 self._paused = start_paused - self._window = glfw.create_window(width=width, height=height, title="MuJoCo", monitor=None, share=None) - glfw.make_context_current(self._window) - # Disable v_sync, so swap_buffers does not block # glfw.swap_interval(0) - glfw.set_mouse_button_callback(self._window, self.mouse_button) - glfw.set_cursor_pos_callback(self._window, self.mouse_move) - glfw.set_key_callback(self._window, self.keyboard) - glfw.set_scroll_callback(self._window, self.scroll) - - self._model = model - - self._scene = mujoco.MjvScene(model, 1000) + self._scene = mujoco.MjvScene(self._model, 1000) self._scene_option = mujoco.MjvOption() - self._camera = mujoco.MjvCamera() mujoco.mjv_defaultFreeCamera(model, self._camera) + if camera_params is None: + self._camera_params = self.get_default_camera_params() + else: + self._camera_params = self._assert_camera_params(camera_params) + self._all_camera_modes = ("static", "follow", "top_static") + self._camera_mode_iter = cycle(self._all_camera_modes) + self._camera_mode = None + self._camera_mode_target = next(self._camera_mode_iter) + assert default_camera_mode in self._all_camera_modes + while self._camera_mode_target != default_camera_mode: + self._camera_mode_target = next(self._camera_mode_iter) + self._set_camera() + + self.custom_render_callback = custom_render_callback - self._viewport = mujoco.MjrRect(0, 0, width, height) - self._context = mujoco.MjrContext(model, mujoco.mjtFontScale(100)) + self._overlay = {} + self._hide_menu = hide_menu_on_startup - self.rgb_buffer = np.empty((width, height, 3), dtype=np.uint8) + if geom_group_visualization_on_startup is not None: + assert type(geom_group_visualization_on_startup) == list or type(geom_group_visualization_on_startup) == int + if type(geom_group_visualization_on_startup) is not list: + geom_group_visualization_on_startup = [geom_group_visualization_on_startup] + for group_id, _ in enumerate(self._scene_option.geomgroup): + if group_id not in geom_group_visualization_on_startup: + self._scene_option.geomgroup[group_id] = False + + def load_new_model(self, model): + """ + Loads a new model to the viewer, and resets the scene and context. + This is used in MultiMujoco environments. + + Args: + model: Mujoco model. + + """ + + self._model = model + self._scene = mujoco.MjvScene(model, 1000) + self._context = mujoco.MjrContext(model, mujoco.mjtFontScale(self._font_scale)) - self.custom_render_callback = custom_render_callback def mouse_button(self, window, button, act, mods): + """ + Mouse button callback for glfw. + + Args: + window: glfw window. + button: glfw button id. + act: glfw action. + mods: glfw mods. + + """ + self.button_left = glfw.get_mouse_button(self._window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS self.button_right = glfw.get_mouse_button(self._window, glfw.MOUSE_BUTTON_RIGHT) == glfw.PRESS self.button_middle = glfw.get_mouse_button(self._window, glfw.MOUSE_BUTTON_MIDDLE) == glfw.PRESS @@ -64,6 +182,16 @@ def mouse_button(self, window, button, act, mods): self.last_x, self.last_y = glfw.get_cursor_pos(self._window) def mouse_move(self, window, x_pos, y_pos): + """ + Mouse mode callback for glfw. + + Args: + window: glfw window. + x_pos: Current mouse x position. + y_pos: Current mouse y position. + + """ + if not self.button_left and not self.button_right and not self.button_middle: return @@ -87,6 +215,18 @@ def mouse_move(self, window, x_pos, y_pos): mujoco.mjv_moveCamera(self._model, action, dx / width, dy / height, self._scene, self._camera) def keyboard(self, window, key, scancode, act, mods): + """ + Keyboard callback for glfw. + + Args: + window: glfw window. + key: glfw key event. + scancode: glfw scancode. + act: glfw action. + mods: glfw mods. + + """ + if act != glfw.RELEASE: return @@ -103,50 +243,388 @@ def keyboard(self, window, key, scancode, act, mods): self._scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = not self._scene_option.flags[ mujoco.mjtVisFlag.mjVIS_TRANSPARENT] + if key == glfw.KEY_0: + self._scene_option.geomgroup[0] = not self._scene_option.geomgroup[0] + + if key == glfw.KEY_1: + self._scene_option.geomgroup[1] = not self._scene_option.geomgroup[1] + + if key == glfw.KEY_2: + self._scene_option.geomgroup[2] = not self._scene_option.geomgroup[2] + + if key == glfw.KEY_3: + self._scene_option.geomgroup[3] = not self._scene_option.geomgroup[3] + + if key == glfw.KEY_4: + self._scene_option.geomgroup[4] = not self._scene_option.geomgroup[4] + + if key == glfw.KEY_5: + self._scene_option.geomgroup[5] = not self._scene_option.geomgroup[5] + + if key == glfw.KEY_6: + self._scene_option.geomgroup[6] = not self._scene_option.geomgroup[6] + + if key == glfw.KEY_7: + self._scene_option.geomgroup[7] = not self._scene_option.geomgroup[7] + + if key == glfw.KEY_8: + self._scene_option.geomgroup[8] = not self._scene_option.geomgroup[8] + + if key == glfw.KEY_9: + self._scene_option.geomgroup[9] = not self._scene_option.geomgroup[9] + + if key == glfw.KEY_TAB: + self._camera_mode_target = next(self._camera_mode_iter) + + if key == glfw.KEY_S: + self._run_speed_factor /= 2.0 + + if key == glfw.KEY_F: + self._run_speed_factor *= 2.0 + + if key == glfw.KEY_E: + self._scene_option.frame = not self._scene_option.frame + + if key == glfw.KEY_H: + if self._hide_menu: + self._hide_menu = False + else: + self._hide_menu = True + def scroll(self, window, x_offset, y_offset): + """ + Scrolling callback for glfw. + + Args: + window: glfw window. + x_offset: x scrolling offset. + y_offset: y scrolling offset. + + """ + mujoco.mjv_moveCamera(self._model, mujoco.mjtMouse.mjMOUSE_ZOOM, 0, 0.05 * y_offset, self._scene, self._camera) - def render(self, data): + def _set_mujoco_buffers(self): + self._context = mujoco.MjrContext(self._model, mujoco.mjtFontScale(self._font_scale)) + if self._headless: + mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN, self._context) + if self._context.currentBuffer != mujoco.mjtFramebuffer.mjFB_OFFSCREEN: + raise RuntimeError("Offscreen rendering not supported") + else: + mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_WINDOW, self._context) + if self._context.currentBuffer != mujoco.mjtFramebuffer.mjFB_WINDOW: + raise RuntimeError("Window rendering not supported") + + def update_headless_size(self, width, height): + _context = mujoco.MjrContext(self._model, mujoco.mjtFontScale(self._font_scale)) + if width > _context.offWidth or height > _context.offHeight: + width = max(width, self._model.vis.global_.offwidth) + height = max(height, self._model.vis.global_.offheight) + + if width != _context.offWidth or height != _context.offHeight: + self._model.vis.global_.offwidth = width + self._model.vis.global_.offheight = height + + return width, height + + def render(self, data, record): + """ + Main rendering function. + + Args: + data: Mujoco data structure. + record (bool): If true, frames are returned during rendering. + + Returns: + If record is True, frames are returned during rendering, else None. + + """ def render_inner_loop(self): + + if not self._headless: + self._create_overlay() + render_start = time.time() mujoco.mjv_updateScene(self._model, data, self._scene_option, None, self._camera, mujoco.mjtCatBit.mjCAT_ALL, self._scene) - self._viewport.width, self._viewport.height = glfw.get_window_size(self._window) + if not self._headless: + self._viewport.width, self._viewport.height = glfw.get_window_size(self._window) mujoco.mjr_render(self._viewport, self._scene, self._context) + for gridpos, [t1, t2] in self._overlay.items(): + + if self._hide_menu: + continue + + mujoco.mjr_overlay( + mujoco.mjtFont.mjFONT_SHADOW, + gridpos, + self._viewport, + t1, + t2, + self._context) + if self.custom_render_callback is not None: self.custom_render_callback(self._viewport, self._context) - glfw.swap_buffers(self._window) - glfw.poll_events() + if not self._headless: + glfw.swap_buffers(self._window) + glfw.poll_events() + if glfw.window_should_close(self._window): + self.stop() + exit(0) self.frames += 1 - - if glfw.window_should_close(self._window): - self.stop() - exit(0) - + self._overlay.clear() self._time_per_render = 0.9 * self._time_per_render + 0.1 * (time.time() - render_start) - """ - if return_img: - mujoco.mjr_readPixels(self.rgb_buffer, None, self._viewport, self._context) - return self.rgb_buffer - """ if self._paused: while self._paused: render_inner_loop(self) - self._loop_count += self.dt / self._time_per_render + if record: + self._loop_count = 1 + else: + self._loop_count += self.dt / (self._time_per_render * self._run_speed_factor) while self._loop_count > 0: render_inner_loop(self) + self._set_camera() self._loop_count -= 1 + if record: + return self.read_pixels() + + def read_pixels(self, depth=False): + """ + Reads the pixels from the glfw viewer. + + Args: + depth (bool): If True, depth map is also returned. + + Returns: + If depth is True, tuple of np.arrays (rgb and depth), else just a single + np.array for the rgb image. + + """ + + if self._headless: + shape = (self._width, self._height) + else: + shape = glfw.get_framebuffer_size(self._window) + + if depth: + rgb_img = np.zeros((shape[1], shape[0], 3), dtype=np.uint8) + depth_img = np.zeros((shape[1], shape[0], 1), dtype=np.float32) + mujoco.mjr_readPixels(rgb_img, depth_img, self._viewport, self._context) + return (np.flipud(rgb_img), np.flipud(depth_img)) + else: + img = np.zeros((shape[1], shape[0], 3), dtype=np.uint8) + mujoco.mjr_readPixels(img, None, self._viewport, self._context) + return np.flipud(img) + def stop(self): - glfw.destroy_window(self._window) + """ + Destroys the glfw image. + + """ + if not self._headless: + glfw.destroy_window(self._window) + + def _create_overlay(self): + """ + This function creates and adds all overlays used in the viewer. + + """ + + topleft = mujoco.mjtGridPos.mjGRID_TOPLEFT + topright = mujoco.mjtGridPos.mjGRID_TOPRIGHT + bottomleft = mujoco.mjtGridPos.mjGRID_BOTTOMLEFT + bottomright = mujoco.mjtGridPos.mjGRID_BOTTOMRIGHT + + def add_overlay(gridpos, text1, text2="", make_new_line=True): + if gridpos not in self._overlay: + self._overlay[gridpos] = ["", ""] + if make_new_line: + self._overlay[gridpos][0] += text1 + "\n" + self._overlay[gridpos][1] += text2 + "\n" + else: + self._overlay[gridpos][0] += text1 + self._overlay[gridpos][1] += text2 + + add_overlay( + bottomright, + "Framerate:", + str(int(1/self._time_per_render * self._run_speed_factor)), make_new_line=False) + + add_overlay( + topleft, + "Press SPACE to pause.") + + add_overlay( + topleft, + "Press H to hide the menu.") + + add_overlay( + topleft, + "Press TAB to switch cameras.") + + add_overlay( + topleft, + "Press T to make the model transparent.") + + add_overlay( + topleft, + "Press E to toggle reference frames.") + + add_overlay( + topleft, + "Press 0-9 to disable/enable geom group visualization.") + + visualize_contact = "On" if self._scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] else "Off" + add_overlay( + topleft, + "Contact force visualization (Press C):", visualize_contact) + + add_overlay( + topleft, + "Camera mode:", + self._camera_mode) + + add_overlay( + topleft, + "Run speed = %.3f x real time" % + self._run_speed_factor, + "[S]lower, [F]aster", make_new_line=False) + + def _set_camera(self): + """ + Sets the camera mode to the current camera mode target. Allowed camera + modes are "follow" in which the model is tracked, "static" that is a static + camera at the default camera positon, and "top_static" that is a static + camera on top of the model. + + """ + + if self._camera_mode_target == "follow": + if self._camera_mode != "follow": + self._camera.fixedcamid = -1 + self._camera.type = mujoco.mjtCamera.mjCAMERA_TRACKING + self._camera.trackbodyid = 0 + self._set_camera_properties(self._camera_mode_target) + elif self._camera_mode_target == "static": + if self._camera_mode != "static": + self._camera.fixedcamid = 0 + self._camera.type = mujoco.mjtCamera.mjCAMERA_FREE + self._camera.trackbodyid = -1 + self._set_camera_properties(self._camera_mode_target) + elif self._camera_mode_target == "top_static": + if self._camera_mode != "top_static": + self._camera.fixedcamid = 0 + self._camera.type = mujoco.mjtCamera.mjCAMERA_FREE + self._camera.trackbodyid = -1 + self._set_camera_properties(self._camera_mode_target) + + def _set_camera_properties(self, mode): + """ + Sets the camera properties "distance", "elevation", and "azimuth" + as well as the camera mode based on the provided mode. + + Args: + mode (str): Camera mode. (either "follow", "static", or "top_static") + + """ + + cam_params = self._camera_params[mode] + self._camera.distance = cam_params["distance"] + self._camera.elevation = cam_params["elevation"] + self._camera.azimuth = cam_params["azimuth"] + if "lookat" in cam_params: + self._camera.lookat = np.array(cam_params["lookat"]) + self._camera_mode = mode + + def _assert_camera_params(self, camera_params): + """ + Asserts if the provided camera parameters are valid or not. Also, if + properties of some camera types are not specified, the default parameters + are used. + + Args: + camera_params (dict): Dictionary of dictionaries containig parameters for each camera type. + + Returns: + Dictionary of dictionaries with parameters for each camera type. + + """ + + default_camera_params = self.get_default_camera_params() + + # check if the provided camera types and parameters are valid + for cam_type in camera_params.keys(): + assert cam_type in default_camera_params.keys(), f"Camera type \"{cam_type}\" is unknown. Allowed " \ + f"camera types are {list(default_camera_params.keys())}." + for param in camera_params[cam_type].keys(): + assert param in default_camera_params[cam_type].keys(), f"Parameter \"{param}\" of camera type " \ + f"\"{cam_type}\" is unknown. Allowed " \ + f"parameters are" \ + f" {list(default_camera_params[cam_type].keys())}" + + # add default parameters if not specified + for cam_type in default_camera_params.keys(): + if cam_type not in camera_params.keys(): + camera_params[cam_type] = default_camera_params[cam_type] + else: + for param in default_camera_params[cam_type].keys(): + if param not in camera_params[cam_type].keys(): + camera_params[cam_type][param] = default_camera_params[cam_type][param] + + return camera_params + + @staticmethod + def get_default_camera_params(): + """ + Getter for default camera paramterization. + + Returns: + Dictionary of dictionaries with default parameters for each camera type. + + """ + + return dict(static=dict(distance=15.0, elevation=-45.0, azimuth=90.0, lookat=np.array([0.0, 0.0, 0.0])), + follow=dict(distance=3.5, elevation=0.0, azimuth=90.0), + top_static=dict(distance=5.0, elevation=-90.0, azimuth=90.0, lookat=np.array([0.0, 0.0, 0.0]))) + + + def setup_opengl_backend_headless(self, width, height): + + backend = os.environ.get("MUJOCO_GL") + if backend is not None: + try: + opengl_context = _ALL_RENDERERS[backend](width, height) + except KeyError: + raise RuntimeError( + "Environment variable {} must be one of {!r}: got {!r}.".format( + "MUJOCO_GL", _ALL_RENDERERS.keys(), backend + ) + ) + else: + # iterate through all OpenGL backends to see which one is available + for name, _ in _ALL_RENDERERS.items(): + try: + opengl_context = _ALL_RENDERERS[name](width, height) + backend = name + break + except: # noqa:E722 + pass + if backend is None: + raise RuntimeError( + "No OpenGL backend could be imported. Attempting to create a " + "rendering context will result in a RuntimeError." + ) + + return opengl_context diff --git a/mushroom_rl/utils/optimizers.py b/mushroom_rl/utils/optimizers.py index 657f09fc4..596acfd56 100644 --- a/mushroom_rl/utils/optimizers.py +++ b/mushroom_rl/utils/optimizers.py @@ -1,5 +1,4 @@ import numpy as np -import numpy_ml as npml from mushroom_rl.core import Serializable from mushroom_rl.utils.parameters import Parameter @@ -124,98 +123,49 @@ class AdamOptimizer(Optimizer): This class implements the Adam optimizer. """ - def __init__(self, lr=0.001, decay1=0.9, decay2=0.999, maximize=True): + def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-7, maximize=True): """ Constructor. Args: lr ([float, Parameter], 0.001): the learning rate; - decay1 (float, 0.9): Adam beta1 parameter; - decay2 (float, 0.999): Adam beta2 parameter; + beta1 (float, 0.9): Adam beta1 parameter; + beta2 (float, 0.999): Adam beta2 parameter; maximize (bool, True): by default Optimizers do a gradient ascent step. Set to False for gradient descent. """ super().__init__(lr, maximize) # lr_scheduler must be set to None, as we have our own scheduler - self._optimizer = npml.neural_nets.optimizers.Adam( - lr=self._lr.initial_value, - decay1=decay1, - decay2=decay2, - lr_scheduler=None - ) - - self._add_save_attr(_optimizer='pickle') - - def __call__(self, params, grads): - if self._maximize: - # -1*grads because numpy_ml does gradient descent by default, not ascent - grads *= -1 - # Fix the numpy_ml optimizer lr to the one we computed - self._optimizer.lr_scheduler.lr = self._lr() - return self._optimizer.update(params, grads, 'theta') - - -class AdaGradOptimizer(Optimizer): - """ - This class implements the AdaGrad optimizer. - - """ - def __init__(self, lr=0.001, maximize=True): - """ - Constructor. - - Args: - lr ([float, Parameter], 0.001): the learning rate; - maximize (bool, True): by default Optimizers do a gradient ascent step. Set to False for gradient descent. - - """ - super().__init__(lr, maximize) - # lr_scheduler must be set to None, as we have our own scheduler - self._optimizer = npml.neural_nets.optimizers.AdaGrad( - lr=self._lr.initial_value, - lr_scheduler=None - ) + self._m = None + self._v = None + self._beta1 = beta1 + self._beta2 = beta2 + self._eps = eps + self._t = 0 - self._add_save_attr(_optimizer='pickle') + self._add_save_attr(_m='numpy', + _v='numpy', + _beta1='primitive', + _beta2='primitive', + _t='primitive' + ) def __call__(self, params, grads): - if self._maximize: - # -1*grads because numpy_ml does gradient descent by default, not ascent + if not self._maximize: grads *= -1 - # Fix the numpy_ml optimizer lr to the one we computed - self._optimizer.lr_scheduler.lr = self._lr() - return self._optimizer.update(params, grads, 'theta') + if self._m is None: + self._t = 0 + self._m = np.zeros_like(params) + self._v = np.zeros_like(params) -class RMSPropOptimizer(Optimizer): - """ - This class implements the RMSProp optimizer. + self._t += 1 + self._m = self._beta1 * self._m + (1 - self._beta1) * grads + self._v = self._beta2 * self._v + (1 - self._beta2) * grads ** 2 - """ - def __init__(self, lr=0.001, decay=0.9, maximize=True): - """ - Constructor. + m_hat = self._m / (1 - self._beta1 ** self._t) + v_hat = self._v / (1 - self._beta2 ** self._t) - Args: - lr ([float, Parameter], 0.001): the learning rate; - decay (float, 0.9): rate of decay for the moving average; - maximize (bool, True): by default Optimizers do a gradient ascent step. Set to False for gradient descent. + update = self._lr() * m_hat / (np.sqrt(v_hat) + self._eps) - """ - super().__init__(lr, maximize) - # lr_scheduler must be set to None, as we have our own scheduler - self._optimizer = npml.neural_nets.optimizers.RMSProp( - lr=self._lr.initial_value, - decay=decay, - lr_scheduler=None - ) - - self._add_save_attr(_optimizer='pickle') - - def __call__(self, params, grads): - if self._maximize: - # -1*grads because numpy_ml does gradient descent by default, not ascent - grads *= -1 - # Fix the numpy_ml optimizer lr to the one we computed - self._optimizer.lr_scheduler.lr = self._lr() - return self._optimizer.update(params, grads, 'theta') + return params + update diff --git a/mushroom_rl/utils/plots/__init__.py b/mushroom_rl/utils/plots/__init__.py index 7a684cd12..512b99973 100644 --- a/mushroom_rl/utils/plots/__init__.py +++ b/mushroom_rl/utils/plots/__init__.py @@ -16,8 +16,5 @@ __all__ += ['Actions', 'LenOfEpisodeTraining', 'Observations', 'RewardPerEpisode', 'RewardPerStep'] - from ._implementations import common_buffers - __all__.append('common_buffers') - except ImportError: pass diff --git a/mushroom_rl/utils/plots/window.py b/mushroom_rl/utils/plots/window.py index 8669982ff..b05375789 100644 --- a/mushroom_rl/utils/plots/window.py +++ b/mushroom_rl/utils/plots/window.py @@ -1,14 +1,13 @@ import time -from PyQt5.QtGui import QBrush, QColor -from PyQt5.QtWidgets import QTreeWidgetItem +from PyQt5.QtGui import QGuiApplication, QBrush, QColor +from PyQt5.QtWidgets import QTreeWidget, QTreeWidgetItem, QSplitter from PyQt5 import QtCore import pyqtgraph as pg -from pyqtgraph.Qt import QtGui -class Window(QtGui.QSplitter): +class Window(QSplitter): """ This class is used creating windows for plotting. @@ -48,7 +47,7 @@ def __init__(self, plot_list, track_if_deactivated=False, super().__init__(QtCore.Qt.Horizontal) - self._activation_widget = QtGui.QTreeWidget() + self._activation_widget = QTreeWidget() self._activation_widget.setHeaderLabels(["Plots"]) self._activation_widget.itemClicked.connect(self.clicked) @@ -75,7 +74,7 @@ def __init__(self, plot_list, track_if_deactivated=False, plot_instance, plot_instance.plot_data_items_list[i] ] - self._GraphicsWindow = pg.GraphicsWindow(title=title) + self._GraphicsWindow = pg.GraphicsLayoutWidget(title=title) self.addWidget(self._activation_widget) self.addWidget(self._GraphicsWindow) @@ -120,7 +119,7 @@ def refresh(self): for plot_instance in self.plot_list: plot_instance.refresh() - QtGui.QGuiApplication.processEvents() + QGuiApplication.processEvents() def activate(self, item): """ diff --git a/mushroom_rl/utils/pybullet/viewer.py b/mushroom_rl/utils/pybullet/viewer.py index d66ac72f7..ffaa9755c 100644 --- a/mushroom_rl/utils/pybullet/viewer.py +++ b/mushroom_rl/utils/pybullet/viewer.py @@ -21,6 +21,8 @@ def display(self): img = self._get_image() super().display(img) + return img + def _get_image(self): view_matrix = self._client.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=self._origin, distance=self._distance, diff --git a/mushroom_rl/utils/record.py b/mushroom_rl/utils/record.py new file mode 100644 index 000000000..aa5761ba0 --- /dev/null +++ b/mushroom_rl/utils/record.py @@ -0,0 +1,69 @@ +import os +import cv2 +import datetime +from pathlib import Path + + +class VideoRecorder(object): + """ + Simple video record that creates a video from a stream of images. + """ + + def __init__(self, path="./mushroom_rl_recordings", tag=None, video_name=None, fps=60): + """ + Constructor. + + Args: + path: Path at which videos will be stored. + tag: Name of the directory at path in which the video will be stored. If None, a timestamp will be created. + video_name: Name of the video without extension. Default is "recording". + fps: Frame rate of the video. + """ + + if tag is None: + date_time = datetime.datetime.now() + tag = date_time.strftime("%d-%m-%Y_%H-%M-%S") + + self._path = Path(path) + self._path = self._path / tag + + self._video_name = video_name if video_name else "recording" + self._counter = 0 + + self._fps = fps + + self._video_writer = None + + def __call__(self, frame): + """ + Args: + frame (np.ndarray): Frame to be added to the video (H, W, RGB) + """ + assert frame is not None + + if self._video_writer is None: + height, width = frame.shape[:2] + self._create_video_writer(height, width) + + self._video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + + def _create_video_writer(self, height, width): + + name = self._video_name + if self._counter > 0: + name += f"-{self._counter}.mp4" + else: + name += ".mp4" + + self._path.mkdir(parents=True, exist_ok=True) + + path = self._path / name + + self._video_writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), + self._fps, (width, height)) + + def stop(self): + cv2.destroyAllWindows() + self._video_writer.release() + self._video_writer = None + self._counter += 1 diff --git a/mushroom_rl/utils/torch.py b/mushroom_rl/utils/torch.py index d586c67dc..41b6caeea 100644 --- a/mushroom_rl/utils/torch.py +++ b/mushroom_rl/utils/torch.py @@ -126,9 +126,40 @@ def to_int_tensor(x, use_cuda=False): def update_optimizer_parameters(optimizer, new_parameters): - for p_old, p_new in zip(optimizer.param_groups[0]['params'], new_parameters): - data = optimizer.state[p_old] - del optimizer.state[p_old] - optimizer.state[p_new] = data + if len(optimizer.state) > 0: + for p_old, p_new in zip(optimizer.param_groups[0]['params'], new_parameters): + data = optimizer.state[p_old] + del optimizer.state[p_old] + optimizer.state[p_new] = data - optimizer.param_groups[0]['params'] = new_parameters \ No newline at end of file + optimizer.param_groups[0]['params'] = new_parameters + + +class CategoricalWrapper(torch.distributions.Categorical): + """ + Wrapper for the Torch Categorical distribution. + + Needed to convert a vector of mushroom discrete action in an input with the proper shape of the original + distribution implemented in torch + + """ + def __init__(self, logits): + super().__init__(logits=logits) + + def log_prob(self, value): + return super().log_prob(value.squeeze()) + + +class DiagonalMultivariateGaussian(torch.distributions.Normal): + """ + Wrapper for the Torch Normal distribution, implementing a diagonal distribution. + + It behaves as the MultivariateNormal distribution, but avoids the computation of the full covariance matrix, + optimizing the computation time, particulalrly when a high dimensional vector is sampled. + + """ + def __init__(self, loc, scale): + super().__init__(loc=loc, scale=scale) + + def log_prob(self, value): + return torch.sum(super().log_prob(value), -1) \ No newline at end of file diff --git a/mushroom_rl/utils/viewer.py b/mushroom_rl/utils/viewer.py index c5c231456..7eb7091c1 100644 --- a/mushroom_rl/utils/viewer.py +++ b/mushroom_rl/utils/viewer.py @@ -5,6 +5,7 @@ import pygame import time import numpy as np +import cv2 class ImageViewer: @@ -78,8 +79,8 @@ def __init__(self, env_width, env_height, width=500, height=500, Constructor. Args: - env_width (int): The x dimension limit of the desired environment; - env_height (int): The y dimension limit of the desired environment; + env_width (float): The x dimension limit of the desired environment; + env_height (float): The y dimension limit of the desired environment; width (int, 500): width of the environment window; height (int, 500): height of the environment window; background (tuple, (0, 0, 0)): background color of the screen. @@ -314,6 +315,21 @@ def function(self, x_s, x_e, f, n_points=100, width=1, color=(255, 255, 255)): points = [self._transform([a, b]) for a, b in zip(x,y)] pygame.draw.lines(self.screen, color, False, points, width) + @staticmethod + def get_frame(): + """ + Getter. + + Returns: + The current Pygame surface as an RGB array. + + """ + surf = pygame.display.get_surface() + pygame_frame = pygame.surfarray.array3d(surf) + frame = pygame_frame.swapaxes(0, 1) + + return frame + def display(self, s): """ Display current frame and initialize the next frame to the background @@ -344,3 +360,66 @@ def _transform(self, p): def _rotate(p, theta): return np.array([np.cos(theta) * p[0] - np.sin(theta) * p[1], np.sin(theta) * p[0] + np.cos(theta) * p[1]]) + + +class CV2Viewer: + + """ + Simple viewer to display rendered images using cv2. + + """ + + def __init__(self, window_name, dt, width, height): + self._window_name = window_name + self._dt = dt + self._created_viewer = False + self._width = width + self._height = height + + def display(self, img): + """ + Displays an image. + + Args: + img (np.array): Image to display + + """ + + # display image the first time + if not self._created_viewer: + # Removes toolbar and status bar + cv2.namedWindow(self._window_name, flags=cv2.WINDOW_GUI_NORMAL) + cv2.resizeWindow(self._window_name, self._width, self._height) + cv2.imshow(self._window_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + self._wait() + self._created_viewer = True + + # if the window is not closed yet, display another image + elif not self._window_was_closed(): + cv2.imshow(self._window_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + self._wait() + + # window was closed, interrupt simulation + else: + exit() + + def _wait(self): + """ + Wait for the specified amount of time. Time is supposed to be in milliseconds. + + """ + wait_time = int(self._dt * 1000) + cv2.waitKey(wait_time) + + def _window_was_closed(self): + """ + Check if a window was closed. + + Returns: + True if the window was closed. + + """ + return cv2.getWindowProperty(self._window_name, cv2.WND_PROP_VISIBLE) == 0 + + def close(self): + cv2.destroyWindow(self._window_name) diff --git a/requirements.txt b/requirements.txt index 2b52c94c7..519f4f848 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,10 @@ numpy -numpy_ml scipy scikit-learn matplotlib joblib tqdm pygame -opencv-python +opencv-python>=4.7 torch pytest diff --git a/setup.py b/setup.py index cc49d99b8..10fd52869 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def glob_data_files(data_package, data_type=None): all_deps = [] for group_name in extras: - if group_name not in ['plots']: + if group_name not in ['plots','box2d']: all_deps += extras[group_name] extras['all'] = all_deps diff --git a/tests/algorithms/helper/utils.py b/tests/algorithms/helper/utils.py index 30b490d43..2fed7444a 100644 --- a/tests/algorithms/helper/utils.py +++ b/tests/algorithms/helper/utils.py @@ -1,8 +1,3 @@ -# Uncomment to run tests locally -# import sys -# import os -# sys.path = [os.getcwd()] + sys.path - import torch import numpy as np from sklearn.ensemble import ExtraTreesRegressor @@ -21,14 +16,14 @@ from mushroom_rl.policy.noise_policy import OrnsteinUhlenbeckPolicy from mushroom_rl.features._implementations.tiles_features import TilesFeatures from mushroom_rl.utils.parameters import Parameter, LinearParameter -from mushroom_rl.utils.optimizers import AdaptiveOptimizer, SGDOptimizer, AdamOptimizer, AdaGradOptimizer, \ - RMSPropOptimizer +from mushroom_rl.utils.optimizers import AdaptiveOptimizer, SGDOptimizer, AdamOptimizer from mushroom_rl.distributions.gaussian import GaussianDiagonalDistribution from mushroom_rl.utils.table import Table from mushroom_rl.utils.spaces import Discrete from mushroom_rl.features._implementations.functional_features import FunctionalFeatures from mushroom_rl.features._implementations.basis_features import BasisFeatures + class TestUtils: @classmethod @@ -49,11 +44,13 @@ def assert_eq(cls, this, that): this = this.model that = that.model for i in range(0, len(this)): - if cls._check_type(this[i], that[i], list) or cls._check_type(this[i], that[i], Ensemble) or cls._check_type(this[i], that[i], ExtraTreesRegressor): + if cls._check_type(this[i], that[i], list) or cls._check_type(this[i], that[i], Ensemble) \ + or cls._check_type(this[i], that[i], ExtraTreesRegressor): cls.assert_eq(this[i], that[i]) else: assert cls.eq_weights(this[i], that[i]) - elif cls._check_subtype(this, that, TorchPolicy) or cls._check_type(this, that, SACPolicy) or cls._check_subtype(this, that, ParametricPolicy): + elif cls._check_subtype(this, that, TorchPolicy) or cls._check_type(this, that, SACPolicy) \ + or cls._check_subtype(this, that, ParametricPolicy): assert cls.eq_weights(this, that) elif cls._check_subtype(this, that, TDPolicy): cls.assert_eq(this.get_q(), that.get_q()) @@ -81,10 +78,6 @@ def assert_eq(cls, this, that): assert cls.eq_sgd_optimizer(this, that) elif cls._check_type(this, that, AdamOptimizer): assert cls.eq_adam_optimizer(this, that) - elif cls._check_type(this, that, AdaGradOptimizer): - assert cls.eq_adagrad_optimizer(this, that) - elif cls._check_type(this, that, RMSPropOptimizer): - assert cls.eq_rmsprop_optimizer(this, that) elif cls._check_type(this, that, GaussianDiagonalDistribution): assert cls.eq_gaussian_diagonal_dist(this, that) elif cls._check_type(this, that, Table): @@ -314,24 +307,6 @@ def eq_adam_optimizer(cls, this, that): res = cls._eq_numpy(this._eps, that._eps) return res - @classmethod - def eq_adagrad_optimizer(cls, this, that): - """ - Compare two AdagradParameterOptimizer objects for equality - """ - - res = cls._eq_numpy(this._eps, that._eps) - return res - - @classmethod - def eq_rmsprop_optimizer(cls, this, that): - """ - Compare two RMSPropParameterOptimizer objects for equality - """ - - res = cls._eq_numpy(this._eps, that._eps) - return res - @classmethod def eq_gaussian_diagonal_dist(cls, this, that): """