Skip to content

Commit

Permalink
complete!
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Jul 3, 2024
1 parent ab746a6 commit 45deddd
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 161 deletions.
141 changes: 28 additions & 113 deletions test/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,109 +5,19 @@
from copy import deepcopy
from itertools import product
from pprint import pformat
from typing import Any, Literal
from typing import Literal

import pytest
import torch
from utils import (
are_equivalent_sequences,
are_equivalent,
create_memory,
create_shapes,
element_to_bulk_dim_swap,
generate_random_dict_data,
generate_random_flat_data,
)

from wingman.replay_buffer import FlatReplayBuffer
from wingman.replay_buffer.core import ReplayBuffer
from wingman.replay_buffer.wrappers.dict_wrapper import DictReplayBufferWrapper


def create_shapes(
use_dict: bool, bulk_size: int = 0
) -> list[tuple[int, ...] | dict[str, Any]]:
"""create_shapes.
Args:
----
use_dict (bool): use_dict
bulk_size (int): bulk_size
Returns:
-------
list[tuple[int, ...] | dict[str, Any]]:
"""
if bulk_size:
bulk_shape = (bulk_size,)
else:
bulk_shape = ()

if use_dict:
return [
(*bulk_shape, 3, 3),
(
*bulk_shape,
3,
),
(*bulk_shape,),
{
"a": (*bulk_shape, 4, 3),
"b": (*bulk_shape,),
"c": {
"d": (*bulk_shape, 11, 2),
},
},
{
"e": (*bulk_shape, 3, 2),
},
(*bulk_shape, 4),
]
else:
return [
(*bulk_shape, 3, 3),
(*bulk_shape, 3),
(*bulk_shape,),
]


def create_memory(
mem_size: int,
mode: Literal["numpy", "torch"],
device: torch.device,
store_on_device: bool,
random_rollover: bool,
use_dict: bool,
) -> ReplayBuffer:
"""create_memory.
Args:
----
mem_size (int): mem_size
mode (Literal["numpy", "torch"]): mode
device (torch.device): device
store_on_device (bool): store_on_device
random_rollover (bool): random_rollover
use_dict (bool): use_dict
Returns:
-------
ReplayBuffer:
"""
memory = FlatReplayBuffer(
mem_size=mem_size,
mode=mode,
device=device,
store_on_device=store_on_device,
random_rollover=random_rollover,
)

if use_dict:
memory = DictReplayBufferWrapper(
replay_buffer=memory,
)

return memory


# define the test configurations
_random_rollovers = [True, False]
_modes = ["numpy", "torch"]
Expand All @@ -116,12 +26,14 @@ def create_memory(
_devices.append(torch.device("cuda:0"))
_store_on_devices = [True, False]
_use_dict = [False, True]
ALL_CONFIGURATIONS = product(
_random_rollovers,
_modes,
_devices,
_store_on_devices,
_use_dict,
ALL_CONFIGURATIONS = list(
product(
_random_rollovers,
_modes,
_devices,
_store_on_devices,
_use_dict,
)
)


Expand Down Expand Up @@ -165,16 +77,21 @@ def test_bulk(
raise ValueError
memory.push(data, bulk=True)

# reverse the data to make indexing for checking easier
reversed_data = [list(item) for item in zip(*data)]
# reverse the data
# on insertion we had [element_dim, bulk_dim, *data_shapes]
# on comparison we want [bulk_dim, element_dim, *data_shapes]
serialized_data = element_to_bulk_dim_swap(
element_first_data=data,
bulk_size=bulk_size,
)

# if random rollover and we're more than full, different matching technique
if random_rollover and memory.is_full:
num_matches = 0
# match according to meshgrid
for item1 in reversed_data:
for item1 in serialized_data:
for item2 in memory:
num_matches += int(are_equivalent_sequences(item1, item2))
num_matches += int(are_equivalent(item1, item2))

assert (
num_matches == bulk_size
Expand All @@ -183,9 +100,9 @@ def test_bulk(
continue

for step in range(bulk_size):
item1 = reversed_data[step]
item1 = serialized_data[step]
item2 = memory[(iteration * bulk_size + step) % mem_size]
assert are_equivalent_sequences(
assert are_equivalent(
item1, item2
), f"""Something went wrong with rollover at iteration {iteration},
step {step}, expected \n{pformat(item1)}, got \n{pformat(item2)}."""
Expand Down Expand Up @@ -231,10 +148,8 @@ def test_non_bulk(
num_current_matches = 0
num_previous_matches = 0
for item in memory:
num_current_matches += int(are_equivalent_sequences(item, current_data))
num_previous_matches += int(
are_equivalent_sequences(item, previous_data)
)
num_current_matches += int(are_equivalent(item, current_data))
num_previous_matches += int(are_equivalent(item, previous_data))

assert (
num_current_matches == 1
Expand All @@ -247,15 +162,15 @@ def test_non_bulk(

# check the current data
output = memory.__getitem__(iteration % mem_size)
assert are_equivalent_sequences(
assert are_equivalent(
output, current_data
), f"""Something went wrong with rollover at iteration {iteration},
expected \n{pformat(current_data)}, got \n{pformat(output)}."""

# check the previous data
if iteration > 0:
output = memory[(iteration - 1) % mem_size]
assert are_equivalent_sequences(
assert are_equivalent(
output, previous_data
), f"""Something went wrong with rollover at iteration {iteration},
expected \n{pformat(previous_data)}, got \n{pformat(output)}."""
Expand Down
Loading

0 comments on commit 45deddd

Please sign in to comment.