diff --git a/examples/isaac_example.py b/examples/isaac_example.py index 5e87af8d..5f9728af 100644 --- a/examples/isaac_example.py +++ b/examples/isaac_example.py @@ -79,8 +79,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep dataset = core.evaluate(n_episodes=n_episodes_test, render=False) - J = torch.mean(torch.stack(dataset.discounted_return)) - R = torch.mean(torch.stack(dataset.undiscounted_return)) + J = torch.mean(dataset.discounted_return) + R = torch.mean(dataset.undiscounted_return) E = agent.policy.entropy() logger.epoch_info(0, J=J, R=R, entropy=E) @@ -89,8 +89,8 @@ def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_ep core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit) dataset = core.evaluate(n_episodes=n_episodes_test, render=False) - J = torch.mean(torch.stack(dataset.discounted_return)) - R = torch.mean(torch.stack(dataset.undiscounted_return)) + J = torch.mean(dataset.discounted_return) + R = torch.mean(dataset.undiscounted_return) E = agent.policy.entropy() logger.epoch_info(it+1, J=J, R=R, entropy=E) diff --git a/mushroom_rl/core/__init__.py b/mushroom_rl/core/__init__.py index 814fc9f7..7f1e51d5 100644 --- a/mushroom_rl/core/__init__.py +++ b/mushroom_rl/core/__init__.py @@ -6,6 +6,8 @@ from .serialization import Serializable from .logger import Logger +from .extra_info import ExtraInfo + from .vectorized_core import VectorCore from .vectorized_env import VectorizedEnvironment from .multiprocess_environment import MultiprocessEnvironment @@ -13,4 +15,4 @@ import mushroom_rl.environments __all__ = ['ArrayBackend', 'Core', 'DatasetInfo', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', - 'Serializable', 'Logger', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment'] + 'Serializable', 'Logger', 'ExtraInfo', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment'] diff --git a/mushroom_rl/core/array_backend.py b/mushroom_rl/core/array_backend.py index c3b6afdf..706357c4 100644 --- a/mushroom_rl/core/array_backend.py +++ b/mushroom_rl/core/array_backend.py @@ -147,6 +147,26 @@ def from_list(array): @staticmethod def pack_padded_sequence(array, mask): raise NotImplementedError + + @staticmethod + def flatten(array): + raise NotImplementedError + + @staticmethod + def empty(shape, device=None): + raise NotImplementedError + + @staticmethod + def none(): + raise NotImplementedError + + @staticmethod + def shape(array): + raise NotImplementedError + + @staticmethod + def full(shape, value): + raise NotImplementedError class NumpyBackend(ArrayBackend): @@ -253,6 +273,28 @@ def pack_padded_sequence(array, mask): new_shape = (shape[0] * shape[1],) + shape[2:] return array.reshape(new_shape, order='F')[mask.flatten(order='F')] + + @staticmethod + def flatten(array): + shape = array.shape + new_shape = (shape[0] * shape[1],) + shape[2:] + return array.reshape(new_shape, order='F') + + @staticmethod + def empty(shape, device=None): + return np.empty(shape) + + @staticmethod + def none(): + return np.nan + + @staticmethod + def shape(array): + return array.shape + + @staticmethod + def full(shape, value): + return np.full(shape, value) class TorchBackend(ArrayBackend): @@ -364,9 +406,31 @@ def pack_padded_sequence(array, mask): shape = array.shape new_shape = (shape[0]*shape[1], ) + shape[2:] - + return array.transpose(0, 1).reshape(new_shape)[mask.transpose(0, 1).flatten()] + @staticmethod + def flatten(array): + shape = array.shape + new_shape = (shape[0]*shape[1], ) + shape[2:] + return array.transpose(0, 1).reshape(new_shape) + + @staticmethod + def empty(shape, device=None): + device = TorchUtils.get_device() if device is None else device + return torch.empty(shape, device=device) + + @staticmethod + def none(): + return torch.nan + + @staticmethod + def shape(array): + return array.shape + + @staticmethod + def full(shape, value): + return torch.full(shape, value) class ListBackend(ArrayBackend): @@ -421,3 +485,23 @@ def from_list(array): @staticmethod def pack_padded_sequence(array, mask): return NumpyBackend.pack_padded_sequence(array, np.array(mask)) + + @staticmethod + def flatten(array): + return NumpyBackend.flatten(array) + + @staticmethod + def empty(shape, device=None): + return np.empty(shape) + + @staticmethod + def none(): + return None + + @staticmethod + def shape(array): + return np.array(array).shape + + @staticmethod + def full(shape, value): + return np.full(shape, value) \ No newline at end of file diff --git a/mushroom_rl/core/core.py b/mushroom_rl/core/core.py index 3dffe448..3478029c 100644 --- a/mushroom_rl/core/core.py +++ b/mushroom_rl/core/core.py @@ -128,6 +128,8 @@ def _run(self, dataset, n_steps, n_episodes, render, quiet, record, initial_stat self._end(record) + dataset.info.parse() + dataset.episode_info.parse() return dataset def _step(self, render, record): diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index ab903b0b..c11fd03b 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -6,6 +6,7 @@ from mushroom_rl.core.serialization import Serializable from .array_backend import ArrayBackend +from .extra_info import ExtraInfo from ._impl import * @@ -103,8 +104,8 @@ def __init__(self, dataset_info, n_steps=None, n_episodes=None): else: policy_state_shape = None - self._info = defaultdict(list) - self._episode_info = defaultdict(list) + self._info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device) + self._episode_info = ExtraInfo(dataset_info.n_envs, dataset_info.backend, dataset_info.device) self._theta_list = list() if dataset_info.backend == 'numpy': @@ -195,12 +196,12 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, dataset = cls.create_raw_instance() if info is None: - dataset._info = defaultdict(list) + dataset._info = ExtraInfo(1, backend) else: dataset._info = info.copy() if episode_info is None: - dataset._episode_info = defaultdict(list) + dataset._episode_info = ExtraInfo(1, backend) else: dataset._episode_info = episode_info.copy() @@ -228,7 +229,7 @@ def from_array(cls, states, actions, rewards, next_states, absorbings, lasts, def append(self, step, info): self._data.append(*step) - self._append_info(self._info, info) + self._info.append(info) def append_episode_info(self, info): self._append_info(self._episode_info, info) @@ -243,21 +244,17 @@ def get_info(self, field, index=None): return self._info[field][index] def clear(self): - self._episode_info = defaultdict(list) + self._episode_info.clear() self._theta_list = list() - self._info = defaultdict(list) + self._info.clear() self._data.clear() def get_view(self, index, copy=False): dataset = self.create_raw_instance(dataset=self) - info_slice = defaultdict(list) - for key in self._info.keys(): - info_slice[key] = self._info[key][index] - - dataset._info = info_slice - dataset._episode_info = defaultdict(list) + dataset._info = self._info.get_view(index, copy) + dataset._episode_info = self._episode_info.get_view(index, copy) dataset._data = self._data.get_view(index, copy) return dataset @@ -276,11 +273,9 @@ def __getitem__(self, index): def __add__(self, other): result = self.create_raw_instance(dataset=self) - new_info = self._merge_info(self.info, other.info) - new_episode_info = self._merge_info(self.episode_info, other.episode_info) - result._info = new_info - result._episode_info = new_episode_info + result._info = self._info + other._info + result._episode_info = self._episode_info + other._episode_info result._theta_list = self._theta_list + other._theta_list result._data = self._data + other._data @@ -525,8 +520,8 @@ def _convert(self, *arrays, to='numpy'): def _add_all_save_attr(self): self._add_save_attr( - _info='pickle', - _episode_info='pickle', + _info='mushroom', + _episode_info='mushroom', _theta_list='pickle', _data='mushroom', _array_backend='primitive', @@ -557,7 +552,7 @@ def append(self, step, info): def append_vectorized(self, step, info, mask): self._data.append(*step, mask=mask) - self._append_info(self._info, {}) # FIXME: handle properly info + self._info.append(info) def append_theta_vectorized(self, theta, mask): for i in range(len(theta)): @@ -581,11 +576,16 @@ def clear(self, n_steps_per_fit=None): mask.flatten()[n_extra_steps:] = False residual_data.mask = mask.reshape(original_shape) + residual_info = self._info.get_view(view_size, copy=True) + residual_episode_info = self._episode_info.get_view(view_size, copy=True) + super().clear() self._initialize_theta_list(n_envs) if n_steps_per_fit is not None and residual_data is not None: self._data = residual_data + self._info = residual_info + self._episode_info = residual_episode_info def flatten(self, n_steps_per_fit=None): if len(self) == 0: @@ -622,9 +622,12 @@ def flatten(self, n_steps_per_fit=None): flat_theta_list = self._flatten_theta_list() + flat_info = self._info.flatten(self.mask) + flat_episode_info = self._episode_info.flatten(self.mask) + return Dataset.from_array(states, actions, rewards, next_states, absorbings, lasts, policy_state=policy_state, policy_next_state=policy_next_state, - info=None, episode_info=None, theta_list=flat_theta_list, # FIXME: handle properly info + info=flat_info, episode_info=flat_episode_info, theta_list=flat_theta_list, horizon=self._dataset_info.horizon, gamma=self._dataset_info.gamma, backend=self._array_backend.get_backend_name()) diff --git a/mushroom_rl/core/extra_info.py b/mushroom_rl/core/extra_info.py new file mode 100644 index 00000000..1b1faf75 --- /dev/null +++ b/mushroom_rl/core/extra_info.py @@ -0,0 +1,412 @@ +from collections import UserDict +import numbers +from .array_backend import ArrayBackend +from mushroom_rl.core.serialization import Serializable + +class ExtraInfo(Serializable, UserDict): + """ + A class to to collect and parse step information + """ + def __init__(self, n_envs, backend, device=None): + """ + Constructor. + + Args: + n_envs (int): Number of parallel environments + """ + self._n_envs = n_envs + self._array_backend = ArrayBackend.get_array_backend(backend) + self._device = device + + self._storage = [] + self._key_mapping = {} #maps keys for future output to key paths + self._shape_mapping = {} #maps keys to additional shapes for arrays + self._structured_storage = {} + super().__init__() + self._add_all_save_attr() + + def append(self, info): + """ + Append new step information + + Args: + info (dict or list): Information to append either list of dicts of every environment, or a dictionary of arrays + """ + if self._n_envs > 1: + assert isinstance(info, (dict, list)) + else: + assert isinstance(info, dict) + + self._storage.append(info) + + def parse(self, to=None): + """ + Parse the stored information into an flat dictionary of arrays + + Args: + to (str): the backend to be used for the returned arrays, 'torch' or 'numpy'. + + Returns: + dict: Flat dictionary containing an array for every property of the step information + """ + if to is None: + to = self._array_backend.get_backend_name() + + #create key mapping + for step_data in self._storage: + if isinstance(step_data, dict): + self._update_key_mapping(step_data, self._n_envs == 1) + elif isinstance(step_data, list): + for env_data in step_data: + assert isinstance(env_data, dict) + self._update_key_mapping(env_data, True) + + # calculate the size for the array + if self._structured_storage: + length_structured_storage = self._structured_storage[next(iter(self._structured_storage.keys()))].shape[0] + else: + length_structured_storage = 0 + size = (len(self._storage) + length_structured_storage, self._n_envs) if self._n_envs > 1 else (len(self._storage) + length_structured_storage, ) + + #create output dictionary with empty arrays + output = { + key: ArrayBackend.get_array_backend(to).empty(size + self._shape_mapping[key], self._device) + for key in self._key_mapping + } + + #fill output with elements stored in structured storage + if self._structured_storage: + for key in output: + index = length_structured_storage + value = self._convert(self._structured_storage[key], to) + output[key][:index] = value + + #fill output with elements stored in storage + for index, step_data in enumerate(self._storage): + index = index + length_structured_storage + if isinstance(step_data, dict): + self._append_dict_to_output(output, step_data, index, to) + elif isinstance(step_data, list): + self._append_list_to_output(output, step_data, index, to) + + self._structured_storage = {key: value for key, value in output.items()} + self._storage = [] + self._array_backend = ArrayBackend.get_array_backend(to) + + self.data = output + + def flatten(self, mask=None): + """ + Flattens the arrays in data by combining the first two dimensions. + + Args: + mask + + Returns: + ExtraInfo: Flattened ExtraInfo + """ + self.parse() + + info = ExtraInfo(1, self._array_backend.get_backend_name(), self._device) + info._shape_mapping = self._shape_mapping + info._key_mapping = self._key_mapping + info._structured_storage = {} + + for key in self.data: + if mask is None: + info.data[key] = info._array_backend.flatten(self.data[key]) + else: + info.data[key] = info._array_backend.pack_padded_sequence(self.data[key], mask) + + for key in self._structured_storage: + if mask is None: + info._structured_storage[key] = info._array_backend.flatten(self._structured_storage[key]) + else: + info._structured_storage[key] = info._array_backend.pack_padded_sequence(self._structured_storage[key], mask) + + return info + + def __add__(self, other): + """ + Returns new object which combines two ExtraInfo objects. + + Args: + other(ExtraInfo): other ExtraInfo which will be combined with self + """ + assert(self._n_envs == other.n_envs) + + info = ExtraInfo(self._n_envs, self._array_backend.get_backend_name(), self._device) + info._storage = self._storage + other._storage + + info._structured_storage = self._concatenate_dictionary(self._structured_storage, other._structured_storage, self._array_backend, other._array_backend) + info.data = self._concatenate_dictionary(self.data, other.data, self._array_backend, other._array_backend) + + #combine key_mapping + info._key_mapping = self._key_mapping.copy() + info._key_mapping.update(other._key_mapping) + + #combine shape_mapping + info._shape_mapping = self._shape_mapping.copy() + info._shape_mapping.update(other._shape_mapping) + + return info + + def _concatenate_array(self, array1, array2, intended_length_array1, intended_length_array2, array1_backend, array2_backend): + """ + Concatenate array1 with array2 + + Args: + array1 (array, None) + array2 (array, None) + intended_length_array1 (int): Intended Length of array1 in case array1 is None + intended_length_array2 (int): Intended Length of array2 in case array2 is None + array1_backend (ArrayBackend): Backend of array1 + array2_backend (ArrayBackend): Backend of array2 + + Returns: + array: Concatenation of array1 and array2 + """ + if array1 is None: + shape = (intended_length_array1,) + array2_backend.shape(array2)[1:] + array1 = array1_backend.full(shape, array1_backend.none()) + if array2 is None: + shape = (intended_length_array2, ) + array1_backend.shape(array1)[1:] + array2 = array2_backend.full(shape, array2_backend.none()) + array2 = array1_backend.convert(array2, backend=array2_backend) + return array1_backend.concatenate((array1, array2)) + + def _concatenate_dictionary(self, dict1, dict2, backend1, backend2): + """ + Concatenate dict1 with dict2. + + Args: + dict1 (dict): Flat dictionary containing arrays of backend1 + dict2 (dict): Flat dictionary containing arrays of backend2 + backend1 (ArrayBackend): Backend of arrays in dict1. + backend2 (ArrayBackend): Backend of arrays in dict2. + + Returns + dict: Concatenation of dict1 and dict2 + """ + if not dict1: + return dict2 + if not dict2: + return dict1 + + array_length_dict1 = backend1.shape(dict1[next(iter(dict1.keys()))])[0] + array_length_dict2 = backend2.shape(dict2[next(iter(dict2.keys()))])[0] + + r = {} + + for key in dict1.keys() | dict2.keys(): + array1 = dict1[key] if key in dict1 else None + array2 = dict2[key] if key in dict2 else None + r[key] = self._concatenate_array(array1, array2, array_length_dict1, array_length_dict2, backend1, backend2) + return r + + + def copy(self): + info = ExtraInfo(self._n_envs, self._array_backend.get_backend_name(), self._device) + info._storage = self._storage.copy() + info._key_mapping = self._key_mapping.copy() + info._shape_mapping = self._shape_mapping.copy() + info.data = self.data.copy() + + return info + + def get_view(self, index, copy=False): + """ + Returns ExtraInfo Object which only contains the specified indexes + + Args: + index (int, slice, ndarray, tensor): indexes which the return should contain + copy (bool): wether content of ExtraInfo object should be copied + """ + self.parse() + info = ExtraInfo(self._n_envs, self._array_backend.get_backend_name(), self._device) + info._key_mapping = self._key_mapping + info._shape_mapping = self._shape_mapping + + if not copy: + info._structured_storage = {key: value[index, ...] for key, value in self._structured_storage.items()} + info.data = {key: value[index, ...] for key, value in self.data.items()} + else: + for key, value in self._structured_storage.items(): + value = value[index, ...] + info._structured_storage[key] = self._array_backend.empty(value.shape, self._device) + info._structured_storage[key][:] = value + + for key, value in self.data.items(): + value = value[index, ...] + info.data[key] = self._array_backend.empty(value.shape, self._device) + info.data[key][:] = value + + return info + + def clear(self): + self._storage = [] + self._key_mapping = {} + self._shape_mapping = {} + self._structured_storage = {} + self.data = {} + + def _add_all_save_attr(self): + self._add_save_attr( + data='primitive', + _storage='primitive', + _structured_storage='primitive', + _key_mapping='primitive', + _shape_mapping='primitive' + ) + + def _update_key_mapping(self, template, single_env): + """ + Update the pattern and the key_paths with the keys from the given template + + Args: + template (dict): Dictionary to extract the keys from + single_env (bool): Wether template contains data for only one environment + """ + assert(isinstance(template, dict)) + + # Stack to store dictionaries and their parent key + stack = [(template, [])] + + while stack: + structure_element, parent_keys = stack.pop() + assert isinstance(structure_element, dict) + + #Iterate over the dict + for key, value in structure_element.items(): + key_path = parent_keys + [key] + + # skip if key is already in key_mapping + if key_path in self._key_mapping.values(): + continue + + if isinstance(value, dict): + stack.append((value, key_path)) + else: + new_key = self._create_key(key_path) + self._store_array_shape(new_key, value, single_env) + + def _append_dict_to_output(self, output, step_data, index, to): + """ + Append a dictionary to the output arrays. + + Args: + output (dict): Flat dictionary containing the arrays + step_data (dict): Containing the step information for one step + index (int): index of the step + to (str): Target format + """ + for key, key_path in self._key_mapping.items(): + value = self._find_element_by_key_path(step_data, key_path) + value = self._convert(value, to) + output[key][index] = value + + def _append_list_to_output(self, output, step_data, index, to): + """ + Append a list to the output arrays. + + Args: + output (list): Flat dictionary containing the arrays + step_data (dict): List containing the step information in form of a dictionary for every environment + index (int): index of the step + to (str): Target format + """ + assert(self._n_envs > 1) + for key, key_path in self._key_mapping.items(): + for i, env_data in enumerate(step_data): + value = self._find_element_by_key_path(env_data, key_path) + value = self._convert(value, to) + output[key][index][i] = value + + def _find_element_by_key_path(self, source, key_path): + """ + Find the value in source corresponding to the key path. + + Args: + source (dict): Dictionary to search in. + key_path (list): List of keys. + + Returns: + The found value or None if any key is missing. + """ + current = source + for key in key_path: + if key in current: + current = current[key] + else: + return None + return current + + def _convert(self, value, to): + """ + Convert value to the target format. + + Args: + value: Value to convert. + to (str): Target format, 'torch' or 'numpy'. + + Returns: + Converted value. + """ + if isinstance(value, numbers.Number): + return value + + if value is None: + return ArrayBackend.get_array_backend(to).none() + + return ArrayBackend.convert(value, to=to, backend=self._array_backend) + + def _create_key(self, key_path): + """ + Creates single key in pattern from a list of keys. + + Args: + key_path (list): List of keys to combine. + + Returns: + key (str): Created key. + """ + key = "_".join(str(key) for key in key_path) + self._key_mapping[key] = key_path + return key + + def _store_array_shape(self, key, value, single_env): + """ + Stores the shape of the value. If value does not have a shape, an empty tuple is stored. + + Args: + key (str): Dictionary key. + value (Array, Number): Variable whose shape should be saved + sinlge_env (bool): + """ + if isinstance(value, numbers.Number): + self._shape_mapping[key] = () + else: + shape = self._array_backend.shape(value) + self._shape_mapping[key] = shape[1:] if not single_env else shape + + @property + def n_envs(self): + return self._n_envs + + def __setitem__(self, key, value): + raise TypeError("This dictionary is read-only.") + + def __delitem__(self, key): + raise TypeError("This dictionary is read-only.") + + def pop(self, key, default=None): + raise TypeError("This dictionary is read-only.") + + def popitem(self): + raise TypeError("This dictionary is read-only.") + + def setdefault(self, key, default=None): + raise TypeError("This dictionary is read-only.") + + def update(self, *args, **kwargs): + raise TypeError("This dictionary is read-only.") \ No newline at end of file diff --git a/mushroom_rl/core/multiprocess_environment.py b/mushroom_rl/core/multiprocess_environment.py index b8593d8b..6f566851 100644 --- a/mushroom_rl/core/multiprocess_environment.py +++ b/mushroom_rl/core/multiprocess_environment.py @@ -66,6 +66,7 @@ def __init__(self, env_class, *args, n_envs=-1, use_generator=False, **kwargs): **kwargs: keyword arguments to set to the constructor or to the generator; """ + assert env_class is not None, "Environment class requires not installed module." assert n_envs > 1 or n_envs == -1 if n_envs == -1: @@ -107,7 +108,7 @@ def reset_all(self, env_mask, state=None): else: episode_infos.append({}) - return self._states, episode_infos + return self._states.copy(), episode_infos.copy() def step_all(self, env_mask, action): for i, remote in enumerate(self._remotes): @@ -129,7 +130,7 @@ def step_all(self, env_mask, action): else: step_infos.append({}) - return self._states.copy(), rewards, absorbings, step_infos + return self._states.copy(), rewards.copy(), absorbings.copy(), step_infos.copy() def render_all(self, env_mask, record=False): for i, remote in enumerate(self._remotes): diff --git a/mushroom_rl/environments/isaac_env.py b/mushroom_rl/environments/isaac_env.py index 6215090c..ec94f158 100644 --- a/mushroom_rl/environments/isaac_env.py +++ b/mushroom_rl/environments/isaac_env.py @@ -90,8 +90,9 @@ def reset_all(self, env_mask, state=None): self._task.reset_idx(idxs) # self._world.step(render=self._render) # TODO Check if we can do otherwise task_obs = self._task.get_observations() + task_extras = self._task.get_extras() observation = convert_task_observation(task_obs) - return observation, [{}]*self._n_envs + return observation.clone(), [task_extras]*self._n_envs def step_all(self, env_mask, action): self._task.pre_physics_step(action) @@ -106,7 +107,7 @@ def step_all(self, env_mask, action): env_mask_cuda = torch.as_tensor(env_mask).cuda() - return observation, reward, torch.logical_and(done, env_mask_cuda), [info]*self._n_envs + return observation.clone(), reward, torch.logical_and(done, env_mask_cuda), [info]*self._n_envs def render_all(self, env_mask, record=False): self._world.render() @@ -132,3 +133,11 @@ def _convert_gym_space(space): return Box(low=space.low, high=space.high, shape=space.shape) else: raise ValueError + + @property + def world(self): + return self._world + + @property + def render_enabled(self): + return self._render \ No newline at end of file diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index e55a8ac1..c99c6bfa 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -126,4 +126,8 @@ def test_dataset_loading(tmpdir): assert dataset._dataset_info.gamma == new_dataset._dataset_info.gamma + assert len(dataset.info) == len(new_dataset.info) + for key in dataset.info: + assert np.array_equal(dataset.info[key], new_dataset.info[key]) + diff --git a/tests/core/test_extra_info.py b/tests/core/test_extra_info.py new file mode 100644 index 00000000..51c22b25 --- /dev/null +++ b/tests/core/test_extra_info.py @@ -0,0 +1,453 @@ +from mushroom_rl.core import ExtraInfo +import torch +import numpy as np + +def test_list_of_dict(): + info = ExtraInfo(6, 'numpy') + + data = [] + for i in range(6): + single_step_data = { + 'prop1': 100 + i, + 'prop2': np.arange(300 + i, 300 + i + 0.5, 0.1), + 'prop3': { + 'x': 400 + i, + 'y': 500 + i + } + } + data.append(single_step_data) + + data2 = [] + for i in range(6): + single_step_data = { + 'prop1': 110 + i, + 'prop2': np.arange(310 + i, 310 + i + 0.5, 0.1), + 'prop3': { + 'x': 410 + i, + 'y': 510 + i + } + } + data2.append(single_step_data) + + info.append(data) + info.append(data2) + + info.parse(to='torch') + + assert(len(info) == 4) + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop2"])) + assert(torch.is_tensor(info["prop3_x"])) + assert(torch.is_tensor(info["prop3_y"])) + assert(info["prop1"].dim() == 2 and info["prop1"].size(0) == 2 and info["prop1"].size(1) == 6) + assert(info["prop2"].dim() == 3 and info["prop2"].size(0) == 2 and info["prop2"].size(1) == 6 and info["prop2"].size(2) == 5) + assert(info["prop3_x"].dim() == 2 and info["prop3_x"].size(0) == 2 and info["prop3_x"].size(1) == 6) + assert(info["prop3_y"].dim() == 2 and info["prop3_y"].size(0) == 2 and info["prop3_y"].size(1) == 6) + + info = info.flatten() + + assert(len(info) == 4) + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop2"])) + assert(torch.is_tensor(info["prop3_x"])) + assert(torch.is_tensor(info["prop3_y"])) + assert(info["prop1"].dim() == 1 and info["prop1"].size(0) == 12) + assert(info["prop2"].dim() == 2 and info["prop2"].size(0) == 12 and info["prop2"].size(1) == 5) + assert(info["prop3_x"].dim() == 1 and info["prop3_x"].size(0) == 12) + assert(info["prop3_y"].dim() == 1 and info["prop3_y"].size(0) == 12) + + prop1 = torch.tensor([100, 110, 101, 111, 102, 112, 103, 113, 104, 114, 105, 115]) + prop3_x = torch.tensor([400, 410, 401, 411, 402, 412, 403, 413, 404, 414, 405, 415]) + prop3_y = torch.tensor([500, 510, 501, 511, 502, 512, 503, 513, 504, 514, 505, 515]) + assert torch.equal(prop1, info["prop1"]) + assert torch.equal(prop3_x, info["prop3_x"]) + assert torch.equal(prop3_y, info["prop3_y"]) + + info.parse(to='torch') + + assert(len(info) == 4) + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop2"])) + assert(torch.is_tensor(info["prop3_x"])) + assert(torch.is_tensor(info["prop3_y"])) + assert(info["prop1"].dim() == 1 and info["prop1"].size(0) == 12) + assert(info["prop2"].dim() == 2 and info["prop2"].size(0) == 12 and info["prop2"].size(1) == 5) + assert(info["prop3_x"].dim() == 1 and info["prop3_x"].size(0) == 12) + assert(info["prop3_y"].dim() == 1 and info["prop3_y"].size(0) == 12) + +def test_dict_of_torch(): + info = ExtraInfo(4, 'torch') + data1 = { + 'prop1': torch.arange(100, 104), + 'prop2': torch.tensor([[200.0, 200.5], [201.0, 201.5], [202.0, 202.5], [203.0, 203.5]]), + 'prop3': { + 'x': torch.arange(300, 304) + } + } + data2 = { + 'prop1': torch.arange(110, 114), + 'prop2': torch.tensor([[210.0, 210.5], [211.0, 211.5], [212.0, 212.5], [213.0, 213.5]]), + 'prop3': { + 'x': torch.arange(310, 314) + } + } + info.append(data1) + info.append(data2) + + info.parse(to='numpy') + + assert(len(info) == 3) + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + assert(isinstance(info["prop3_x"], np.ndarray)) + assert(info["prop1"].ndim == 2 and info["prop1"].shape[0] == 2 and info["prop1"].shape[1] == 4) + assert(info["prop2"].ndim == 3 and info["prop2"].shape[0] == 2 and info["prop2"].shape[1] == 4 and info["prop2"].shape[2] == 2) + assert(info["prop3_x"].ndim == 2 and info["prop3_x"].shape[0] == 2 and info["prop3_x"].shape[1] == 4) + + info = info.flatten() + + assert(len(info) == 3) + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + assert(isinstance(info["prop3_x"], np.ndarray)) + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 8) + assert(info["prop2"].ndim == 2 and info["prop2"].shape[0] == 8 and info["prop2"].shape[1] == 2) + assert(info["prop3_x"].ndim == 1 and info["prop3_x"].shape[0] == 8) + + assert np.array_equal(np.array([100, 110, 101, 111, 102, 112, 103, 113]), info["prop1"]) + prop2 = np.array([[200.0, 200.5], [210.0, 210.5], [201.0, 201.5], [211.0, 211.5], + [202.0, 202.5], [212.0, 212.5], [203.0, 203.5], [213.0, 213.5]]) + assert np.array_equal(prop2, info["prop2"]) + assert np.array_equal(np.array([300, 310, 301, 311, 302, 312, 303, 313]), info["prop3_x"]) + + info.parse() + + assert(len(info) == 3) + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + assert(isinstance(info["prop3_x"], np.ndarray)) + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 8) + assert(info["prop2"].ndim == 2 and info["prop2"].shape[0] == 8 and info["prop2"].shape[1] == 2) + assert(info["prop3_x"].ndim == 1 and info["prop3_x"].shape[0] == 8) + +def test_empty_dict_in_list(): + info = ExtraInfo(3, 'torch') + + data1 = { + 'prop1': 100, + 'prop2': 200 + } + data2 = {} + data3 = { + 'prop1': 102, + 'prop2': 202 + } + info.append([data1, data2, data3]) + info = info.flatten() + print(info) + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop2" in info) + + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop2"])) + + assert(info["prop1"].dim() == 1 and info["prop1"].size(0) == 3) + assert(info["prop2"].dim() == 1 and info["prop2"].size(0) == 3) + + assert(info["prop1"][0] == 100 and info["prop2"][0] == 200) + assert(torch.isnan(info["prop1"][1]) and torch.isnan(info["prop2"][1])) + assert(info["prop1"][2] == 102 and info["prop2"][2] == 202) + +def test_empty_dict(): + info = ExtraInfo(2, 'numpy') + data1 = { + 'prop1': np.arange(100, 102) + } + data2 = {} + data3 = { + 'prop1': np.arange(120, 122) + } + info.append(data1) + info.append(data2) + info.append(data3) + info = info.flatten() + print(info) + + assert(len(info) == 1) + assert("prop1" in info) + assert(isinstance(info["prop1"], np.ndarray)) + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 6) + + assert info["prop1"][0] == 100 + assert np.isnan(info["prop1"][1]) + assert info["prop1"][2] == 120 + assert info["prop1"][3] == 101 + assert np.isnan(info["prop1"][4]) + assert info["prop1"][5] == 121 + +def test_changing_properties_dict(): + info = ExtraInfo(2, 'numpy') + data1 = { + 'prop2': np.arange(200, 202), + 'prop3': np.arange(300, 302) + } + data2 = { + 'prop2': np.arange(210, 212), + 'prop4': np.arange(410, 412) + } + data3 = { + 'prop2': np.arange(220, 222), + 'prop3': np.arange(320, 322) + } + info.append(data1) + info.append(data2) + info.append(data3) + info.parse(to='torch') + info = info.flatten() + + print(info) + + assert(len(info) == 3) + + assert("prop2" in info) + assert("prop3" in info) + assert("prop4" in info) + + assert(torch.is_tensor(info["prop2"])) + assert(torch.is_tensor(info["prop3"])) + assert(torch.is_tensor(info["prop4"])) + + assert(info["prop2"].dim() == 1 and info["prop2"].size(0) == 6) + assert(info["prop3"].dim() == 1 and info["prop3"].size(0) == 6) + assert(info["prop4"].dim() == 1 and info["prop4"].size(0) == 6) + + assert info["prop2"][0] == 200 and info["prop3"][0] == 300 and torch.isnan(info["prop4"][0]) + assert info["prop2"][1] == 210 and torch.isnan(info["prop3"][1]) and info["prop4"][1] == 410 + assert info["prop2"][2] == 220 and info["prop3"][2] == 320 and torch.isnan(info["prop4"][2]) + assert info["prop2"][3] == 201 and info["prop3"][3] == 301 and torch.isnan(info["prop4"][3]) + assert info["prop2"][4] == 211 and torch.isnan(info["prop3"][4]) and info["prop4"][4] == 411 + assert info["prop2"][5] == 221 and info["prop3"][5] == 321 and torch.isnan(info["prop4"][5]) + +def test_one_environment(): + info = ExtraInfo(1, 'torch') + data1 = { + 'prop1': torch.arange(100, 103), + 'prop2': torch.randn(3, 2), + 'prop3': 1 + } + data2 = { + 'prop1': torch.arange(110, 113), + 'prop2': torch.randn(3, 2), + 'prop3': 2 + } + data3 = { + 'prop1': torch.arange(120, 123), + 'prop2': torch.randn(3, 2), + 'prop3': 3 + } + info.append(data1) + info.append(data2) + info.append(data3) + info.parse('torch') + print(info) + + assert(len(info) == 3) + + assert("prop1" in info) + assert("prop2" in info) + assert("prop3" in info) + + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop2"])) + assert(torch.is_tensor(info["prop3"])) + + assert(info["prop1"].dim() == 2 and info["prop1"].size(0) == 3 and info["prop2"].size(1) == 3) + assert(info["prop2"].dim() == 3 and info["prop2"].size(0) == 3 and info["prop2"].size(1) == 3 and info["prop2"].size(2) == 2) + assert(info["prop3"].dim() == 1 and info["prop3"].size(0) == 3) + +def test_get_view_slice(): + info = ExtraInfo(3, 'torch') + data1 = { + 'prop1': torch.arange(100, 103), + 'prop3': torch.randn(3, 2) + } + data2 = { + 'prop1': torch.arange(110, 113), + 'prop3': torch.randn(3, 2) + } + + info.append(data1) + info.append(data2) + + info = info.flatten() + info = info.get_view(slice(4)) + info.parse('torch') + + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop3" in info) + + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop3"])) + + assert(info["prop1"].dim() == 1 and info["prop1"].size(0) == 4) + assert(info["prop3"].dim() == 2 and info["prop3"].size(0) == 4 and info["prop3"].size(1) == 2) + + assert(info["prop1"][0] == 100) + assert(info["prop1"][1] == 110) + assert(info["prop1"][2] == 101) + assert(info["prop1"][3] == 111) + +def test_get_view_array(): + info = ExtraInfo(3, 'torch') + data1 = { + 'prop1': torch.arange(100, 103), + 'prop3': torch.randn(3, 2) + } + data2 = { + 'prop1': torch.arange(110, 113), + 'prop3': torch.randn(3, 2) + } + + info.append(data1) + info.append(data2) + + info = info.flatten() + info = info.get_view(np.array([1, 2, 5]), True) + info.parse('torch') + print(info) + + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop3" in info) + + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop3"])) + + assert(info["prop1"].dim() == 1 and info["prop1"].size(0) == 3) + assert(info["prop3"].dim() == 2 and info["prop3"].size(0) == 3 and info["prop3"].size(1) == 2) + + assert(info["prop1"][0] == 110) + assert(info["prop1"][1] == 101) + assert(info["prop1"][2] == 112) + +def test_add(): + info1 = ExtraInfo(10, 'numpy') + data1 = { + 'prop1': np.arange(100, 110), + 'prop2': np.arange(200, 210) + } + data2 = { + 'prop1': np.arange(110, 120), + 'prop2': np.arange(210, 220) + } + info1.append(data1) + info1.append(data2) + + info2 = ExtraInfo(10, 'torch') + data1 = { + 'prop1': torch.arange(100, 110, dtype=torch.float32), + 'prop3': torch.arange(300, 310, dtype=torch.float32) + } + data2 = { + 'prop1': torch.arange(110, 120), + 'prop3': torch.arange(310, 320) + } + info2.append(data1) + info2.append(data2) + + info1.parse('torch') + info2.parse('numpy') + + info = info1 + info2 + + assert(len(info) == 3) + + assert("prop1" in info) + assert("prop2" in info) + assert("prop3" in info) + + assert(torch.is_tensor(info["prop1"])) + assert(torch.is_tensor(info["prop2"])) + assert(torch.is_tensor(info["prop3"])) + + assert(info["prop1"].dim() == 2 and info["prop1"].size(0) == 4 and info["prop1"].size(1) == 10) + assert(info["prop2"].dim() == 2 and info["prop2"].size(0) == 4 and info["prop2"].size(1) == 10) + assert(info["prop3"].dim() == 2 and info["prop3"].size(0) == 4 and info["prop3"].size(1) == 10) + + for i in range(2): + for j in range(10): + assert(info["prop1"][i][j] == 100 + i*10 + j) + assert(info["prop2"][i][j] == 200 + i*10 + j) + assert(torch.isnan(info["prop3"][i][j])) + + for i in range(2): + for j in range(10): + assert(info["prop1"][2 + i][j] == 100 + i*10 + j) + assert(torch.isnan(info["prop2"][2 + i][j])) + assert(info["prop3"][2 + i][j] == 300 + i*10 + j) + +def test_clear(): + info = ExtraInfo(10, 'numpy') + data1 = { + 'prop1': np.arange(100, 110), + 'prop2': np.arange(200, 210) + } + data2 = { + 'prop1': np.arange(110, 120), + 'prop2': np.arange(210, 220) + } + info.append(data1) + info.append(data2) + info.parse() + info.clear() + assert(not info) + +def test_flatten_with_mask(): + info = ExtraInfo(5, 'numpy') + data1 = { + 'prop1': np.arange(100, 105), + 'prop2': np.arange(200, 205) + } + data2 = { + 'prop1': np.arange(110, 115), + 'prop2': np.arange(210, 215) + } + info.append(data1) + info.append(data2) + mask = np.array([True, True, False, False, False, True, False, False, True, False]) + info = info.flatten(mask) + + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop2" in info) + + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 4) + assert(info["prop2"].ndim == 1 and info["prop2"].shape[0] == 4) + + assert np.array_equal(np.array([100, 110, 112, 104]), info["prop1"]) + assert np.array_equal(np.array([200, 210, 212, 204]), info["prop2"]) + + #Test if mask is permantly applied + info.parse() + assert(len(info) == 2) + + assert("prop1" in info) + assert("prop2" in info) + + assert(isinstance(info["prop1"], np.ndarray)) + assert(isinstance(info["prop2"], np.ndarray)) + + assert(info["prop1"].ndim == 1 and info["prop1"].shape[0] == 4) + assert(info["prop2"].ndim == 1 and info["prop2"].shape[0] == 4) + + assert np.array_equal(np.array([100, 110, 112, 104]), info["prop1"]) + assert np.array_equal(np.array([200, 210, 212, 204]), info["prop2"])