Skip to content

Commit

Permalink
Fixed packed sequences for numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 6, 2023
1 parent a5dfa3a commit e5de4fb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mushroom_rl/core/_impl/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def pack_padded_sequence(array, lengths):
shape = array.shape

new_shape = (shape[0] * shape[1],) + shape[2:]
mask = (np.arange(len(array))[:, None] < lengths[None, :]).flatten()
return array.reshape(new_shape)[mask]
mask = (np.arange(len(array))[:, None] < lengths[None, :]).flatten(order='F')
return array.reshape(new_shape, order='F')[mask]


class TorchBackend(ArrayBackend):
Expand Down

0 comments on commit e5de4fb

Please sign in to comment.