diff --git a/pyproject.toml b/pyproject.toml index fdb87b2..94a1f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "jj_wingman" -version = "0.12.0" +version = "0.13.0" authors = [ { name="Jet", email="taijunjet@hotmail.com" }, ] diff --git a/test/test_replay_buffer.py b/test/test_replay_buffer.py index 6ab862a..5b8c9d2 100644 --- a/test/test_replay_buffer.py +++ b/test/test_replay_buffer.py @@ -115,7 +115,11 @@ def test_bulk( mem_size = 11 element_shapes = [(3, 3), (3,), ()] memory = ReplayBuffer( - mem_size=mem_size, mode=mode, device=device, store_on_device=store_on_device + mem_size=mem_size, + mode=mode, + device=device, + store_on_device=store_on_device, + random_rollover=random_rollover, ) for iteration in range(10): @@ -126,7 +130,7 @@ def test_bulk( for shape in element_shapes: data.append(_randn(shape=(bulk_size, *shape), mode=mode)) print([d.shape for d in data]) - memory.push(data, bulk=True, random_rollover=random_rollover) + memory.push(data, bulk=True) # reverse the data to make indexing for checking easier reversed_data = [list(item) for item in zip(*data)] @@ -168,7 +172,11 @@ def test_non_bulk( mem_size = 11 element_shapes = [(3, 3), (3,), ()] memory = ReplayBuffer( - mem_size=mem_size, mode=mode, device=device, store_on_device=store_on_device + mem_size=mem_size, + mode=mode, + device=device, + store_on_device=store_on_device, + random_rollover=random_rollover, ) previous_data = [] @@ -176,7 +184,7 @@ def test_non_bulk( current_data = [] for shape in element_shapes: current_data.append(_randn(shape=shape, mode=mode)) - memory.push(current_data, random_rollover=random_rollover) + memory.push(current_data) # if random rollover and we're more than full, different matching method if random_rollover and memory.is_full: