Skip to content

Commit

Permalink
prioritised_trajectory_buffer: allow adding max_length_time_axis (#40)
Browse files Browse the repository at this point in the history
* prioritised_trajectory_buffer: allow adding max_length_time_axis

* test sampling in test_add_max_length
  • Loading branch information
garymm authored Oct 29, 2024
1 parent e0199d7 commit a8ff66c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flashbax/buffers/prioritised_trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def prioritised_add(
add_batch_size, max_length_time_axis = utils.get_tree_shape_prefix(
state.experience, n_axes=2
)
chex.assert_axis_dimension_lt(
chex.assert_axis_dimension_lteq(
jax.tree_util.tree_leaves(batch)[0], 1, max_length_time_axis
)

Expand Down
37 changes: 37 additions & 0 deletions flashbax/buffers/prioritised_trajectory_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,43 @@ def get_fake_batch_sequence(
return get_fake_batch(get_fake_batch(fake_transition, sequence_length), batch_size)


def test_add_max_length(
fake_transition: chex.ArrayTree,
device: str,
sample_period: int,
max_length: int,
add_batch_size: int,
sample_sequence_length: int,
) -> None:
"""Check the `add` function works when adding the max length."""
prioritised_state = prioritised_trajectory_buffer.prioritised_init(
fake_transition,
add_batch_size,
max_length,
sample_period,
)
fake_batch_sequence = get_fake_batch_sequence(
fake_transition, add_batch_size, max_length
)
assert not prioritised_state.is_full
prioritised_state = prioritised_trajectory_buffer.prioritised_add(
prioritised_state,
fake_batch_sequence,
sample_sequence_length,
sample_period,
device,
)
assert prioritised_state.is_full
sampled = prioritised_trajectory_buffer.prioritised_sample(
prioritised_state,
jax.random.PRNGKey(0),
1,
sample_sequence_length,
sample_period,
)
assert sampled.experience["reward"].shape == (1, sample_sequence_length)


def test_add_and_can_sample_prioritised(
prioritised_state: prioritised_trajectory_buffer.PrioritisedTrajectoryBufferState,
fake_transition: chex.ArrayTree,
Expand Down

0 comments on commit a8ff66c

Please sign in to comment.