diff --git a/flashbax/buffers/prioritised_trajectory_buffer.py b/flashbax/buffers/prioritised_trajectory_buffer.py index 2337bf8..ea1206f 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer.py +++ b/flashbax/buffers/prioritised_trajectory_buffer.py @@ -440,7 +440,7 @@ def prioritised_add( add_batch_size, max_length_time_axis = utils.get_tree_shape_prefix( state.experience, n_axes=2 ) - chex.assert_axis_dimension_lt( + chex.assert_axis_dimension_lteq( jax.tree_util.tree_leaves(batch)[0], 1, max_length_time_axis ) diff --git a/flashbax/buffers/prioritised_trajectory_buffer_test.py b/flashbax/buffers/prioritised_trajectory_buffer_test.py index 5600247..af7f71c 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer_test.py +++ b/flashbax/buffers/prioritised_trajectory_buffer_test.py @@ -73,6 +73,43 @@ def get_fake_batch_sequence( return get_fake_batch(get_fake_batch(fake_transition, sequence_length), batch_size) +def test_add_max_length( + fake_transition: chex.ArrayTree, + device: str, + sample_period: int, + max_length: int, + add_batch_size: int, + sample_sequence_length: int, +) -> None: + """Check the `add` function works when adding the max length.""" + prioritised_state = prioritised_trajectory_buffer.prioritised_init( + fake_transition, + add_batch_size, + max_length, + sample_period, + ) + fake_batch_sequence = get_fake_batch_sequence( + fake_transition, add_batch_size, max_length + ) + assert not prioritised_state.is_full + prioritised_state = prioritised_trajectory_buffer.prioritised_add( + prioritised_state, + fake_batch_sequence, + sample_sequence_length, + sample_period, + device, + ) + assert prioritised_state.is_full + sampled = prioritised_trajectory_buffer.prioritised_sample( + prioritised_state, + jax.random.PRNGKey(0), + 1, + sample_sequence_length, + sample_period, + ) + assert sampled.experience["reward"].shape == (1, sample_sequence_length) + + def test_add_and_can_sample_prioritised( prioritised_state: prioritised_trajectory_buffer.PrioritisedTrajectoryBufferState, fake_transition: chex.ArrayTree,