Skip to content

Commit

Permalink
Update ExtraInfo
Browse files Browse the repository at this point in the history
- use arraybackend.flatten when no mask and pack_padded_sequence with mask
- combine key_mapping and shape_mapping without union operator
  • Loading branch information
Bjarne-55 committed Oct 15, 2024
1 parent 9fbabf9 commit 2b84773
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions mushroom_rl/core/extra_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,16 @@ def flatten(self, mask=None):
info._structured_storage = {}

for key in self.data:
info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask)
if mask is None:
info.data[key] = info._array_backend.flatten(self.data[key])
else:
info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask)

for key in self._structured_storage:
info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask)
if mask is None:
info._structured_storage[key] = info._array_backend.flatten(self._structured_storage[key])
else:
info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask)

return info

Expand All @@ -135,8 +141,13 @@ def __add__(self, other):
info._structured_storage = self._concatenate_dictionary(self._structured_storage, other._structured_storage, self._array_backend, other._array_backend)
info.data = self._concatenate_dictionary(self.data, other.data, self._array_backend, other._array_backend)

info._key_mapping = self._key_mapping | other._key_mapping
info._shape_mapping = self._shape_mapping | other._shape_mapping
#combine key_mapping
info._key_mapping = self._key_mapping.copy()
info._key_mapping.update(other._key_mapping)

#combine shape_mapping
info._shape_mapping = self._shape_mapping.copy()
info._shape_mapping.update(other._shape_mapping)

return info

Expand Down

0 comments on commit 2b84773

Please sign in to comment.