From 5e74492aad44a118125b24aed55201b20ebf4e29 Mon Sep 17 00:00:00 2001 From: boris-il-forte Date: Sat, 28 Oct 2023 18:49:42 +0200 Subject: [PATCH] Started implementation of vectorized core - refactoring of core logic to support vectorized environments - implemented vectorized core, needs to be tested and debugged - renamed parallel environments into MultiprocessEnvironment, completed implementation (to be tested) - refactoring of core.mdp into core.env --- mushroom_rl/core/_impl/__init__.py | 3 +- mushroom_rl/core/_impl/core_logic.py | 45 +++- .../core/_impl/vectorized_core_logic.py | 56 ++++ mushroom_rl/core/core.py | 33 +-- mushroom_rl/core/environment.py | 10 +- ...ronment.py => multiprocess_environment.py} | 57 ++-- mushroom_rl/core/vectorized_core.py | 245 ++++++++++++++++++ mushroom_rl/core/vectorized_env.py | 71 ++++- mushroom_rl/environments/lqr.py | 4 +- mushroom_rl/environments/ship_steering.py | 2 +- 10 files changed, 468 insertions(+), 58 deletions(-) create mode 100644 mushroom_rl/core/_impl/vectorized_core_logic.py rename mushroom_rl/core/{parallel_environment.py => multiprocess_environment.py} (67%) create mode 100644 mushroom_rl/core/vectorized_core.py diff --git a/mushroom_rl/core/_impl/__init__.py b/mushroom_rl/core/_impl/__init__.py index b43e6afa..9eeb5fbd 100644 --- a/mushroom_rl/core/_impl/__init__.py +++ b/mushroom_rl/core/_impl/__init__.py @@ -2,4 +2,5 @@ from .torch_dataset import TorchDataset from .list_dataset import ListDataset from .type_conversions import DataConversion, NumpyConversion, TorchConversion, ListConversion -from .core_logic import CoreLogic \ No newline at end of file +from .core_logic import CoreLogic +from .vectorized_core_logic import VectorizedCoreLogic diff --git a/mushroom_rl/core/_impl/core_logic.py b/mushroom_rl/core/_impl/core_logic.py index babb41e2..16131916 100644 --- a/mushroom_rl/core/_impl/core_logic.py +++ b/mushroom_rl/core/_impl/core_logic.py @@ -6,11 +6,12 @@ def __init__(self): self.fit_required = None self.move_required = None - self._total_episodes_counter = 0 - self._total_steps_counter = 0 - self._current_episodes_counter = 0 - self._current_steps_counter = 0 + self._total_episodes_counter = None + self._total_steps_counter = None + self._current_episodes_counter = None + self._current_steps_counter = None + self._n_steps = None self._n_episodes = None self._n_steps_per_fit = None self._n_episodes_per_fit = None @@ -18,7 +19,7 @@ def __init__(self): self._steps_progress_bar = None self._episodes_progress_bar = None - def initialize_fit(self, n_steps_per_fit, n_episodes_per_fit): + def initialize_learn(self, n_steps_per_fit, n_episodes_per_fit): assert (n_episodes_per_fit is not None and n_steps_per_fit is None) \ or (n_episodes_per_fit is None and n_steps_per_fit is not None) @@ -26,9 +27,9 @@ def initialize_fit(self, n_steps_per_fit, n_episodes_per_fit): self._n_episodes_per_fit = n_episodes_per_fit if n_steps_per_fit is not None: - self.fit_required = lambda: self._current_steps_counter >= self._n_steps_per_fit + self.fit_required = self._fit_steps_condition else: - self.fit_required = lambda: self._current_episodes_counter >= self._n_episodes_per_fit + self.fit_required = self._fit_episodes_condition def initialize_evaluate(self): self.fit_required = lambda: False @@ -38,23 +39,21 @@ def initialize_run(self, n_steps, n_episodes, initial_states, quiet): or n_episodes is None and n_steps is not None and initial_states is None\ or n_episodes is None and n_steps is None and initial_states is not None + self._n_steps = n_steps self._n_episodes = len(initial_states) if initial_states is not None else n_episodes if n_steps is not None: - self.move_required = lambda: self._total_steps_counter < n_steps + self.move_required = self._move_steps_condition self._steps_progress_bar = tqdm(total=n_steps, dynamic_ncols=True, disable=quiet, leave=False) self._episodes_progress_bar = tqdm(disable=True) else: - self.move_required = lambda: self._total_episodes_counter < self._n_episodes + self.move_required = self._move_episodes_condition self._steps_progress_bar = tqdm(disable=True) self._episodes_progress_bar = tqdm(total=self._n_episodes, dynamic_ncols=True, disable=quiet, leave=False) - self._total_episodes_counter = 0 - self._total_steps_counter = 0 - self._current_episodes_counter = 0 - self._current_steps_counter = 0 + self._reset_counters() def get_initial_state(self, initial_states): if initial_states is None or self._total_episodes_counter == self._n_episodes: @@ -78,4 +77,22 @@ def after_fit(self): def terminate_run(self): self._steps_progress_bar.close() - self._episodes_progress_bar.close() \ No newline at end of file + self._episodes_progress_bar.close() + + def _reset_counters(self): + self._total_episodes_counter = 0 + self._total_steps_counter = 0 + self._current_episodes_counter = 0 + self._current_steps_counter = 0 + + def _move_steps_condition(self): + return self._total_steps_counter < self._n_steps + + def _fit_steps_condition(self): + return self._current_steps_counter >= self._n_steps_per_fit + + def _move_episodes_condition(self): + return self._total_episodes_counter < self._n_episodes + + def _fit_episodes_condition(self): + return self._current_episodes_counter >= self._n_episodes_per_fit diff --git a/mushroom_rl/core/_impl/vectorized_core_logic.py b/mushroom_rl/core/_impl/vectorized_core_logic.py new file mode 100644 index 00000000..82cd27dd --- /dev/null +++ b/mushroom_rl/core/_impl/vectorized_core_logic.py @@ -0,0 +1,56 @@ +import numpy as np + +from .core_logic import CoreLogic + + +class VectorizedCoreLogic(CoreLogic): + def __init__(self, n_envs): + self._n_envs = n_envs + + super().__init__() + + def get_action_mask(self): + action_mask = np.ones(self._n_envs, dtype=bool) + + if self._n_episodes is not None: + if self._n_episodes_per_fit is not None: + action_mask = self._current_episodes_counter != self._n_episodes_per_fit + else: + action_mask = self._current_episodes_counter != self._n_episodes + + return action_mask + + def get_initial_state(self, initial_states): + + if initial_states is None or np.all(self._total_episodes_counter == self._n_episodes): + initial_state = None + else: + initial_state = initial_states[self._total_episodes_counter] # FIXME + + return initial_state + + def after_step(self, last): + self._total_steps_counter += self._n_envs + self._current_steps_counter += self._n_envs + self._steps_progress_bar.update(self._n_envs) + + completed = last.sum() + self._total_episodes_counter += completed + self._current_episodes_counter += completed + self._episodes_progress_bar.update(completed) + + def after_fit(self): + self._current_episodes_counter = np.zeros(self._n_envs, dtype=int) + self._current_steps_counter = 0 + + def _reset_counters(self): + self._total_episodes_counter = np.zeros(self._n_envs, dtype=int) + self._current_episodes_counter = np.zeros(self._n_envs, dtype=int) + self._total_steps_counter = 0 + self._current_steps_counter = 0 + + def _move_episodes_condition(self): + return np.sum(self._total_episodes_counter) < self._n_episodes + + def _fit_episodes_condition(self): + return np.sum(self._current_episodes_counter) >= self._n_episodes_per_fit diff --git a/mushroom_rl/core/core.py b/mushroom_rl/core/core.py index 1dc5154a..6cbca35b 100644 --- a/mushroom_rl/core/core.py +++ b/mushroom_rl/core/core.py @@ -9,19 +9,22 @@ class Core(object): Implements the functions to run a generic algorithm. """ - def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None, record_dictionary=None): + def __init__(self, agent, env, callbacks_fit=None, callback_step=None, record_dictionary=None): """ Constructor. Args: agent (Agent): the agent moving according to a policy; - mdp (Environment): the environment in which the agent moves; + env (Environment): the environment in which the agent moves; callbacks_fit (list): list of callbacks to execute at the end of each fit; callback_step (Callback): callback to execute after each step; + record_dictionary (dict, None): a dictionary of parameters for the recording, must containt the + recorder_class, fps, and optionally other keyword arguments to be passed to build the recorder class. + By default, the VideoRecorder class is used and the environment action frequency as frames per second. """ self.agent = agent - self.mdp = mdp + self.env = env self.callbacks_fit = callbacks_fit if callbacks_fit is not None else list() self.callback_step = callback_step if callback_step is not None else lambda x: None @@ -36,8 +39,8 @@ def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None, record_di record_dictionary = dict() self._record = self._build_recorder_class(**record_dictionary) - def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, - n_episodes_per_fit=None, render=False, quiet=False, record=False): + def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_per_fit=None, + render=False, record=False, quiet=False): """ This function moves the agent in the environment and fits the policy using the collected samples. The agent can be moved for a given number of steps or a given number of episodes and, independently from this @@ -52,15 +55,15 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_per_fit (int, None): number of episodes between each fit of the policy; render (bool, False): whether to render the environment or not; - quiet (bool, False): whether to show the progress bar or not; record (bool, False): whether to record a video of the environment or not. If True, also the render flag should be set to True. + quiet (bool, False): whether to show the progress bar or not. """ assert (render and record) or (not record), "To record, the render flag must be set to true" - self._core_logic.initialize_fit(n_steps_per_fit, n_episodes_per_fit) + self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) - dataset = Dataset(self.mdp.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) + dataset = Dataset(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) self._run(dataset, n_steps, n_episodes, render, quiet, record) @@ -88,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa self._core_logic.initialize_evaluate() n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes - dataset = Dataset(self.mdp.info, self.agent.info, n_steps, n_episodes_dataset) + dataset = Dataset(self.env.info, self.agent.info, n_steps, n_episodes_dataset) return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states) @@ -121,7 +124,7 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat last = sample[5] self.agent.stop() - self.mdp.stop() + self.env.stop() self._end(record) @@ -140,17 +143,17 @@ def _step(self, render, record): """ action, policy_next_state = self.agent.draw_action(self._state, self._policy_state) - next_state, reward, absorbing, step_info = self.mdp.step(action) + next_state, reward, absorbing, step_info = self.env.step(action) if render: - frame = self.mdp.render(record) + frame = self.env.render(record) if record: self._record(frame) self._episode_steps += 1 - last = self._episode_steps >= self.mdp.info.horizon or absorbing + last = self._episode_steps >= self.env.info.horizon or absorbing state = self._state policy_state = self._policy_state @@ -167,7 +170,7 @@ def _reset(self, initial_states): """ initial_state = self._core_logic.get_initial_state(initial_states) - state, episode_info = self.mdp.reset(initial_state) + state, episode_info = self.env.reset(initial_state) self._policy_state, self._current_theta = self.agent.episode_start(episode_info) self._state = self._preprocess(state) self.agent.next_action = None @@ -218,6 +221,6 @@ def _build_recorder_class(self, recorder_class=None, fps=None, **kwargs): recorder_class = VideoRecorder if not fps: - fps = int(1 / self.mdp.info.dt) + fps = int(1 / self.env.info.dt) return recorder_class(fps=fps, **kwargs) diff --git a/mushroom_rl/core/environment.py b/mushroom_rl/core/environment.py index 1f83ce58..e67a76fb 100644 --- a/mushroom_rl/core/environment.py +++ b/mushroom_rl/core/environment.py @@ -59,7 +59,7 @@ def shape(self): class Environment(object): """ - Basic interface used by any mushroom environment. + Basic interface used by any MushroomRL environment. """ @@ -142,13 +142,13 @@ def seed(self, seed): def reset(self, state=None): """ - Reset the current state. + Reset the environment to the initial state. Args: state (np.ndarray, None): the state to set to the current state. Returns: - The current state and a dictionary containing the info for the episode. + The initial state and a dictionary containing the info for the episode. """ raise NotImplementedError @@ -170,6 +170,8 @@ def step(self, action): def render(self, record=False): """ + Render the environment to screen. + Args: record (bool, False): whether the visualized image should be returned or not. @@ -181,7 +183,7 @@ def render(self, record=False): def stop(self): """ - Method used to stop an mdp. Useful when dealing with real world environments, simulators, or when using + Method used to stop an env. Useful when dealing with real world environments, simulators, or when using openai-gym rendering """ diff --git a/mushroom_rl/core/parallel_environment.py b/mushroom_rl/core/multiprocess_environment.py similarity index 67% rename from mushroom_rl/core/parallel_environment.py rename to mushroom_rl/core/multiprocess_environment.py index a1bc90d3..ca190e5e 100644 --- a/mushroom_rl/core/parallel_environment.py +++ b/mushroom_rl/core/multiprocess_environment.py @@ -1,10 +1,12 @@ from multiprocessing import Pipe from multiprocessing import Process +import numpy as np + from .vectorized_env import VectorizedEnvironment -def _parallel_env_worker(remote, env_class, use_generator, args, kwargs): +def _env_worker(remote, env_class, use_generator, args, kwargs): if use_generator: env = env_class.generate(*args, **kwargs) @@ -24,17 +26,19 @@ def _parallel_env_worker(remote, env_class, use_generator, args, kwargs): remote.send(res) elif cmd in 'stop': env.stop() + remote.send(None) elif cmd == 'info': remote.send(env.info) elif cmd == 'seed': env.seed(int(data)) + remote.send(None) else: raise NotImplementedError() finally: remote.close() -class ParallelEnvironment(VectorizedEnvironment): +class MultiprocessEnvironment(VectorizedEnvironment): """ Basic interface to run in parallel multiple copies of the same environment. This class assumes that the environments are homogeneus, i.e. have the same type and MDP info. @@ -55,9 +59,11 @@ def __init__(self, env_class, *args, n_envs=-1, use_generator=False, **kwargs): assert n_envs > 1 self._remotes, self._work_remotes = zip(*[Pipe() for _ in range(n_envs)]) - self._processes = [Process(target=_parallel_env_worker, - args=(work_remote, env_class, use_generator, args, kwargs)) - for work_remote in self._work_remotes] + self._processes = list() + + for work_remote in self._work_remotes: + worker_process = Process(target=_env_worker, args=(work_remote, env_class, use_generator, args, kwargs)) + self._processes.append(worker_process) for p in self._processes: p.start() @@ -72,24 +78,45 @@ def step_all(self, env_mask, action): if env_mask[i]: remote.send(('step', action[i, :])) - results = [] + states = list() + step_infos = list() for i, remote in enumerate(self._remotes): if env_mask[i]: - results.extend(remote.recv()) + state, step_info = remote.recv() + states.append(remote.recv()) + step_infos.append(step_info) - return zip(*results) # FIXME!!! + return np.array(states), step_infos def reset_all(self, env_mask, state=None): - for i in range(self._n_envs): - state_i = state[i, :] if state is not None else None - self._remotes[i].send(('reset', state_i)) + for i, remote in enumerate(self._remotes): + if env_mask[i]: + state_i = state[i, :] if state is not None else None + remote.send(('reset', state_i)) + + states = list() + episode_infos = list() + for i, remote in enumerate(self._remotes): + if env_mask[i]: + state, episode_info = remote.recv() + states.append(state) + episode_infos.append(episode_info) + + return np.array(states), episode_infos + + def render_all(self, env_mask, record=False): + for i, remote in enumerate(self._remotes): + if env_mask[i]: + remote.send(('render', record)) + + frames = list() - results = [] for i, remote in enumerate(self._remotes): if env_mask[i]: - results.extend(remote.recv()) + frame = remote.recv() + frames.append(frame) - return zip(*results) # FIXME!!! + return np.array(frames) def seed(self, seed): for remote in self._remotes: @@ -124,4 +151,4 @@ def generate(env, *args, n_envs=-1, **kwargs): """ use_generator = hasattr(env, 'generate') - return ParallelEnvironment(env, *args, n_envs=n_envs, use_generator=use_generator, **kwargs) \ No newline at end of file + return MultiprocessEnvironment(env, *args, n_envs=n_envs, use_generator=use_generator, **kwargs) \ No newline at end of file diff --git a/mushroom_rl/core/vectorized_core.py b/mushroom_rl/core/vectorized_core.py new file mode 100644 index 00000000..87e9020c --- /dev/null +++ b/mushroom_rl/core/vectorized_core.py @@ -0,0 +1,245 @@ +import numpy as np + +from mushroom_rl.core.dataset import Dataset +from mushroom_rl.utils.record import VideoRecorder + +from ._impl import VectorizedCoreLogic + + +class VectorCore(object): + """ + Implements the functions to run a generic algorithm. + + """ + + def __init__(self, agent, env, callbacks_fit=None, callback_step=None, record_dictionary=None): + """ + Constructor. + + Args: + agent (Agent): the agent moving according to a policy; + env (VectorEnvironment): the environment in which the agent moves; + callbacks_fit (list): list of callbacks to execute at the end of each fit; + callback_step (Callback): callback to execute after each step; + record_dictionary (dict, None): a dictionary of parameters for the recording, must containt the + recorder_class, fps, and optionally other keyword arguments to be passed to build the recorder class. + By default, the VideoRecorder class is used and the environment action frequency as frames per second. + + """ + self.agent = agent + self.env = env + self.callbacks_fit = callbacks_fit if callbacks_fit is not None else list() + self.callback_step = callback_step if callback_step is not None else lambda x: None + + self._state = None + self._policy_state = None + self._current_theta = None + self._episode_steps = None + + self._core_logic = VectorizedCoreLogic(self.env.number) + + if record_dictionary is None: + record_dictionary = dict() + self._record = [self._build_recorder_class(**record_dictionary) for _ in self.env.number] + + def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_per_fit=None, + render=False, record=False, quiet=False): + """ + This function moves the agent in the environment and fits the policy using the collected samples. + The agent can be moved for a given number of steps or a given number of episodes and, independently from this + choice, the policy can be fitted after a given number of steps or a given number of episodes. + The environment is reset at the beginning of the learning process. + + Args: + n_steps (int, None): number of steps to move the agent; + n_episodes (int, None): number of episodes to move the agent; + n_steps_per_fit (int, None): number of steps between each fit of the + policy; + n_episodes_per_fit (int, None): number of episodes between each fit + of the policy; + render (bool, False): whether to render the environment or not; + record (bool, False): whether to record a video of the environment or not. If True, also the render flag + should be set to True. + quiet (bool, False): whether to show the progress bar or not. + + """ + assert (render and record) or (not record), "To record, the render flag must be set to true" + self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit) + + datasets = [Dataset(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit) + for _ in self.env.number] + + self._run(datasets, n_steps, n_episodes, render, quiet, record) + + def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=False, quiet=False, record=False): + """ + This function moves the agent in the environment using its policy. + The agent is moved for a provided number of steps, episodes, or from a set of initial states for the whole + episode. The environment is reset at the beginning of the learning process. + + Args: + initial_states (np.ndarray, None): the starting states of each episode; + n_steps (int, None): number of steps to move the agent; + n_episodes (int, None): number of episodes to move the agent; + render (bool, False): whether to render the environment or not; + quiet (bool, False): whether to show the progress bar or not; + record (bool, False): whether to record a video of the environment or not. If True, also the render flag + should be set to True. + + Returns: + The collected dataset. + + """ + assert (render and record) or (not record), "To record, the render flag must be set to true" + + self._core_logic.initialize_evaluate() + + n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes + datasets = [Dataset(self.env.info, self.agent.info, n_steps, n_episodes_dataset) for _ in self.env.number] + + return self._run(datasets, n_steps, n_episodes, render, quiet, record, initial_states) + + def _run(self, datasets, n_steps, n_episodes, render, quiet, record, initial_states=None): + self._core_logic.initialize_run(n_steps, n_episodes, initial_states, quiet) + + last = None + while self._core_logic.move_required(): + action_mask = self._core_logic.get_action_mask() + last = np.logical_and(last, action_mask) + + if np.any(last): + self._reset(initial_states, last) + sample, step_info = self._step(render, record, action_mask) + + self.callback_step(sample) + + self._core_logic.after_step(np.logical_and(sample[5], action_mask)) + + samples = list(zip(*sample)) + for i in range(self.env.number): + if action_mask[i]: + datasets[i].append(samples[i]) + + if self._core_logic.fit_required(): + fit_dataset = self._aggregate(datasets) + self.agent.fit(fit_dataset) + self._core_logic.after_fit() + + for c in self.callbacks_fit: + c(datasets) + + for dataset in datasets: + dataset.clear() + + last = sample[5] + + self.agent.stop() + self.env.stop() + + self._end(record) + + return self._aggregate(datasets) + + def _step(self, render, record, action_mask): + """ + Single step. + + Args: + render (bool): whether to render or not. + + Returns: + A tuple containing the previous states, the actions sampled by the + agent, the rewards obtained, the reached states, the absorbing flags + of the reached states and the last step flags. + + """ + + action, policy_next_state = self.agent.draw_action(self._states[action_mask], self._policy_state[action_mask]) + + next_state, rewards, absorbing, step_info = self.env.step_all(action, action_mask) + + self._episode_steps += 1 + + if render: + self.env.render_all(action_mask, record=record) + + last = np.logical_or(absorbing, self._episode_steps >= self.env.info.horizon) + + state = self._state + policy_state = self._policy_state + next_state = self._preprocess(next_state) + self._state = next_state + self._policy_state = policy_next_state + + return (state, action, rewards, next_state, absorbing, last, policy_state, policy_next_state), step_info + + def _reset(self, initial_states, mask): + """ + Reset the states of the agent. + + """ + initial_state = self._core_logic.get_initial_state(initial_states) + # self.agent.episode_start(mask) FIXME + self.agent.episode_start() + + self._states = self._preprocess(self.env.reset_all(initial_state, mask)) + self.agent.next_action = None + self._episode_steps = np.multiply(self._episode_steps, np.logical_not(mask)) + + def _end(self, record): + self._state = None + self._policy_state = None + self._current_theta = None + self._episode_steps = None + + if record: + for record in self._record: + record.stop() + + self._core_logic.terminate_run() + + def _preprocess(self, states): + """ + Method to apply state preprocessors. + + Args: + states (Iterable of np.ndarray): the states to be preprocessed. + + Returns: + The preprocessed states. + + """ + for p in self.agent.preprocessors: + states = p(states) + + return states + + @staticmethod + def _aggregate(datasets): + aggregated_dataset = datasets[0] + + for dataset in datasets[1:]: + aggregated_dataset += dataset + + return aggregated_dataset + + def _build_recorder_class(self, recorder_class=None, fps=None, **kwargs): + """ + Method to create a video recorder class. + + Args: + recorder_class (class): the class used to record the video. By default, we use the ``VideoRecorder`` class + from mushroom. The class must implement the ``__call__`` and ``stop`` methods. + + Returns: + The recorder object. + + """ + + if not recorder_class: + recorder_class = VideoRecorder + + if not fps: + fps = int(1 / self.env.info.dt) + + return recorder_class(fps=fps, **kwargs) diff --git a/mushroom_rl/core/vectorized_env.py b/mushroom_rl/core/vectorized_env.py index ad5bf38a..d074c6f9 100644 --- a/mushroom_rl/core/vectorized_env.py +++ b/mushroom_rl/core/vectorized_env.py @@ -5,25 +5,84 @@ class VectorizedEnvironment(Environment): """ - Class to create a Mushroom environment using the PyBullet simulator. + Basic interface used by any MushroomRL vectorized environment. """ def __init__(self, mdp_info, n_envs): self._n_envs = n_envs + self._default_env = 0 + super().__init__(mdp_info) def reset(self, state=None): - env_mask = np.zeros(dtype=bool) - env_mask[0] = True + env_mask = np.zeros(self._n_envs, dtype=bool) + env_mask[self._default_env] = True return self.reset_all(env_mask, state) def step(self, action): - env_mask = np.zeros(dtype=bool) - env_mask[0] = True + env_mask = np.zeros(self._n_envs, dtype=bool) + env_mask[self._default_env] = True return self.step_all(env_mask, action) + def render(self, record=False): + env_mask = np.zeros(self._n_envs, dtype=bool) + env_mask[self._default_env] = True + + return self.render_all(env_mask, record=record) + + def reset_all(self, env_mask, state=None): + """ + Reset all the specified environments to the initial state. + + Args: + env_mask: mask specifying which environments needs reset. + state: set of initial states to impose to the environment. + + Returns: + The initial states of all environments and a listy of episode info dictionaries + + """ + raise NotImplementedError + def step_all(self, env_mask, action): + """ + Move all the specified agents from their current state according to the actions. + + Args: + env_mask: mask specifying which environments needs reset. + action: set of actions to execute. + + Returns: + The initial states of all environments and a listy of step info dictionaries + + """ raise NotImplementedError - def reset_all(self, env_mask, state=None): + def render_all(self, env_mask, record=False): + """ + Render all the specified environments to screen. + + Args: + record (bool, False): whether the visualized images should be returned or not. + + Returns: + The visualized images, or None if the record flag is set to false. + + """ raise NotImplementedError + + def set_default_env(self, id): + """ + Select the id of the default environment that will be executed with the default env interface. + + Args: + id (int): the number of the selected environment + + """ + assert id < self._n_envs, "The selected ID is higher than the available ones" + + self._default_env = id + + @property + def number(self): + return self._n_envs diff --git a/mushroom_rl/environments/lqr.py b/mushroom_rl/environments/lqr.py index 8b7fb404..2ebe4bcd 100644 --- a/mushroom_rl/environments/lqr.py +++ b/mushroom_rl/environments/lqr.py @@ -41,7 +41,7 @@ def __init__(self, A, B, Q, R, max_pos=np.inf, max_action=np.inf, random_init=F episodic (bool, False): end the episode when the state goes over the threshold; gamma (float, 0.9): discount factor; - horizon (int, 50): horizon of the mdp; + horizon (int, 50): horizon of the env; dt (float, 0.1): the control timestep of the environment. """ @@ -92,7 +92,7 @@ def generate(dimensions=None, s_dim=None, a_dim=None, max_pos=np.inf, max_action episodic (bool, False): end the episode when the state goes over the threshold; gamma (float, .9): discount factor; - horizon (int, 50): horizon of the mdp. + horizon (int, 50): horizon of the env. """ assert dimensions != None or (s_dim != None and a_dim != None) diff --git a/mushroom_rl/environments/ship_steering.py b/mushroom_rl/environments/ship_steering.py index 4502f4c8..69b6e556 100644 --- a/mushroom_rl/environments/ship_steering.py +++ b/mushroom_rl/environments/ship_steering.py @@ -19,7 +19,7 @@ def __init__(self, small=True, n_steps_action=3): Args: small (bool, True): whether to use a small state space or not. n_steps_action (int, 3): number of integration intervals for each - step of the mdp. + step of the env. """ # MDP parameters