From 2b847734fb143844273f40e597ee7c2560718ca6 Mon Sep 17 00:00:00 2001 From: Bjarne-55 <73470930+Bjarne-55@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:23:06 +0200 Subject: [PATCH] Update ExtraInfo - use arraybackend.flatten when no mask and pack_padded_sequence with mask - combine key_mapping and shape_mapping without union operator --- mushroom_rl/core/extra_info.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/mushroom_rl/core/extra_info.py b/mushroom_rl/core/extra_info.py index bcfcea4a..1b1faf75 100644 --- a/mushroom_rl/core/extra_info.py +++ b/mushroom_rl/core/extra_info.py @@ -113,10 +113,16 @@ def flatten(self, mask=None): info._structured_storage = {} for key in self.data: - info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask) + if mask is None: + info.data[key] = info._array_backend.flatten(self.data[key]) + else: + info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask) for key in self._structured_storage: - info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask) + if mask is None: + info._structured_storage[key] = info._array_backend.flatten(self._structured_storage[key]) + else: + info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask) return info @@ -135,8 +141,13 @@ def __add__(self, other): info._structured_storage = self._concatenate_dictionary(self._structured_storage, other._structured_storage, self._array_backend, other._array_backend) info.data = self._concatenate_dictionary(self.data, other.data, self._array_backend, other._array_backend) - info._key_mapping = self._key_mapping | other._key_mapping - info._shape_mapping = self._shape_mapping | other._shape_mapping + #combine key_mapping + info._key_mapping = self._key_mapping.copy() + info._key_mapping.update(other._key_mapping) + + #combine shape_mapping + info._shape_mapping = self._shape_mapping.copy() + info._shape_mapping.update(other._shape_mapping) return info