diff --git a/examples/mixer_demonstration.ipynb b/examples/mixer_demonstration.ipynb new file mode 100644 index 0000000..7b07a40 --- /dev/null +++ b/examples/mixer_demonstration.ipynb @@ -0,0 +1,305 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import flashbax as fbx\n", + "import jax.numpy as jnp\n", + "from jax.tree_util import tree_map\n", + "import jax\n", + "\n", + "key = jax.random.PRNGKey(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create our first buffer, with a sample batch size of 4\n", + "buffer_a = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=1000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=4,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_a = buffer_a.init(\n", + " timestep,\n", + ")\n", + "for i in range(100):\n", + " # Fill with POSITIVE values\n", + " state_a = jax.jit(buffer_a.add, donate_argnums=0)(\n", + " state_a,\n", + " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "sample_a = buffer_a.sample(state_a, key)\n", + "tree_map(lambda x: x.shape, sample_a)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create our second buffer, with a sample batch size of 16\n", + "buffer_b = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=1000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=16,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_b = buffer_b.init(\n", + " timestep,\n", + ")\n", + "for i in range(100):\n", + " # Fill with NEGATIVE values\n", + " state_b = jax.jit(buffer_b.add, donate_argnums=0)(\n", + " state_b,\n", + " tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "sample_b = buffer_b.sample(state_b, key)\n", + "tree_map(lambda x: x.shape, sample_b)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Make the mixer, with a ratio of 1:3 from buffer_a:buffer_b\n", + "mixer = fbx.make_mixer(\n", + " buffers=[buffer_a, buffer_b],\n", + " sample_batch_size=8,\n", + " proportions=[1,3],\n", + ")\n", + "\n", + "# jittable sampling!\n", + "mixer_sample = jax.jit(mixer.sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Sample from the mixer, using the usual flashbax API\n", + "joint_sample = mixer_sample(\n", + " [state_a, state_b],\n", + " key,\n", + ")\n", + "\n", + "# Notice the resulting shape\n", + "tree_map(lambda x: x.shape, joint_sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': Array([[[90., 90., 90.],\n", + " [91., 91., 91.],\n", + " [92., 92., 92.],\n", + " [93., 93., 93.],\n", + " [94., 94., 94.]],\n", + "\n", + " [[56., 56., 56.],\n", + " [57., 57., 57.],\n", + " [58., 58., 58.],\n", + " [59., 59., 59.],\n", + " [60., 60., 60.]]], dtype=float32), 'obs': Array([[[90., 90.],\n", + " [91., 91.],\n", + " [92., 92.],\n", + " [93., 93.],\n", + " [94., 94.]],\n", + "\n", + " [[56., 56.],\n", + " [57., 57.],\n", + " [58., 58.],\n", + " [59., 59.],\n", + " [60., 60.]]], dtype=float32)})" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Notice how the first 1/4 * 8 = 2 batches are from buffer_a (POSITIVE VALUES)\n", + "tree_map(lambda x: x[0:2], joint_sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': Array([[[-34., -34., -34.],\n", + " [-35., -35., -35.],\n", + " [-36., -36., -36.],\n", + " [-37., -37., -37.],\n", + " [-38., -38., -38.]],\n", + "\n", + " [[-88., -88., -88.],\n", + " [-89., -89., -89.],\n", + " [-90., -90., -90.],\n", + " [-91., -91., -91.],\n", + " [-92., -92., -92.]],\n", + "\n", + " [[-30., -30., -30.],\n", + " [-31., -31., -31.],\n", + " [-32., -32., -32.],\n", + " [-33., -33., -33.],\n", + " [-34., -34., -34.]],\n", + "\n", + " [[-11., -11., -11.],\n", + " [-12., -12., -12.],\n", + " [-13., -13., -13.],\n", + " [-14., -14., -14.],\n", + " [-15., -15., -15.]],\n", + "\n", + " [[-78., -78., -78.],\n", + " [-79., -79., -79.],\n", + " [-80., -80., -80.],\n", + " [-81., -81., -81.],\n", + " [-82., -82., -82.]],\n", + "\n", + " [[-15., -15., -15.],\n", + " [-16., -16., -16.],\n", + " [-17., -17., -17.],\n", + " [-18., -18., -18.],\n", + " [-19., -19., -19.]]], dtype=float32), 'obs': Array([[[-34., -34.],\n", + " [-35., -35.],\n", + " [-36., -36.],\n", + " [-37., -37.],\n", + " [-38., -38.]],\n", + "\n", + " [[-88., -88.],\n", + " [-89., -89.],\n", + " [-90., -90.],\n", + " [-91., -91.],\n", + " [-92., -92.]],\n", + "\n", + " [[-30., -30.],\n", + " [-31., -31.],\n", + " [-32., -32.],\n", + " [-33., -33.],\n", + " [-34., -34.]],\n", + "\n", + " [[-11., -11.],\n", + " [-12., -12.],\n", + " [-13., -13.],\n", + " [-14., -14.],\n", + " [-15., -15.]],\n", + "\n", + " [[-78., -78.],\n", + " [-79., -79.],\n", + " [-80., -80.],\n", + " [-81., -81.],\n", + " [-82., -82.]],\n", + "\n", + " [[-15., -15.],\n", + " [-16., -16.],\n", + " [-17., -17.],\n", + " [-18., -18.],\n", + " [-19., -19.]]], dtype=float32)})" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# and how the second 3/4 * 8 = 6 batches are from buffer_b (NEGATIVE VALUES)\n", + "tree_map(lambda x: x[2:], joint_sample)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flashbax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flashbax/__init__.py b/flashbax/__init__.py index 576f501..d2f2c14 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -18,11 +18,13 @@ item_buffer, make_flat_buffer, make_item_buffer, + make_mixer, make_prioritised_flat_buffer, make_prioritised_item_buffer, make_prioritised_trajectory_buffer, make_trajectory_buffer, make_trajectory_queue, + mixer, prioritised_flat_buffer, prioritised_item_buffer, prioritised_trajectory_buffer, diff --git a/flashbax/buffers/__init__.py b/flashbax/buffers/__init__.py index 1e0096f..4325a07 100644 --- a/flashbax/buffers/__init__.py +++ b/flashbax/buffers/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from flashbax.buffers.flat_buffer import make_flat_buffer from flashbax.buffers.item_buffer import make_item_buffer +from flashbax.buffers.mixer import make_mixer from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer from flashbax.buffers.prioritised_item_buffer import make_prioritised_item_buffer from flashbax.buffers.prioritised_trajectory_buffer import ( diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py new file mode 100644 index 0000000..95417a8 --- /dev/null +++ b/flashbax/buffers/mixer.py @@ -0,0 +1,217 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Callable, Sequence, TypeVar + +import chex +import jax +import jax.numpy as jnp +from chex import Numeric, dataclass +from jax import Array +from jax.tree_util import tree_map + +from flashbax.buffers.flat_buffer import TransitionSample +from flashbax.buffers.prioritised_trajectory_buffer import ( + PrioritisedTrajectoryBuffer, + PrioritisedTrajectoryBufferSample, + PrioritisedTrajectoryBufferState, +) +from flashbax.buffers.trajectory_buffer import ( + TrajectoryBuffer, + TrajectoryBufferSample, + TrajectoryBufferState, +) + +# Support for Trajectory, Flat, Item buffers, and prioritised variants +sample_types = [ + TrajectoryBufferSample, + PrioritisedTrajectoryBufferSample, + TransitionSample, +] +SampleTypes = TypeVar( + "SampleTypes", + TrajectoryBufferSample, + PrioritisedTrajectoryBufferSample, + TransitionSample, +) + +state_types = [TrajectoryBufferState, PrioritisedTrajectoryBufferState] +StateTypes = TypeVar( + "StateTypes", TrajectoryBufferState, PrioritisedTrajectoryBufferState +) + +BufferTypes = TypeVar("BufferTypes", TrajectoryBuffer, PrioritisedTrajectoryBuffer) + + +@dataclass(frozen=True) +class Mixer: + """Pure functions defining the mixer. + + Attributes: + sample (Callable): function to sample proportionally from all buffers, + concatenating along the batch axis + can_sample (Callable): function to check if all buffers can sample + """ + + sample: Callable + can_sample: Callable + + +def _batch_slicer( + sample: SampleTypes, + batch_start: int, + batch_end: int, +) -> SampleTypes: + """Simple utility function to slice a sample along the batch axis. + + Args: + sample (SampleTypes): incoming sample + batch_start (int): batch start index + batch_end (int): batch end index + + Returns: + SampleTypes: outgoing sliced sample + """ + return tree_map(lambda x: x[batch_start:batch_end, ...], sample) + + +def sample_mixer_fn( + states: Sequence[StateTypes], + key: chex.PRNGKey, + prop_batch_sizes: Sequence[int], + sample_fns: Sequence[Callable[[StateTypes, chex.PRNGKey], SampleTypes]], +) -> SampleTypes: + """Perform mixed sampling from provided buffer states, according to provided proportions. + + Each buffer sample needs to be of the same pytree structure, and the samples are concatenated + along the first axis i.e. the batch axis. For example, if you are sampling trajectories, then + all samples need to be sequences of the same sequence length but batch sizes can differ. + + Args: + states (Sequence[StateTypes]): list of buffer states + key (chex.PRNGKey): random key + prop_batch_sizes (Sequence[Numeric]): list of batch sizes sampled from each buffer, + calculated according to the proportions of joint sample size + sample_fns (Sequence[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure + sample functions from each buffer + + Returns: + SampleTypes: proportionally concatenated samples from all buffers + """ + keys = jax.random.split( + key, len(states) + ) # Split the key for each buffer sampling operation + + # We first sample from each buffer, and get a list of samples + samples_array = tree_map( + lambda state, sample, key_in: sample(state, key_in), + states, + sample_fns, + list(keys), + is_leaf=lambda leaf: type(leaf) in state_types, + ) + + # We then slice the samples according to the proportions + prop_batch_samples_array = tree_map( + lambda x, p: _batch_slicer(x, 0, p), + samples_array, + prop_batch_sizes, + is_leaf=lambda leaf: type(leaf) in sample_types, + ) + + # Concatenate the samples along the batch axis + joint_sample = tree_map( + lambda *x: jnp.concatenate(x, axis=0), + *prop_batch_samples_array, + ) + + return joint_sample + + +def can_sample_mixer_fn( + states: Sequence[StateTypes], + can_sample_fns: Sequence[Callable[[StateTypes], Array]], +) -> Array: + """Check if all buffers can sample. + + Args: + states (Sequence[StateTypes]): list of buffer states + can_sample_fns (Sequence[Callable[[StateTypes], Array]]): list of can_sample functions + from each buffer + + Returns: + bool: whether all buffers can sample + """ + each_can_sample = jnp.asarray( + tree_map( + lambda state, can_sample: can_sample(state), + states, + can_sample_fns, + is_leaf=lambda leaf: type(leaf) in state_types, + ) + ) + return jnp.all(each_can_sample) + + +def make_mixer( + buffers: Sequence[BufferTypes], + sample_batch_size: int, + proportions: Sequence[Numeric], +) -> Mixer: + """Create the mixer. + + Args: + buffers (Sequence[BufferTypes]): list of buffers (pure functions) + sample_batch_size (int): desired batch size of joint sample + proportions (Sequence[Numeric]): + Proportions of joint sample size to be sampled from each buffer, given as a ratio. + + Returns: + Mixer: a mixer + """ + assert len(buffers) == len( + proportions + ), "Number of buffers and proportions must match" + assert all( + isinstance(b, type(buffers[0])) for b in buffers + ), "All buffers must be of the same type" + assert sample_batch_size > 0, "Sample batch size must be greater than 0" + + sample_fns = [b.sample for b in buffers] + can_sample_fns = [b.can_sample for b in buffers] + + # Normalize proportions and calculate resulting integer batch sizes + props_sum = sum(proportions) + props_norm = [p / props_sum for p in proportions] + prop_batch_sizes = [int(p * sample_batch_size) for p in props_norm] + if sum(prop_batch_sizes) != sample_batch_size: + # In case of rounding errors, add the remainder to the first buffer's proportion + prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes) + + mixer_sample_fn = functools.partial( + sample_mixer_fn, + prop_batch_sizes=prop_batch_sizes, + sample_fns=sample_fns, + ) + + mixer_can_sample_fn = functools.partial( + can_sample_mixer_fn, + can_sample_fns=can_sample_fns, + ) + + return Mixer( + sample=mixer_sample_fn, + can_sample=mixer_can_sample_fn, + ) diff --git a/flashbax/buffers/mixer_test.py b/flashbax/buffers/mixer_test.py new file mode 100644 index 0000000..bfcf6f2 --- /dev/null +++ b/flashbax/buffers/mixer_test.py @@ -0,0 +1,283 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from flashbax.buffers import ( + flat_buffer, + prioritised_trajectory_buffer, + trajectory_buffer, +) +from flashbax.buffers.conftest import get_fake_batch +from flashbax.buffers.mixer import make_mixer + + +@pytest.fixture +def rng_key() -> chex.PRNGKey: + return jax.random.PRNGKey(0) + + +@pytest.fixture +def fake_transition() -> chex.ArrayTree: + return { + "obs": jnp.array([0.0, 0.0]), + "reward": jnp.array(0.0), + "done": jnp.array(False), + "next_obs": jnp.array([0.0, 0.0]), + } + + +@pytest.fixture +def max_length() -> int: + return 32 + + +@pytest.fixture +def min_length() -> int: + return 8 + + +@pytest.fixture +def add_batch_size() -> int: + return 4 + + +@pytest.fixture +def sample_batch_size() -> int: + return 100 + + +@pytest.fixture +def sample_sequence_length() -> int: + return 4 + + +@pytest.fixture +def sample_period() -> int: + return 1 + + +def get_fake_batch_sequence( + fake_transition: chex.ArrayTree, batch_size: int, sequence_length: int +) -> chex.ArrayTree: + return get_fake_batch(get_fake_batch(fake_transition, sequence_length), batch_size) + + +def test_mixed_trajectory_sample( + rng_key: chex.PRNGKey, + sample_batch_size: int, + sample_period: int, + add_batch_size: int, + sample_sequence_length: int, + fake_transition: chex.ArrayTree, +): + buffers = [] + buffer_states = [] + for i in range(3): + buffer = trajectory_buffer.make_trajectory_buffer( + max_length_time_axis=200 * (i + 1), + min_length_time_axis=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=sample_period, + ) + buffers.append(buffer) + + state = buffer.init( + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + + proportions = [0.2, 0.2, 0.6] + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, + ) + samples = mixer.sample(buffer_states, rng_key) + + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience["done"] + dones = dones[:, 0] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + +def test_mixed_prioritised_trajectory_sample( + rng_key: chex.PRNGKey, + sample_batch_size: int, + sample_period: int, + add_batch_size: int, + sample_sequence_length: int, + fake_transition: chex.ArrayTree, +): + buffers = [] + buffer_states = [] + for i in range(3): + buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer( + max_length_time_axis=200 * (i + 1), + min_length_time_axis=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=sample_period, + ) + buffers.append(buffer) + + state = buffer.init( + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + + proportions = [0.4, 0.1, 0.5] + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, + ) + samples = mixer.sample(buffer_states, rng_key) + + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience["done"] + dones = dones[:, 0] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + +def test_mixed_flat_buffer_sample( + rng_key: chex.PRNGKey, + sample_batch_size: int, + add_batch_size: int, + fake_transition: chex.ArrayTree, +): + buffers = [] + buffer_states = [] + for i in range(3): + buffer = flat_buffer.make_flat_buffer( + max_length=200 * (i + 1), + min_length=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + add_sequences=True, + ) + buffers.append(buffer) + + state = buffer.init( + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + + proportions = [0.1, 0.1, 0.8] + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, + ) + samples = mixer.sample(buffer_states, rng_key) + + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience.first["done"] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + +def test_mixed_buffer_does_not_smoke( + rng_key: chex.PRNGKey, + sample_batch_size: int, + sample_period: int, + add_batch_size: int, + sample_sequence_length: int, + fake_transition: chex.ArrayTree, +): + buffers = [] + buffer_states = [] + for i in range(3): + buffer = trajectory_buffer.make_trajectory_buffer( + max_length_time_axis=2000 * (i + 1), + min_length_time_axis=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=sample_period, + ) + buffers.append(buffer) + + state = buffer.init( + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + + proportions = [0.2, 0.2, 0.6] + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, + ) + + can_sample = jax.jit(mixer.can_sample)(buffer_states) + assert can_sample + + samples = jax.jit(mixer.sample)(buffer_states, rng_key) + + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience["done"] + dones = dones[:, 0] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos diff --git a/pyproject.toml b/pyproject.toml index 1f7f379..daf68a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dev = [ 'mkdocs-mermaid2-plugin==1.1.1', 'mkdocstrings[python]==0.23.0', 'mknotebooks==0.8.0', - 'mypy>=0.982', + 'mypy>=1.8.0', 'pre-commit>=2.20.0', 'pytest>=7.4.2', 'pytest-cov>=4.00',