diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index c11fd03b..e0016736 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -212,11 +212,14 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, dataset._array_backend = ArrayBackend.get_array_backend(backend) if backend == 'numpy': - dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state, policy_next_state) elif backend == 'torch': - dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state, policy_next_state) else: - dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state, policy_next_state) state_shape = states.shape[1:] action_shape = actions.shape[1:]