Skip to content

Commit

Permalink
Started implementation of vectorized core
Browse files Browse the repository at this point in the history
- refactoring of core logic to support vectorized environments
- implemented vectorized core, needs to be tested and debugged
- renamed parallel environments into MultiprocessEnvironment, completed
implementation (to be tested)
- refactoring of core.mdp into core.env
  • Loading branch information
boris-il-forte committed Oct 28, 2023
1 parent cad5c75 commit 5e74492
Show file tree
Hide file tree
Showing 10 changed files with 468 additions and 58 deletions.
3 changes: 2 additions & 1 deletion mushroom_rl/core/_impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .torch_dataset import TorchDataset
from .list_dataset import ListDataset
from .type_conversions import DataConversion, NumpyConversion, TorchConversion, ListConversion
from .core_logic import CoreLogic
from .core_logic import CoreLogic
from .vectorized_core_logic import VectorizedCoreLogic
45 changes: 31 additions & 14 deletions mushroom_rl/core/_impl/core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@ def __init__(self):
self.fit_required = None
self.move_required = None

self._total_episodes_counter = 0
self._total_steps_counter = 0
self._current_episodes_counter = 0
self._current_steps_counter = 0
self._total_episodes_counter = None
self._total_steps_counter = None
self._current_episodes_counter = None
self._current_steps_counter = None

self._n_steps = None
self._n_episodes = None
self._n_steps_per_fit = None
self._n_episodes_per_fit = None

self._steps_progress_bar = None
self._episodes_progress_bar = None

def initialize_fit(self, n_steps_per_fit, n_episodes_per_fit):
def initialize_learn(self, n_steps_per_fit, n_episodes_per_fit):
assert (n_episodes_per_fit is not None and n_steps_per_fit is None) \
or (n_episodes_per_fit is None and n_steps_per_fit is not None)

self._n_steps_per_fit = n_steps_per_fit
self._n_episodes_per_fit = n_episodes_per_fit

if n_steps_per_fit is not None:
self.fit_required = lambda: self._current_steps_counter >= self._n_steps_per_fit
self.fit_required = self._fit_steps_condition
else:
self.fit_required = lambda: self._current_episodes_counter >= self._n_episodes_per_fit
self.fit_required = self._fit_episodes_condition

def initialize_evaluate(self):
self.fit_required = lambda: False
Expand All @@ -38,23 +39,21 @@ def initialize_run(self, n_steps, n_episodes, initial_states, quiet):
or n_episodes is None and n_steps is not None and initial_states is None\
or n_episodes is None and n_steps is None and initial_states is not None

self._n_steps = n_steps
self._n_episodes = len(initial_states) if initial_states is not None else n_episodes

if n_steps is not None:
self.move_required = lambda: self._total_steps_counter < n_steps
self.move_required = self._move_steps_condition

self._steps_progress_bar = tqdm(total=n_steps, dynamic_ncols=True, disable=quiet, leave=False)
self._episodes_progress_bar = tqdm(disable=True)
else:
self.move_required = lambda: self._total_episodes_counter < self._n_episodes
self.move_required = self._move_episodes_condition

self._steps_progress_bar = tqdm(disable=True)
self._episodes_progress_bar = tqdm(total=self._n_episodes, dynamic_ncols=True, disable=quiet, leave=False)

self._total_episodes_counter = 0
self._total_steps_counter = 0
self._current_episodes_counter = 0
self._current_steps_counter = 0
self._reset_counters()

def get_initial_state(self, initial_states):
if initial_states is None or self._total_episodes_counter == self._n_episodes:
Expand All @@ -78,4 +77,22 @@ def after_fit(self):

def terminate_run(self):
self._steps_progress_bar.close()
self._episodes_progress_bar.close()
self._episodes_progress_bar.close()

def _reset_counters(self):
self._total_episodes_counter = 0
self._total_steps_counter = 0
self._current_episodes_counter = 0
self._current_steps_counter = 0

