Skip to content

Commit

Permalink
fix problem
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Jul 2, 2024
1 parent d2f9a04 commit 4a960ef
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "jj_wingman"
version = "0.12.0"
version = "0.13.0"
authors = [
{ name="Jet", email="[email protected]" },
]
Expand Down
16 changes: 12 additions & 4 deletions test/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def test_bulk(
mem_size = 11
element_shapes = [(3, 3), (3,), ()]
memory = ReplayBuffer(
mem_size=mem_size, mode=mode, device=device, store_on_device=store_on_device
mem_size=mem_size,
mode=mode,
device=device,
store_on_device=store_on_device,
random_rollover=random_rollover,
)

for iteration in range(10):
Expand All @@ -126,7 +130,7 @@ def test_bulk(
for shape in element_shapes:
data.append(_randn(shape=(bulk_size, *shape), mode=mode))
print([d.shape for d in data])
memory.push(data, bulk=True, random_rollover=random_rollover)
memory.push(data, bulk=True)

# reverse the data to make indexing for checking easier
reversed_data = [list(item) for item in zip(*data)]
Expand Down Expand Up @@ -168,15 +172,19 @@ def test_non_bulk(
mem_size = 11
element_shapes = [(3, 3), (3,), ()]
memory = ReplayBuffer(
mem_size=mem_size, mode=mode, device=device, store_on_device=store_on_device
mem_size=mem_size,
mode=mode,
device=device,
store_on_device=store_on_device,
random_rollover=random_rollover,
)

previous_data = []
for iteration in range(20):
current_data = []
for shape in element_shapes:
current_data.append(_randn(shape=shape, mode=mode))
memory.push(current_data, random_rollover=random_rollover)
memory.push(current_data)

# if random rollover and we're more than full, different matching method
if random_rollover and memory.is_full:
Expand Down

0 comments on commit 4a960ef

Please sign in to comment.