diff --git a/mushroom_rl/core/extra_info.py b/mushroom_rl/core/extra_info.py index bcfcea4a..1b1faf75 100644 --- a/mushroom_rl/core/extra_info.py +++ b/mushroom_rl/core/extra_info.py @@ -113,10 +113,16 @@ def flatten(self, mask=None): info._structured_storage = {} for key in self.data: - info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask) + if mask is None: + info.data[key] = info._array_backend.flatten(self.data[key]) + else: + info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask) for key in self._structured_storage: - info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask) + if mask is None: + info._structured_storage[key] = info._array_backend.flatten(self._structured_storage[key]) + else: + info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask) return info @@ -135,8 +141,13 @@ def __add__(self, other): info._structured_storage = self._concatenate_dictionary(self._structured_storage, other._structured_storage, self._array_backend, other._array_backend) info.data = self._concatenate_dictionary(self.data, other.data, self._array_backend, other._array_backend) - info._key_mapping = self._key_mapping | other._key_mapping - info._shape_mapping = self._shape_mapping | other._shape_mapping + #combine key_mapping + info._key_mapping = self._key_mapping.copy() + info._key_mapping.update(other._key_mapping) + + #combine shape_mapping + info._shape_mapping = self._shape_mapping.copy() + info._shape_mapping.update(other._shape_mapping) return info