From 4e9259462fd16ac82bdbea432702a666f737ade5 Mon Sep 17 00:00:00 2001 From: Bjarne-55 <73470930+Bjarne-55@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:20:41 +0200 Subject: [PATCH] Update ExtraInfo - Apply mask permanently in flatten - Remove mask from parse --- mushroom_rl/core/extra_info.py | 14 ++-------- tests/core/test_extra_info.py | 47 +++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/mushroom_rl/core/extra_info.py b/mushroom_rl/core/extra_info.py index f9160403..8e0be0af 100644 --- a/mushroom_rl/core/extra_info.py +++ b/mushroom_rl/core/extra_info.py @@ -39,13 +39,12 @@ def append(self, info): self._storage.append(info) - def parse(self, to=None, mask=None): + def parse(self, to=None): """ Parse the stored information into an flat dictionary of arrays Args: to (str): the backend to be used for the returned arrays, 'torch' or 'numpy'. - mask Returns: dict: Flat dictionary containing an array for every property of the step information @@ -93,11 +92,6 @@ def parse(self, to=None, mask=None): self._structured_storage = {key: value for key, value in output.items()} self._storage = [] self._array_backend = target_backend - - #apply mask on arrays in output - if mask is not None: - for key, value in output.items(): - output[key] = value[mask] self.data = output @@ -122,11 +116,7 @@ def flatten(self, mask=None): info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask) for key in self._structured_storage: - info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key]) - - if mask is not None: - for key, value in self.data.items(): - self.data[key] = value[mask] + info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask) return info diff --git a/tests/core/test_extra_info.py b/tests/core/test_extra_info.py index 1a7fd50e..51c22b25 100644 --- a/tests/core/test_extra_info.py +++ b/tests/core/test_extra_info.py @@ -405,4 +405,49 @@ def test_clear(): info.append(data2) info.parse() info.clear() - assert(not info) \ No newline at end of file + assert(not info) + +def test_flatten_with_mask(): + info = ExtraInfo(5, 'numpy') + data1 = { + 'prop1': np.arange(100, 105), + 'prop2': np.arange(200, 205) + } + data2 = { + 'prop1': np.arange(110, 115), + 'prop2': np.arange(210, 215) + } + info.append(data1) + info.append(data2) + mask = np.array([True, True, False, False, False, True, False, False, True, False]) + info = info.flatten(mask) + + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop2" in info) + + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 4) + assert(info["prop2"].ndim == 1 and info["prop2"].shape[0] == 4) + + assert np.array_equal(np.array([100, 110, 112, 104]), info["prop1"]) + assert np.array_equal(np.array([200, 210, 212, 204]), info["prop2"]) + + #Test if mask is permantly applied + info.parse() + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop2" in info) + + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 4) + assert(info["prop2"].ndim == 1 and info["prop2"].shape[0] == 4) + + assert np.array_equal(np.array([100, 110, 112, 104]), info["prop1"]) + assert np.array_equal(np.array([200, 210, 212, 204]), info["prop2"])