def _move_steps_condition(self):
return self._total_steps_counter < self._n_steps

def _fit_steps_condition(self):
return self._current_steps_counter >= self._n_steps_per_fit

def _move_episodes_condition(self):
return self._total_episodes_counter < self._n_episodes

def _fit_episodes_condition(self):
return self._current_episodes_counter >= self._n_episodes_per_fit
56 changes: 56 additions & 0 deletions mushroom_rl/core/_impl/vectorized_core_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np

from .core_logic import CoreLogic


class VectorizedCoreLogic(CoreLogic):
def __init__(self, n_envs):
self._n_envs = n_envs

super().__init__()

def get_action_mask(self):
action_mask = np.ones(self._n_envs, dtype=bool)

if self._n_episodes is not None:
if self._n_episodes_per_fit is not None:
action_mask = self._current_episodes_counter != self._n_episodes_per_fit
else:
action_mask = self._current_episodes_counter != self._n_episodes

return action_mask

def get_initial_state(self, initial_states):

if initial_states is None or np.all(self._total_episodes_counter == self._n_episodes):
initial_state = None
else:
initial_state = initial_states[self._total_episodes_counter] # FIXME

return initial_state

def after_step(self, last):
self._total_steps_counter += self._n_envs
self._current_steps_counter += self._n_envs
self._steps_progress_bar.update(self._n_envs)

completed = last.sum()
self._total_episodes_counter += completed
self._current_episodes_counter += completed
self._episodes_progress_bar.update(completed)

def after_fit(self):
self._current_episodes_counter = np.zeros(self._n_envs, dtype=int)
self._current_steps_counter = 0

def _reset_counters(self):
self._total_episodes_counter = np.zeros(self._n_envs, dtype=int)
self._current_episodes_counter = np.zeros(self._n_envs, dtype=int)
self._total_steps_counter = 0
self._current_steps_counter = 0

def _move_episodes_condition(self):
return np.sum(self._total_episodes_counter) < self._n_episodes

