Skip to content

Commit

Permalink
Update ExtraInfo
Browse files Browse the repository at this point in the history
- Apply mask permanently in  flatten
- Remove mask from parse
  • Loading branch information
Bjarne-55 committed Oct 10, 2024
1 parent 1e6764a commit 4e92594
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
14 changes: 2 additions & 12 deletions mushroom_rl/core/extra_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ def append(self, info):

self._storage.append(info)

def parse(self, to=None, mask=None):
def parse(self, to=None):
"""
Parse the stored information into an flat dictionary of arrays
Args:
to (str): the backend to be used for the returned arrays, 'torch' or 'numpy'.
mask
Returns:
dict: Flat dictionary containing an array for every property of the step information
Expand Down Expand Up @@ -93,11 +92,6 @@ def parse(self, to=None, mask=None):
self._structured_storage = {key: value for key, value in output.items()}
self._storage = []
self._array_backend = target_backend

#apply mask on arrays in output
if mask is not None:
for key, value in output.items():
output[key] = value[mask]

self.data = output

Expand All @@ -122,11 +116,7 @@ def flatten(self, mask=None):
info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask)

for key in self._structured_storage:
info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key])

if mask is not None:
for key, value in self.data.items():
self.data[key] = value[mask]
info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask)

return info

Expand Down
47 changes: 46 additions & 1 deletion tests/core/test_extra_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,49 @@ def test_clear():
info.append(data2)
info.parse()
info.clear()
assert(not info)
assert(not info)

def test_flatten_with_mask():
info = ExtraInfo(5, 'numpy')
data1 = {
'prop1': np.arange(100, 105),
'prop2': np.arange(200, 205)
}
data2 = {
'prop1': np.arange(110, 115),
'prop2': np.arange(210, 215)
}
info.append(data1)
info.append(data2)
mask = np.array([True, True, False, False, False, True, False, False, True, False])
info = info.flatten(mask)

assert(len(info) == 2)

assert("prop1" in info)
assert("prop2" in info)

assert(isinstance(info["prop1"], np.ndarray))
assert(isinstance(info["prop2"], np.ndarray))

assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 4)
assert(info["prop2"].ndim == 1 and info["prop2"].shape[0] == 4)

assert np.array_equal(np.array([100, 110, 112, 104]), info["prop1"])
assert np.array_equal(np.array([200, 210, 212, 204]), info["prop2"])

#Test if mask is permantly applied
info.parse()
assert(len(info) == 2)

assert("prop1" in info)
assert("prop2" in info)

assert(isinstance(info["prop1"], np.ndarray))
assert(isinstance(info["prop2"], np.ndarray))

assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 4)
assert(info["prop2"].ndim == 1 and info["prop2"].shape[0] == 4)

assert np.array_equal(np.array([100, 110, 112, 104]), info["prop1"])
assert np.array_equal(np.array([200, 210, 212, 204]), info["prop2"])

0 comments on commit 4e92594

Please sign in to comment.