Skip to content

Commit

Permalink
Updated preprocessor interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
robfiras committed Jan 18, 2024
1 parent 1eaa5e3 commit 1d2f04a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
1 change: 1 addition & 0 deletions mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def _preprocess(self, state):
"""
for p in self.agent.preprocessors:
p.update(state)
state = p(state)

return state
Expand Down
1 change: 1 addition & 0 deletions mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _preprocess(self, states):
"""
for p in self.agent.preprocessors:
p.update(states)
states = p(states)

return states
Expand Down
57 changes: 35 additions & 22 deletions mushroom_rl/rl_utils/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,38 @@
from mushroom_rl.rl_utils.running_stats import RunningStandardization


class StandardizationPreprocessor(Serializable):
class Preprocessor(Serializable):
"""
Abstract preprocessor class.
"""
def __call__(self, obs):
"""
Preprocess the observations.
Args:
obs (Array): observations to be preprocessed.
Return:
Preprocessed observations.
"""
# TODO: Support vectorized environment and batch preprocessing.
raise NotImplementedError

def update(self, obs):
"""
Update internal state of the preprocessor using the current observations.
Args:
obs (Array): observations to be preprocessed.
"""
# TODO: Support vectorized environment and batch update.
pass


class StandardizationPreprocessor(Preprocessor):
"""
Preprocess observations from the environment using a running
standardization.
Expand Down Expand Up @@ -33,29 +64,21 @@ def __init__(self, mdp_info, clip_obs=10., alpha=1e-32):
)

def __call__(self, obs):
"""
Call function to normalize the observation.
Args:
obs (np.ndarray): observation to be normalized.
Returns:
Normalized observation array with the same shape.
"""
assert obs.shape == self._obs_shape, \
"Values given to running_norm have incorrect shape " \
"(obs shape: {}, expected shape: {})" \
.format(obs.shape, self._obs_shape)

self._obs_runstand.update_stats(obs)
norm_obs = np.clip(
(obs - self._obs_runstand.mean) / self._obs_runstand.std,
-self._clip_obs, self._clip_obs
)

return norm_obs

def update(self, obs):
self._obs_runstand.update_stats(obs)


class MinMaxPreprocessor(StandardizationPreprocessor):
"""
Expand Down Expand Up @@ -104,16 +127,6 @@ def __init__(self, mdp_info, clip_obs=10., alpha=1e-32):
)

def __call__(self, obs):
"""
Call function to normalize the observation.
Args:
obs (np.ndarray): observation to be normalized.
Returns:
Normalized observation array with the same shape.
"""
orig_obs = obs.copy()

if self._run_norm_obs:
Expand Down

0 comments on commit 1d2f04a

Please sign in to comment.