diff --git a/mushroom_rl/environments/gym_env.py b/mushroom_rl/environments/gym_env.py index 191d8c84..1d9a43be 100644 --- a/mushroom_rl/environments/gym_env.py +++ b/mushroom_rl/environments/gym_env.py @@ -81,7 +81,7 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args= def reset(self, state=None): if state is None: - return np.atleast_1d(self.env.reset()), {} + return np.atleast_1d(self.env.reset()[0]), {} else: self.env.reset() self.env.state = state @@ -90,7 +90,7 @@ def reset(self, state=None): def step(self, action): action = self._convert_action(action) - obs, reward, absorbing, info = self.env.step(action) + obs, reward, absorbing, _, info = self.env.step(action) return np.atleast_1d(obs), reward, absorbing, info