diff --git a/test/test_replay_buffer.py b/test/test_replay_buffer.py index e510f5e..1cb8b30 100644 --- a/test/test_replay_buffer.py +++ b/test/test_replay_buffer.py @@ -5,109 +5,19 @@ from copy import deepcopy from itertools import product from pprint import pformat -from typing import Any, Literal +from typing import Literal import pytest import torch from utils import ( - are_equivalent_sequences, + are_equivalent, + create_memory, + create_shapes, + element_to_bulk_dim_swap, generate_random_dict_data, generate_random_flat_data, ) -from wingman.replay_buffer import FlatReplayBuffer -from wingman.replay_buffer.core import ReplayBuffer -from wingman.replay_buffer.wrappers.dict_wrapper import DictReplayBufferWrapper - - -def create_shapes( - use_dict: bool, bulk_size: int = 0 -) -> list[tuple[int, ...] | dict[str, Any]]: - """create_shapes. - - Args: - ---- - use_dict (bool): use_dict - bulk_size (int): bulk_size - - Returns: - ------- - list[tuple[int, ...] | dict[str, Any]]: - - """ - if bulk_size: - bulk_shape = (bulk_size,) - else: - bulk_shape = () - - if use_dict: - return [ - (*bulk_shape, 3, 3), - ( - *bulk_shape, - 3, - ), - (*bulk_shape,), - { - "a": (*bulk_shape, 4, 3), - "b": (*bulk_shape,), - "c": { - "d": (*bulk_shape, 11, 2), - }, - }, - { - "e": (*bulk_shape, 3, 2), - }, - (*bulk_shape, 4), - ] - else: - return [ - (*bulk_shape, 3, 3), - (*bulk_shape, 3), - (*bulk_shape,), - ] - - -def create_memory( - mem_size: int, - mode: Literal["numpy", "torch"], - device: torch.device, - store_on_device: bool, - random_rollover: bool, - use_dict: bool, -) -> ReplayBuffer: - """create_memory. - - Args: - ---- - mem_size (int): mem_size - mode (Literal["numpy", "torch"]): mode - device (torch.device): device - store_on_device (bool): store_on_device - random_rollover (bool): random_rollover - use_dict (bool): use_dict - - Returns: - ------- - ReplayBuffer: - - """ - memory = FlatReplayBuffer( - mem_size=mem_size, - mode=mode, - device=device, - store_on_device=store_on_device, - random_rollover=random_rollover, - ) - - if use_dict: - memory = DictReplayBufferWrapper( - replay_buffer=memory, - ) - - return memory - - # define the test configurations _random_rollovers = [True, False] _modes = ["numpy", "torch"] @@ -116,12 +26,14 @@ def create_memory( _devices.append(torch.device("cuda:0")) _store_on_devices = [True, False] _use_dict = [False, True] -ALL_CONFIGURATIONS = product( - _random_rollovers, - _modes, - _devices, - _store_on_devices, - _use_dict, +ALL_CONFIGURATIONS = list( + product( + _random_rollovers, + _modes, + _devices, + _store_on_devices, + _use_dict, + ) ) @@ -165,16 +77,21 @@ def test_bulk( raise ValueError memory.push(data, bulk=True) - # reverse the data to make indexing for checking easier - reversed_data = [list(item) for item in zip(*data)] + # reverse the data + # on insertion we had [element_dim, bulk_dim, *data_shapes] + # on comparison we want [bulk_dim, element_dim, *data_shapes] + serialized_data = element_to_bulk_dim_swap( + element_first_data=data, + bulk_size=bulk_size, + ) # if random rollover and we're more than full, different matching technique if random_rollover and memory.is_full: num_matches = 0 # match according to meshgrid - for item1 in reversed_data: + for item1 in serialized_data: for item2 in memory: - num_matches += int(are_equivalent_sequences(item1, item2)) + num_matches += int(are_equivalent(item1, item2)) assert ( num_matches == bulk_size @@ -183,9 +100,9 @@ def test_bulk( continue for step in range(bulk_size): - item1 = reversed_data[step] + item1 = serialized_data[step] item2 = memory[(iteration * bulk_size + step) % mem_size] - assert are_equivalent_sequences( + assert are_equivalent( item1, item2 ), f"""Something went wrong with rollover at iteration {iteration}, step {step}, expected \n{pformat(item1)}, got \n{pformat(item2)}.""" @@ -231,10 +148,8 @@ def test_non_bulk( num_current_matches = 0 num_previous_matches = 0 for item in memory: - num_current_matches += int(are_equivalent_sequences(item, current_data)) - num_previous_matches += int( - are_equivalent_sequences(item, previous_data) - ) + num_current_matches += int(are_equivalent(item, current_data)) + num_previous_matches += int(are_equivalent(item, previous_data)) assert ( num_current_matches == 1 @@ -247,7 +162,7 @@ def test_non_bulk( # check the current data output = memory.__getitem__(iteration % mem_size) - assert are_equivalent_sequences( + assert are_equivalent( output, current_data ), f"""Something went wrong with rollover at iteration {iteration}, expected \n{pformat(current_data)}, got \n{pformat(output)}.""" @@ -255,7 +170,7 @@ def test_non_bulk( # check the previous data if iteration > 0: output = memory[(iteration - 1) % mem_size] - assert are_equivalent_sequences( + assert are_equivalent( output, previous_data ), f"""Something went wrong with rollover at iteration {iteration}, expected \n{pformat(previous_data)}, got \n{pformat(output)}.""" diff --git a/test/utils.py b/test/utils.py index 03350b6..ebeb317 100644 --- a/test/utils.py +++ b/test/utils.py @@ -7,57 +7,97 @@ import numpy as np import torch +from wingman.replay_buffer import FlatReplayBuffer +from wingman.replay_buffer.core import ReplayBuffer +from wingman.replay_buffer.wrappers.dict_wrapper import DictReplayBufferWrapper -def _cast(array: np.ndarray | torch.Tensor | float | int) -> np.ndarray: - """_cast. + +def create_shapes( + use_dict: bool, bulk_size: int = 0 +) -> list[tuple[int, ...] | dict[str, Any]]: + """create_shapes. Args: ---- - array (np.ndarray | torch.Tensor | float | int): array + use_dict (bool): use_dict + bulk_size (int): bulk_size Returns: ------- - np.ndarray: + list[tuple[int, ...] | dict[str, Any]]: """ - if isinstance(array, np.ndarray): - return array - elif isinstance(array, torch.Tensor): - return array.cpu().numpy() # pyright: ignore[reportAttributeAccessIssue] + if bulk_size: + bulk_shape = (bulk_size,) else: - return np.asarray(array) + bulk_shape = () + if use_dict: + return [ + (*bulk_shape, 3, 3), + ( + *bulk_shape, + 3, + ), + (*bulk_shape,), + { + "a": (*bulk_shape, 4, 3), + "b": (*bulk_shape,), + "c": { + "d": (*bulk_shape, 11, 2), + }, + }, + { + "e": (*bulk_shape, 3, 2), + }, + (*bulk_shape, 4), + ] + else: + return [ + (*bulk_shape, 3, 3), + (*bulk_shape, 3), + (*bulk_shape,), + ] -def are_equivalent_sequences( - item1: Any, - item2: Any, -): - """Check if two pieces of data are equivalent. + +def create_memory( + mem_size: int, + mode: Literal["numpy", "torch"], + device: torch.device, + store_on_device: bool, + random_rollover: bool, + use_dict: bool, +) -> ReplayBuffer: + """create_memory. Args: ---- - item1 (Any): item1 - item2 (Any): item2 + mem_size (int): mem_size + mode (Literal["numpy", "torch"]): mode + device (torch.device): device + store_on_device (bool): store_on_device + random_rollover (bool): random_rollover + use_dict (bool): use_dict - """ - # comparison for array-able types - if isinstance(item1, (int, float, bool, torch.Tensor, np.ndarray)) or item1 is None: - return np.isclose(_cast(item1), _cast(item2)).all() + Returns: + ------- + ReplayBuffer: - # comparison for lists and tuples - if isinstance(item1, (list, tuple)) and isinstance(item2, (list, tuple)): - return len(item1) == len(item2) and all( - are_equivalent_sequences(d1, d2) for d1, d2 in zip(item1, item2) - ) + """ + memory = FlatReplayBuffer( + mem_size=mem_size, + mode=mode, + device=device, + store_on_device=store_on_device, + random_rollover=random_rollover, + ) - # comparison for dictionaries - if isinstance(item1, dict) and isinstance(item2, dict): - return item1.keys() != item2.keys() and all( - are_equivalent_sequences(item1[key], item2[key]) for key in item1 + if use_dict: + memory = DictReplayBufferWrapper( + replay_buffer=memory, ) - # non of the checks passed - return False + return memory def generate_random_flat_data( @@ -109,3 +149,121 @@ def generate_random_dict_data( data[key] = generate_random_flat_data(shape=val, mode=mode) return data + + +def _dict_element_to_bulk_dim_swap( + data_dict: dict[str, Any], + bulk_size: int, +) -> list[dict[str, Any]]: + """Given a nested dictionary where each leaf is an n-long array, returns an n-long sequence where each item is the same nested dictionary structure. + + Args: + ---- + data_dict (dict[str, Any]): data_dict + bulk_size (int): bulk_size + + Returns: + ------- + list[dict[str, Any]]: + + """ + bulk_first_dicts: list[dict[str, Any]] = [dict() for _ in range(bulk_size)] + for key, value in data_dict.items(): + if isinstance(value, dict): + for i, element in enumerate( + _dict_element_to_bulk_dim_swap(data_dict=value, bulk_size=bulk_size) + ): + bulk_first_dicts[i][key] = element + else: + for i, element in enumerate(value): + bulk_first_dicts[i][key] = element + + return bulk_first_dicts + + +def element_to_bulk_dim_swap( + element_first_data: list[Any], + bulk_size: int, +) -> list[Any]: + """Given a tuple of elements, each with `bulk_size` items, returns a `bulk_size` sequence, with each item being the size of the tuple. + + Args: + ---- + element_first_data (list[Any]): element_first_data + bulk_size (int): bulk_size + + Returns: + ------- + list[Any]: + + """ + bulk_first_data = [[] for _ in range(bulk_size)] + for element in element_first_data: + # if not a dictionary, can do a plain axis extract + if not isinstance(element, dict): + for i in range(bulk_size): + bulk_first_data[i].append(element[i]) + + # if it's a dictionary, then we need to unpack the dictionary into each item + else: + for i, dict_element in enumerate( + _dict_element_to_bulk_dim_swap( + data_dict=element, + bulk_size=bulk_size, + ) + ): + bulk_first_data[i].append(dict_element) + + return bulk_first_data + + +def _cast(array: np.ndarray | torch.Tensor | float | int) -> np.ndarray: + """_cast. + + Args: + ---- + array (np.ndarray | torch.Tensor | float | int): array + + Returns: + ------- + np.ndarray: + + """ + if isinstance(array, np.ndarray): + return array + elif isinstance(array, torch.Tensor): + return array.cpu().numpy() # pyright: ignore[reportAttributeAccessIssue] + else: + return np.asarray(array) + + +def are_equivalent( + item1: Any, + item2: Any, +): + """Check if two pieces of data are equivalent. + + Args: + ---- + item1 (Any): item1 + item2 (Any): item2 + + """ + # comparison for array-able types + if isinstance(item1, (int, float, bool, torch.Tensor, np.ndarray)) or item1 is None: + return np.isclose(_cast(item1), _cast(item2)).all() + + # comparison for lists and tuples + if isinstance(item1, (list, tuple)) and isinstance(item2, (list, tuple)): + return len(item1) == len(item2) and all( + are_equivalent(d1, d2) for d1, d2 in zip(item1, item2) + ) + + # comparison for dictionaries + if isinstance(item1, dict) and isinstance(item2, dict): + return item1.keys() == item2.keys() and all( + are_equivalent(item1[key], item2[key]) for key in item1.keys() + ) + + # non of the checks passed + return False diff --git a/wingman/replay_buffer/core.py b/wingman/replay_buffer/core.py index 651640c..f0e3849 100644 --- a/wingman/replay_buffer/core.py +++ b/wingman/replay_buffer/core.py @@ -46,20 +46,6 @@ def __repr__(self) -> str: {self.memory} """ - def __getitem__(self, idx: int) -> Sequence[Any]: - """__getitem__. - - Args: - ---- - idx (int): idx - - Returns: - ------- - Sequence[Any]: - - """ - raise NotImplementedError - @property def is_full(self) -> bool: """Whether or not the replay buffer has reached capacity. @@ -90,6 +76,21 @@ def iter_sample( for _ in range(num_iter): yield (self.sample(batch_size=batch_size)) + @abstractmethod + def sample(self, batch_size: int) -> Sequence[Any]: + """sample. + + Args: + ---- + batch_size (int): batch_size + + Returns: + ------- + Sequence[Any]: + + """ + raise NotImplementedError + @abstractmethod def push( self, @@ -111,12 +112,12 @@ def push( raise NotImplementedError @abstractmethod - def sample(self, batch_size: int) -> Sequence[Any]: - """sample. + def __getitem__(self, idx: int) -> Sequence[Any]: + """__getitem__. Args: ---- - batch_size (int): batch_size + idx (int): idx Returns: -------