def _fit_episodes_condition(self):
return np.sum(self._current_episodes_counter) >= self._n_episodes_per_fit
33 changes: 18 additions & 15 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ class Core(object):
Implements the functions to run a generic algorithm.
"""
def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None, record_dictionary=None):
def __init__(self, agent, env, callbacks_fit=None, callback_step=None, record_dictionary=None):
"""
Constructor.
Args:
agent (Agent): the agent moving according to a policy;
mdp (Environment): the environment in which the agent moves;
env (Environment): the environment in which the agent moves;
callbacks_fit (list): list of callbacks to execute at the end of each fit;
callback_step (Callback): callback to execute after each step;
record_dictionary (dict, None): a dictionary of parameters for the recording, must containt the
recorder_class, fps, and optionally other keyword arguments to be passed to build the recorder class.
By default, the VideoRecorder class is used and the environment action frequency as frames per second.
"""
self.agent = agent
self.mdp = mdp
self.env = env
self.callbacks_fit = callbacks_fit if callbacks_fit is not None else list()
self.callback_step = callback_step if callback_step is not None else lambda x: None

Expand All @@ -36,8 +39,8 @@ def __init__(self, agent, mdp, callbacks_fit=None, callback_step=None, record_di
record_dictionary = dict()
self._record = self._build_recorder_class(**record_dictionary)

def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None,
n_episodes_per_fit=None, render=False, quiet=False, record=False):
def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None, n_episodes_per_fit=None,
render=False, record=False, quiet=False):
"""
This function moves the agent in the environment and fits the policy using the collected samples.
The agent can be moved for a given number of steps or a given number of episodes and, independently from this
Expand All @@ -52,15 +55,15 @@ def learn(self, n_steps=None, n_episodes=None, n_steps_per_fit=None,
n_episodes_per_fit (int, None): number of episodes between each fit
of the policy;
render (bool, False): whether to render the environment or not;
quiet (bool, False): whether to show the progress bar or not;
record (bool, False): whether to record a video of the environment or not. If True, also the render flag
should be set to True.
quiet (bool, False): whether to show the progress bar or not.
"""
assert (render and record) or (not record), "To record, the render flag must be set to true"
self._core_logic.initialize_fit(n_steps_per_fit, n_episodes_per_fit)
self._core_logic.initialize_learn(n_steps_per_fit, n_episodes_per_fit)

dataset = Dataset(self.mdp.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit)
dataset = Dataset(self.env.info, self.agent.info, n_steps_per_fit, n_episodes_per_fit)

self._run(dataset, n_steps, n_episodes, render, quiet, record)

Expand Down Expand Up @@ -88,7 +91,7 @@ def evaluate(self, initial_states=None, n_steps=None, n_episodes=None, render=Fa
self._core_logic.initialize_evaluate()

n_episodes_dataset = len(initial_states) if initial_states is not None else n_episodes
dataset = Dataset(self.mdp.info, self.agent.info, n_steps, n_episodes_dataset)
dataset = Dataset(self.env.info, self.agent.info, n_steps, n_episodes_dataset)

return self._run(dataset, n_steps, n_episodes, render, quiet, record, initial_states)

Expand Down Expand Up @@ -121,7 +124,7 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat
last = sample[5]

self.agent.stop()
self.mdp.stop()
self.env.stop()

self._end(record)

Expand All @@ -140,17 +143,17 @@ def _step(self, render, record):
"""
action, policy_next_state = self.agent.draw_action(self._state, self._policy_state)
next_state, reward, absorbing, step_info = self.mdp.step(action)
next_state, reward, absorbing, step_info = self.env.step(action)

if render:
frame = self.mdp.render(record)
frame = self.env.render(record)

if record:
self._record(frame)

self._episode_steps += 1

last = self._episode_steps >= self.mdp.info.horizon or absorbing
last = self._episode_steps >= self.env.info.horizon or absorbing

state = self._state
policy_state = self._policy_state
Expand All @@ -167,7 +170,7 @@ def _reset(self, initial_states):
"""
initial_state = self._core_logic.get_initial_state(initial_states)

state, episode_info = self.mdp.reset(initial_state)
state, episode_info = self.env.reset(initial_state)
self._policy_state, self._current_theta = self.agent.episode_start(episode_info)
self._state = self._preprocess(state)
self.agent.next_action = None
Expand Down Expand Up @@ -218,6 +221,6 @@ def _build_recorder_class(self, recorder_class=None, fps=None, **kwargs):
recorder_class = VideoRecorder

if not fps:
fps = int(1 / self.mdp.info.dt)
fps = int(1 / self.env.info.dt)

return recorder_class(fps=fps, **kwargs)
10 changes: 6 additions & 4 deletions mushroom_rl/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def shape(self):

class Environment(object):
"""
Basic interface used by any mushroom environment.
Basic interface used by any MushroomRL environment.
"""

Expand Down Expand Up @@ -142,13 +142,13 @@ def seed(self, seed):

def reset(self, state=None):
"""
Reset the current state.
Reset the environment to the initial state.
Args:
state (np.ndarray, None): the state to set to the current state.
Returns:
The current state and a dictionary containing the info for the episode.
The initial state and a dictionary containing the info for the episode.
"""
raise NotImplementedError
Expand All @@ -170,6 +170,8 @@ def step(self, action):

def render(self, record=False):
"""
Render the environment to screen.
Args:
record (bool, False): whether the visualized image should be returned or not.
Expand All @@ -181,7 +183,7 @@ def render(self, record=False):

def stop(self):
"""
Method used to stop an mdp. Useful when dealing with real world environments, simulators, or when using
Method used to stop an env. Useful when dealing with real world environments, simulators, or when using
openai-gym rendering
"""
Expand Down
Loading

0 comments on commit 5e74492

Please sign in to comment.