From 5395905ace5323fc4b2566cd0eb8b43b9d637696 Mon Sep 17 00:00:00 2001 From: boris-il-forte Date: Tue, 26 Nov 2024 15:44:14 +0100 Subject: [PATCH] Fixed dataset from_array - added missing policy state in the new dataset --- mushroom_rl/core/dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index c11fd03b..e0016736 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -212,11 +212,14 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, dataset._array_backend = ArrayBackend.get_array_backend(backend) if backend == 'numpy': - dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state, policy_next_state) elif backend == 'torch': - dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state, policy_next_state) else: - dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts) + dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts, + policy_state, policy_next_state) state_shape = states.shape[1:] action_shape = actions.shape[1:]