Skip to content

Commit

Permalink
chore: typing in mixer tests; fix List type in mixer
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Jul 17, 2024
1 parent c6eee8f commit 7874f38
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
32 changes: 16 additions & 16 deletions flashbax/buffers/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import functools
from typing import Callable, List, TypeVar
from typing import Callable, Sequence, TypeVar

import chex
import jax
Expand Down Expand Up @@ -87,10 +87,10 @@ def _batch_slicer(


def sample_mixer_fn(
states: List[StateTypes],
states: Sequence[StateTypes],
key: chex.PRNGKey,
prop_batch_sizes: List[int],
sample_fns: List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]],
prop_batch_sizes: Sequence[int],
sample_fns: Sequence[Callable[[StateTypes, chex.PRNGKey], SampleTypes]],
) -> SampleTypes:
"""Perform mixed sampling from provided buffer states, according to provided proportions.
Expand All @@ -99,12 +99,12 @@ def sample_mixer_fn(
all samples need to be sequences of the same sequence length but batch sizes can differ.
Args:
states (List[StateTypes]): list of buffer states
states (Sequence[StateTypes]): list of buffer states
key (chex.PRNGKey): random key
prop_batch_sizes (List[Numeric]): list of batch sizes sampled from each buffer, calculated
according to the proportions of joint sample size
sample_fns (List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure sample
functions from each buffer
prop_batch_sizes (Sequence[Numeric]): list of batch sizes sampled from each buffer,
calculated according to the proportions of joint sample size
sample_fns (Sequence[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure
sample functions from each buffer
Returns:
SampleTypes: proportionally concatenated samples from all buffers
Expand Down Expand Up @@ -140,13 +140,13 @@ def sample_mixer_fn(


def can_sample_mixer_fn(
states: List[StateTypes], can_sample_fns: List[Callable[[StateTypes], bool]]
states: Sequence[StateTypes], can_sample_fns: Sequence[Callable[[StateTypes], bool]]
) -> bool:
"""Check if all buffers can sample.
Args:
states (List[StateTypes]): list of buffer states
can_sample_fns (List[Callable[[StateTypes], bool]]): list of can_sample functions
states (Sequence[StateTypes]): list of buffer states
can_sample_fns (Sequence[Callable[[StateTypes], bool]]): list of can_sample functions
from each buffer
Returns:
Expand All @@ -162,16 +162,16 @@ def can_sample_mixer_fn(


def make_mixer(
buffers: List[BufferTypes],
buffers: Sequence[BufferTypes],
sample_batch_size: int,
proportions: List[Numeric],
proportions: Sequence[Numeric],
) -> Mixer:
"""Create the mixer.
Args:
buffers (List[BufferTypes]): list of buffers (pure functions)
buffers (Sequence[BufferTypes]): list of buffers (pure functions)
sample_batch_size (int): desired batch size of joint sample
proportions (List[Numeric]):
proportions (Sequence[Numeric]):
Proportions of joint sample size to be sampled from each buffer, given as a ratio.
Returns:
Expand Down
41 changes: 22 additions & 19 deletions flashbax/buffers/mixer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def get_fake_batch_sequence(


def test_mixed_trajectory_sample(
rng_key,
sample_batch_size,
sample_period,
add_batch_size,
sample_sequence_length,
fake_transition,
rng_key: chex.PRNGKey,
sample_batch_size: int,
sample_period: int,
add_batch_size: int,
sample_sequence_length: int,
fake_transition: chex.ArrayTree,
):
buffers = []
buffer_states = []
Expand Down Expand Up @@ -131,12 +131,12 @@ def test_mixed_trajectory_sample(


def test_mixed_prioritised_trajectory_sample(
rng_key,
sample_batch_size,
sample_period,
add_batch_size,
sample_sequence_length,
fake_transition,
rng_key: chex.PRNGKey,
sample_batch_size: int,
sample_period: int,
add_batch_size: int,
sample_sequence_length: int,
fake_transition: chex.ArrayTree,
):
buffers = []
buffer_states = []
Expand Down Expand Up @@ -182,7 +182,10 @@ def test_mixed_prioritised_trajectory_sample(


def test_mixed_flat_buffer_sample(
rng_key, sample_batch_size, add_batch_size, fake_transition
rng_key: chex.PRNGKey,
sample_batch_size: int,
add_batch_size: int,
fake_transition: chex.ArrayTree,
):
buffers = []
buffer_states = []
Expand Down Expand Up @@ -226,12 +229,12 @@ def test_mixed_flat_buffer_sample(


def test_mixed_buffer_does_not_smoke(
rng_key,
sample_batch_size,
sample_period,
add_batch_size,
sample_sequence_length,
fake_transition,
rng_key: chex.PRNGKey,
sample_batch_size: int,
sample_period: int,
add_batch_size: int,
sample_sequence_length: int,
fake_transition: chex.ArrayTree,
):
buffers = []
buffer_states = []
Expand Down

0 comments on commit 7874f38

Please sign in to comment.