From f3cad12a62af8766eabbaa382e3f6cc6660c0ac5 Mon Sep 17 00:00:00 2001 From: Mick van Gelderen Date: Fri, 8 Nov 2024 13:13:50 -0800 Subject: [PATCH] Pass max_length_time_axis instead of max_size Makes it so that the warning: ``` Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*` ``` will no longer be triggered by legitimate use of `create_flat_buffer` and `make_prioritised_flat_buffer`. --- flashbax/buffers/flat_buffer.py | 4 ++-- flashbax/buffers/prioritised_flat_buffer.py | 4 ++-- pyproject.toml | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flashbax/buffers/flat_buffer.py b/flashbax/buffers/flat_buffer.py index 41305a7..e42c94b 100644 --- a/flashbax/buffers/flat_buffer.py +++ b/flashbax/buffers/flat_buffer.py @@ -123,13 +123,13 @@ def create_flat_buffer( ) buffer = make_trajectory_buffer( - max_length_time_axis=None, # Unused because max_size is specified + max_length_time_axis=max_length // add_batch_size, min_length_time_axis=min_length // add_batch_size + 1, add_batch_size=add_batch_size, sample_batch_size=sample_batch_size, sample_sequence_length=2, period=1, - max_size=max_length, + max_size=None, ) add_fn = buffer.add diff --git a/flashbax/buffers/prioritised_flat_buffer.py b/flashbax/buffers/prioritised_flat_buffer.py index a4f6a1d..bf9d630 100644 --- a/flashbax/buffers/prioritised_flat_buffer.py +++ b/flashbax/buffers/prioritised_flat_buffer.py @@ -110,13 +110,13 @@ def make_prioritised_flat_buffer( ) buffer = make_prioritised_trajectory_buffer( - max_length_time_axis=None, # Unused because max_size is specified + max_length_time_axis=max_length // add_batch_size, min_length_time_axis=min_length // add_batch_size + 1, add_batch_size=add_batch_size, sample_batch_size=sample_batch_size, sample_sequence_length=2, period=1, - max_size=max_length, + max_size=None, priority_exponent=priority_exponent, device=device, ) diff --git a/pyproject.toml b/pyproject.toml index 29fefe7..5a98cb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ filterwarnings = [ "error", "ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax", "ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax", - "ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax", "ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax", ]