Skip to content

Commit

Permalink
Fixed dataset from_array
Browse files Browse the repository at this point in the history
- added missing policy state in the new dataset
  • Loading branch information
boris-il-forte committed Nov 26, 2024
1 parent 5c3834d commit 5395905
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,14 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts,

dataset._array_backend = ArrayBackend.get_array_backend(backend)
if backend == 'numpy':
dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
dataset._data = NumpyDataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
policy_state, policy_next_state)
elif backend == 'torch':
dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
dataset._data = TorchDataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
policy_state, policy_next_state)
else:
dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts)
dataset._data = ListDataset.from_array(states, actions, rewards, next_states, absorbings, lasts,
policy_state, policy_next_state)

state_shape = states.shape[1:]
action_shape = actions.shape[1:]
Expand Down

0 comments on commit 5395905

Please sign in to comment.