From 0a840713f6b1589567c62e7c350827c4df8954d1 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 9 Jul 2024 16:50:50 +0000 Subject: [PATCH 1/8] feat: initial start to mixed replay buffers --- flashbax/buffers/mixed_buffer.py | 82 ++++++++ flashbax/buffers/mixed_buffer_test.py | 288 ++++++++++++++++++++++++++ 2 files changed, 370 insertions(+) create mode 100644 flashbax/buffers/mixed_buffer.py create mode 100644 flashbax/buffers/mixed_buffer_test.py diff --git a/flashbax/buffers/mixed_buffer.py b/flashbax/buffers/mixed_buffer.py new file mode 100644 index 0000000..5ef2a28 --- /dev/null +++ b/flashbax/buffers/mixed_buffer.py @@ -0,0 +1,82 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Sequence + +import chex +import jax +import jax.numpy as jnp + +from flashbax.buffers.trajectory_buffer import BufferState + + +def mixed_sample( + buffer_state_list: Sequence[BufferState], + rng_key: chex.Array, + buffer_sample_fns: Sequence[Callable[[BufferState], Any]], + proportions: Sequence[float], + sample_batch_size: int, +) -> Any: + """ + Sample from a mixed buffer, which is a list of buffer states, each with its own sample function. + + Each buffer sample needs to be of the same pytree structure, and the samples are concatenated along the first axis i.e. the batch axis. + For example, if you are sampling trajectories, then all samples need to be sequences of the same sequence length but batch sizes can differ. + """ + assert len(buffer_state_list) == len( + buffer_sample_fns + ), "Number of buffer states and sample functions must match" + assert len(buffer_state_list) == len( + proportions + ), "Number of buffer states and proportions must match" + assert sum(proportions) == 1.0, "Proportions must sum to 1" + local_batch_sizes = [int(sample_batch_size * p) for p in proportions] + if sum(local_batch_sizes) != sample_batch_size: + local_batch_sizes[-1] += sample_batch_size - sum(local_batch_sizes) + # Sample from each buffer + buffer_samples = [] + for buffer_idx, buffer_state in enumerate(buffer_state_list): + rng_key, sample_key = jax.random.split(rng_key) + buffer_state = buffer_state_list[buffer_idx] + buffer_sample_fn = buffer_sample_fns[buffer_idx] + sampled_data = buffer_sample_fn(buffer_state, sample_key) + size_to_sample = local_batch_sizes[buffer_idx] + sampled_data = jax.tree_map(lambda x: x[:size_to_sample], sampled_data) + buffer_samples.append(sampled_data) + + # Concatenate the samples + buffer_samples = jax.tree.map( + lambda *x: jnp.concatenate(x, axis=0), *buffer_samples + ) + + return buffer_samples + + +def joint_mixed_add( + buffer_state_list: Sequence[BufferState], + data: Any, + buffer_add_fns: Sequence[Callable[[BufferState, Any], BufferState]], +) -> Sequence[BufferState]: + """ + Add data to a mixed buffer, which is a list of buffer states, each with its own add function. + """ + assert len(buffer_state_list) == len( + buffer_add_fns + ), "Number of buffer states and add functions must match" + for buffer_idx, buffer_state in enumerate(buffer_state_list): + buffer_state = buffer_state_list[buffer_idx] + buffer_add_fn = buffer_add_fns[buffer_idx] + buffer_state = buffer_add_fn(buffer_state, data) + buffer_state_list[buffer_idx] = buffer_state + return buffer_state_list diff --git a/flashbax/buffers/mixed_buffer_test.py b/flashbax/buffers/mixed_buffer_test.py new file mode 100644 index 0000000..e705885 --- /dev/null +++ b/flashbax/buffers/mixed_buffer_test.py @@ -0,0 +1,288 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from functools import partial + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from flashbax.buffers import ( + flat_buffer, + prioritised_trajectory_buffer, + trajectory_buffer, +) +from flashbax.buffers.conftest import get_fake_batch +from flashbax.buffers.mixed_buffer import joint_mixed_add, mixed_sample + + +@pytest.fixture +def rng_key() -> chex.PRNGKey: + return jax.random.PRNGKey(0) + + +@pytest.fixture +def fake_transition() -> chex.ArrayTree: + return { + "obs": jnp.array([0.0, 0.0]), + "reward": jnp.array(0.0), + "done": jnp.array(False), + "next_obs": jnp.array([0.0, 0.0]), + } + + +@pytest.fixture +def max_length() -> int: + return 32 + + +@pytest.fixture +def min_length() -> int: + return 8 + + +@pytest.fixture +def add_batch_size() -> int: + return 4 + + +@pytest.fixture +def sample_batch_size() -> int: + return 100 + + +@pytest.fixture +def sample_sequence_length() -> int: + return 4 + + +@pytest.fixture +def sample_period() -> int: + return 1 + + +def get_fake_batch_sequence( + fake_transition: chex.ArrayTree, batch_size: int, sequence_length: int +) -> chex.ArrayTree: + return get_fake_batch(get_fake_batch(fake_transition, sequence_length), batch_size) + + +def test_mixed_trajectory_sample( + rng_key, + sample_batch_size, + sample_period, + add_batch_size, + sample_sequence_length, + fake_transition, +): + buffer_states = [] + buffer_sample_fns = [] + for i in range(3): + buffer = trajectory_buffer.make_trajectory_buffer( + max_length_time_axis=200 * (i + 1), + min_length_time_axis=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=sample_period, + ) + state = buffer.init( + jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + buffer_sample_fns.append(buffer.sample) + + proportions = [0.2, 0.2, 0.6] + samples = mixed_sample( + buffer_states, rng_key, buffer_sample_fns, proportions, sample_batch_size + ) + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience["done"] + dones = dones[:, 0] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + +def test_mixed_prioritised_trajectory_sample( + rng_key, + sample_batch_size, + sample_period, + add_batch_size, + sample_sequence_length, + fake_transition, +): + buffer_states = [] + buffer_sample_fns = [] + for i in range(3): + buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer( + max_length_time_axis=200 * (i + 1), + min_length_time_axis=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=sample_period, + ) + state = buffer.init( + jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + buffer_sample_fns.append(buffer.sample) + + proportions = [0.4, 0.1, 0.5] + samples = mixed_sample( + buffer_states, rng_key, buffer_sample_fns, proportions, sample_batch_size + ) + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience["done"] + dones = dones[:, 0] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + +def test_mixed_flat_buffer_sample( + rng_key, sample_batch_size, add_batch_size, fake_transition +): + buffer_states = [] + buffer_sample_fns = [] + for i in range(3): + buffer = flat_buffer.make_flat_buffer( + max_length=200 * (i + 1), + min_length=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + add_sequences=True, + ) + state = buffer.init( + jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + buffer_sample_fns.append(buffer.sample) + + proportions = [0.1, 0.1, 0.8] + samples = mixed_sample( + buffer_states, rng_key, buffer_sample_fns, proportions, sample_batch_size + ) + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience.first["done"] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + +def test_joint_mixed_add(rng_key, fake_transition, add_batch_size): + buffer_states = [] + buffer_add_fns = [] + for i in range(3): + buffer = trajectory_buffer.make_trajectory_buffer( + max_length_time_axis=2000 * (i + 1), + min_length_time_axis=0, + sample_batch_size=100, + add_batch_size=add_batch_size, + sample_sequence_length=2, + period=1, + ) + state = buffer.init( + jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + buffer_states.append(state) + buffer_add_fns.append(buffer.add) + + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * 6, fake_add_data) + new_states = joint_mixed_add(buffer_states, fake_add_data, buffer_add_fns) + assert len(new_states) == len(buffer_states) + for state in new_states: + assert state is not None + assert jnp.sum(state.experience["done"] == 6) == add_batch_size * 50 + + +def test_mixed_buffer_does_not_smoke( + rng_key, + sample_batch_size, + sample_period, + add_batch_size, + sample_sequence_length, + fake_transition, +): + buffer_states = [] + buffer_sample_fns = [] + buffer_add_fns = [] + for i in range(3): + buffer = trajectory_buffer.make_trajectory_buffer( + max_length_time_axis=2000 * (i + 1), + min_length_time_axis=0, + sample_batch_size=sample_batch_size, + add_batch_size=add_batch_size, + sample_sequence_length=sample_sequence_length, + period=sample_period, + ) + state = buffer.init( + jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + ) + fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) + fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + state = buffer.add(state, fake_add_data) + buffer_states.append(state) + buffer_sample_fns.append(buffer.sample) + buffer_add_fns.append(buffer.add) + + proportions = [0.2, 0.2, 0.6] + # we expect to pre-instantiate the function with the buffer_sample_fns, proportions and sample_batch_size + mixed_sample_fn = partial( + mixed_sample, + buffer_sample_fns=buffer_sample_fns, + proportions=proportions, + sample_batch_size=sample_batch_size, + ) + samples = jax.jit(mixed_sample_fn)(buffer_states, rng_key) + assert samples is not None + expected_zeros = int(sample_batch_size * proportions[0]) + expected_ones = int(sample_batch_size * proportions[1]) + expected_twos = int(sample_batch_size * proportions[2]) + chex.assert_tree_shape_prefix(samples, (sample_batch_size,)) + dones = samples.experience["done"] + dones = dones[:, 0] + assert np.sum(dones == 0) == expected_zeros + assert np.sum(dones == 1) == expected_ones + assert np.sum(dones == 2) == expected_twos + + # mixed_joint_add_fn = partial(joint_mixed_add, buffer_add_fns=buffer_add_fns) From 00b1668017c65d30d2925f87c0d1a393fae5d2d5 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Tue, 16 Jul 2024 15:00:18 +0200 Subject: [PATCH 2/8] feat: more work on mixer util. --- flashbax/__init__.py | 2 + flashbax/buffers/__init__.py | 1 + flashbax/buffers/mixer.py | 106 +++++++++++++++ mixer_new_demo.ipynb | 253 +++++++++++++++++++++++++++++++++++ 4 files changed, 362 insertions(+) create mode 100644 flashbax/buffers/mixer.py create mode 100644 mixer_new_demo.ipynb diff --git a/flashbax/__init__.py b/flashbax/__init__.py index 576f501..d2f2c14 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -18,11 +18,13 @@ item_buffer, make_flat_buffer, make_item_buffer, + make_mixer, make_prioritised_flat_buffer, make_prioritised_item_buffer, make_prioritised_trajectory_buffer, make_trajectory_buffer, make_trajectory_queue, + mixer, prioritised_flat_buffer, prioritised_item_buffer, prioritised_trajectory_buffer, diff --git a/flashbax/buffers/__init__.py b/flashbax/buffers/__init__.py index 1e0096f..4325a07 100644 --- a/flashbax/buffers/__init__.py +++ b/flashbax/buffers/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from flashbax.buffers.flat_buffer import make_flat_buffer from flashbax.buffers.item_buffer import make_item_buffer +from flashbax.buffers.mixer import make_mixer from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer from flashbax.buffers.prioritised_item_buffer import make_prioritised_item_buffer from flashbax.buffers.prioritised_trajectory_buffer import ( diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py new file mode 100644 index 0000000..ef692d4 --- /dev/null +++ b/flashbax/buffers/mixer.py @@ -0,0 +1,106 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Callable + +import jax.numpy as jnp +from chex import dataclass +from jax.tree_util import tree_map + +from flashbax.buffers.trajectory_buffer import ( + TrajectoryBufferSample, + TrajectoryBufferState, +) + + +@dataclass(frozen=True) +class Mixer: + sample: Callable + can_sample: Callable + + +def sample_mixer_fn( + states, + key, + prop_batch_sizes, + sample_fns, +): + samples_array = tree_map( + lambda state, sample, key_in: sample(state, key_in), + states, + sample_fns, + [key] * len(sample_fns), # if key.ndim == 1 else key, + is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState, + ) + + def _slicer(sample, batch_slice): + return tree_map(lambda x: x[:batch_slice, ...], sample) + + prop_batch_samples_array = tree_map( + lambda x, p: _slicer(x, p), + samples_array, + prop_batch_sizes, + is_leaf=lambda leaf: type(leaf) == TrajectoryBufferSample, + ) + + joint_sample = tree_map( + lambda *x: jnp.concatenate(x, axis=0), + *prop_batch_samples_array, + ) + return joint_sample + + +def can_sample_mixer_fn( + states, + can_sample_fns, +): + each_can_sample = tree_map( + lambda state, can_sample: can_sample(state), + states, + can_sample_fns, + is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState, + ) + return all(each_can_sample) + + +def make_mixer( + buffers: list, + sample_batch_size: int, + proportions: list, +): + sample_fns = [b.sample for b in buffers] + can_sample_fns = [b.can_sample for b in buffers] + + props_sum = sum(proportions) + props_norm = [p / props_sum for p in proportions] + prop_batch_sizes = [int(p * sample_batch_size) for p in props_norm] + if sum(prop_batch_sizes) != sample_batch_size: + prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes) + + mixer_sample_fn = functools.partial( + sample_mixer_fn, + prop_batch_sizes=prop_batch_sizes, + sample_fns=sample_fns, + ) + + mixer_can_sample_fn = functools.partial( + can_sample_mixer_fn, + can_sample_fns=can_sample_fns, + ) + + return Mixer( + sample=mixer_sample_fn, + can_sample=mixer_can_sample_fn, + ) diff --git a/mixer_new_demo.ipynb b/mixer_new_demo.ipynb new file mode 100644 index 0000000..4510ec2 --- /dev/null +++ b/mixer_new_demo.ipynb @@ -0,0 +1,253 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import flashbax as fbx\n", + "import jax.numpy as jnp\n", + "from jax.tree_util import tree_map\n", + "import jax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1500\n" + ] + }, + { + "data": { + "text/plain": [ + "TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_a = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=10_000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=2,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_a = buffer_a.init(\n", + " timestep,\n", + ")\n", + "for i in range(1500):\n", + " state_a = buffer_a.add(\n", + " state_a,\n", + " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "print(state_a.current_index)\n", + "tree_map(lambda x: x.shape, state_a)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6000\n" + ] + }, + { + "data": { + "text/plain": [ + "TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_b = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=10_000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=13,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_b = buffer_b.init(\n", + " timestep,\n", + ")\n", + "for i in range(6000):\n", + " state_b = buffer_b.add(\n", + " state_b,\n", + " tree_map(lambda x, _i=i: (1000 - x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "print(state_b.current_index)\n", + "tree_map(lambda x: x.shape, state_b)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 5, 3)\n", + "(13, 5, 3)\n" + ] + } + ], + "source": [ + "sample_a = buffer_a.sample(state_a, key)\n", + "print(sample_a.experience['acts'].shape)\n", + "\n", + "sample_b = buffer_b.sample(state_b, key)\n", + "print(sample_b.experience['acts'].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "mixer = fbx.make_mixer(\n", + " buffers=[buffer_a, buffer_b],\n", + " sample_batch_size=8,\n", + " proportions=[2,3]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(6, 5, 3)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joint_sample = mixer.sample(\n", + " [state_a, state_b],\n", + " key,\n", + ")\n", + "\n", + "joint_sample.experience['acts'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "mixer = fbx.make_mixer(\n", + " buffers=[buffer_a, buffer_b],\n", + " sample_batch_size=8,\n", + " proportions=[0.1, 0.9]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8, 5, 3)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joint_sample = mixer.sample(\n", + " [state_a, state_b],\n", + " key,\n", + ")\n", + "\n", + "joint_sample.experience['acts'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flashbax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 890f480dc6ddd610166266a5131d301acba16241 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jul 2024 14:43:02 +0200 Subject: [PATCH 3/8] feat: a working mixer, with comments and docs --- flashbax/buffers/mixer.py | 149 ++++++++++++++++++++++++++++++++------ 1 file changed, 128 insertions(+), 21 deletions(-) diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py index ef692d4..af6df86 100644 --- a/flashbax/buffers/mixer.py +++ b/flashbax/buffers/mixer.py @@ -13,80 +13,187 @@ # limitations under the License. import functools -from typing import Callable +from typing import Callable, List, TypeVar +import chex +import jax import jax.numpy as jnp -from chex import dataclass +from chex import Numeric, dataclass from jax.tree_util import tree_map +from flashbax.buffers.flat_buffer import TransitionSample +from flashbax.buffers.prioritised_trajectory_buffer import ( + PrioritisedTrajectoryBuffer, + PrioritisedTrajectoryBufferSample, + PrioritisedTrajectoryBufferState, +) from flashbax.buffers.trajectory_buffer import ( + TrajectoryBuffer, TrajectoryBufferSample, TrajectoryBufferState, ) +# Support for Trajectory, Flat, Item buffers, and prioritised variants +sample_types = [ + TrajectoryBufferSample, + PrioritisedTrajectoryBufferSample, + TransitionSample, +] +SampleTypes = TypeVar( + "SampleTypes", + TrajectoryBufferSample, + PrioritisedTrajectoryBufferSample, + TransitionSample, +) + +state_types = [TrajectoryBufferState, PrioritisedTrajectoryBufferState] +StateTypes = TypeVar( + "StateTypes", TrajectoryBufferState, PrioritisedTrajectoryBufferState +) + +BufferTypes = TypeVar("BufferTypes", TrajectoryBuffer, PrioritisedTrajectoryBuffer) + @dataclass(frozen=True) class Mixer: + """Pure functions defining the mixer. + + Attributes: + sample (Callable): function to sample proportionally from all buffers, + concatenating along the batch axis + can_sample (Callable): function to check if all buffers can sample + """ + sample: Callable can_sample: Callable +def _batch_slicer( + sample: SampleTypes, + batch_start: int, + batch_end: int, +) -> SampleTypes: + """Simple utility function to slice a sample along the batch axis. + + Args: + sample (SampleTypes): incoming sample + batch_start (int): batch start index + batch_end (int): batch end index + + Returns: + SampleTypes: outgoing sliced sample + """ + return tree_map(lambda x: x[batch_start:batch_end, ...], sample) + + def sample_mixer_fn( - states, - key, - prop_batch_sizes, - sample_fns, -): + states: List[StateTypes], + key: chex.PRNGKey, + prop_batch_sizes: List[int], + sample_fns: List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]], +) -> SampleTypes: + """Perform mixed sampling from provided buffer states, according to provided proportions. + + Each buffer sample needs to be of the same pytree structure, and the samples are concatenated + along the first axis i.e. the batch axis. For example, if you are sampling trajectories, then + all samples need to be sequences of the same sequence length but batch sizes can differ. + + Args: + states (List[StateTypes]): list of buffer states + key (chex.PRNGKey): random key + prop_batch_sizes (List[Numeric]): list of batch sizes sampled from each buffer, calculated + according to the proportions of joint sample size + sample_fns (List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure sample + functions from each buffer + + Returns: + SampleTypes: proportionally concatenated samples from all buffers + """ + keys = jax.random.split( + key, len(states) + ) # Split the key for each buffer sampling operation + + # We first sample from each buffer, and get a list of samples samples_array = tree_map( lambda state, sample, key_in: sample(state, key_in), states, sample_fns, - [key] * len(sample_fns), # if key.ndim == 1 else key, - is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState, + list(keys), + is_leaf=lambda leaf: type(leaf) in state_types, ) - def _slicer(sample, batch_slice): - return tree_map(lambda x: x[:batch_slice, ...], sample) - + # We then slice the samples according to the proportions prop_batch_samples_array = tree_map( - lambda x, p: _slicer(x, p), + lambda x, p: _batch_slicer(x, 0, p), samples_array, prop_batch_sizes, - is_leaf=lambda leaf: type(leaf) == TrajectoryBufferSample, + is_leaf=lambda leaf: type(leaf) in sample_types, ) + # Concatenate the samples along the batch axis joint_sample = tree_map( lambda *x: jnp.concatenate(x, axis=0), *prop_batch_samples_array, ) + return joint_sample def can_sample_mixer_fn( - states, - can_sample_fns, -): + states: List[StateTypes], can_sample_fns: List[Callable[[StateTypes], bool]] +) -> bool: + """Check if all buffers can sample. + + Args: + states (List[StateTypes]): list of buffer states + can_sample_fns (List[Callable[[StateTypes], bool]]): list of can_sample functions + from each buffer + + Returns: + bool: whether all buffers can sample + """ each_can_sample = tree_map( lambda state, can_sample: can_sample(state), states, can_sample_fns, - is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState, + is_leaf=lambda leaf: type(leaf) in state_types, ) return all(each_can_sample) def make_mixer( - buffers: list, + buffers: List[BufferTypes], sample_batch_size: int, - proportions: list, -): + proportions: List[Numeric], +) -> Mixer: + """Create the mixer. + + Args: + buffers (List[BufferTypes]): list of buffers (pure functions) + sample_batch_size (int): desired batch size of joint sample + proportions (List[Numeric]): + Proportions of joint sample size to be sampled from each buffer, given as a ratio. + + Returns: + Mixer: a mixer + """ + assert len(buffers) == len( + proportions + ), "Number of buffers and proportions must match" + assert all( + isinstance(b, type(buffers[0])) for b in buffers + ), "All buffers must be of the same type" + assert sample_batch_size > 0, "Sample batch size must be greater than 0" + sample_fns = [b.sample for b in buffers] can_sample_fns = [b.can_sample for b in buffers] + # Normalize proportions and calculate resulting integer batch sizes props_sum = sum(proportions) props_norm = [p / props_sum for p in proportions] prop_batch_sizes = [int(p * sample_batch_size) for p in props_norm] if sum(prop_batch_sizes) != sample_batch_size: + # In case of rounding errors, add the remainder to the first buffer's proportion prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes) mixer_sample_fn = functools.partial( From c6eee8f0791f2a03b7785e2032498eabc9d6e4f4 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jul 2024 15:00:00 +0200 Subject: [PATCH 4/8] feat: tests for mixer. --- flashbax/buffers/mixed_buffer.py | 82 ------------- .../{mixed_buffer_test.py => mixer_test.py} | 114 ++++++++---------- 2 files changed, 51 insertions(+), 145 deletions(-) delete mode 100644 flashbax/buffers/mixed_buffer.py rename flashbax/buffers/{mixed_buffer_test.py => mixer_test.py} (69%) diff --git a/flashbax/buffers/mixed_buffer.py b/flashbax/buffers/mixed_buffer.py deleted file mode 100644 index 5ef2a28..0000000 --- a/flashbax/buffers/mixed_buffer.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2023 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Callable, Sequence - -import chex -import jax -import jax.numpy as jnp - -from flashbax.buffers.trajectory_buffer import BufferState - - -def mixed_sample( - buffer_state_list: Sequence[BufferState], - rng_key: chex.Array, - buffer_sample_fns: Sequence[Callable[[BufferState], Any]], - proportions: Sequence[float], - sample_batch_size: int, -) -> Any: - """ - Sample from a mixed buffer, which is a list of buffer states, each with its own sample function. - - Each buffer sample needs to be of the same pytree structure, and the samples are concatenated along the first axis i.e. the batch axis. - For example, if you are sampling trajectories, then all samples need to be sequences of the same sequence length but batch sizes can differ. - """ - assert len(buffer_state_list) == len( - buffer_sample_fns - ), "Number of buffer states and sample functions must match" - assert len(buffer_state_list) == len( - proportions - ), "Number of buffer states and proportions must match" - assert sum(proportions) == 1.0, "Proportions must sum to 1" - local_batch_sizes = [int(sample_batch_size * p) for p in proportions] - if sum(local_batch_sizes) != sample_batch_size: - local_batch_sizes[-1] += sample_batch_size - sum(local_batch_sizes) - # Sample from each buffer - buffer_samples = [] - for buffer_idx, buffer_state in enumerate(buffer_state_list): - rng_key, sample_key = jax.random.split(rng_key) - buffer_state = buffer_state_list[buffer_idx] - buffer_sample_fn = buffer_sample_fns[buffer_idx] - sampled_data = buffer_sample_fn(buffer_state, sample_key) - size_to_sample = local_batch_sizes[buffer_idx] - sampled_data = jax.tree_map(lambda x: x[:size_to_sample], sampled_data) - buffer_samples.append(sampled_data) - - # Concatenate the samples - buffer_samples = jax.tree.map( - lambda *x: jnp.concatenate(x, axis=0), *buffer_samples - ) - - return buffer_samples - - -def joint_mixed_add( - buffer_state_list: Sequence[BufferState], - data: Any, - buffer_add_fns: Sequence[Callable[[BufferState, Any], BufferState]], -) -> Sequence[BufferState]: - """ - Add data to a mixed buffer, which is a list of buffer states, each with its own add function. - """ - assert len(buffer_state_list) == len( - buffer_add_fns - ), "Number of buffer states and add functions must match" - for buffer_idx, buffer_state in enumerate(buffer_state_list): - buffer_state = buffer_state_list[buffer_idx] - buffer_add_fn = buffer_add_fns[buffer_idx] - buffer_state = buffer_add_fn(buffer_state, data) - buffer_state_list[buffer_idx] = buffer_state - return buffer_state_list diff --git a/flashbax/buffers/mixed_buffer_test.py b/flashbax/buffers/mixer_test.py similarity index 69% rename from flashbax/buffers/mixed_buffer_test.py rename to flashbax/buffers/mixer_test.py index e705885..435a7bf 100644 --- a/flashbax/buffers/mixed_buffer_test.py +++ b/flashbax/buffers/mixer_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from functools import partial import chex import jax @@ -27,7 +25,7 @@ trajectory_buffer, ) from flashbax.buffers.conftest import get_fake_batch -from flashbax.buffers.mixed_buffer import joint_mixed_add, mixed_sample +from flashbax.buffers.mixer import make_mixer @pytest.fixture @@ -89,8 +87,8 @@ def test_mixed_trajectory_sample( sample_sequence_length, fake_transition, ): + buffers = [] buffer_states = [] - buffer_sample_fns = [] for i in range(3): buffer = trajectory_buffer.make_trajectory_buffer( max_length_time_axis=200 * (i + 1), @@ -100,19 +98,26 @@ def test_mixed_trajectory_sample( sample_sequence_length=sample_sequence_length, period=sample_period, ) + buffers.append(buffer) + state = buffer.init( - jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) ) fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) - fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) state = buffer.add(state, fake_add_data) buffer_states.append(state) - buffer_sample_fns.append(buffer.sample) proportions = [0.2, 0.2, 0.6] - samples = mixed_sample( - buffer_states, rng_key, buffer_sample_fns, proportions, sample_batch_size + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, ) + samples = mixer.sample(buffer_states, rng_key) + assert samples is not None expected_zeros = int(sample_batch_size * proportions[0]) expected_ones = int(sample_batch_size * proportions[1]) @@ -133,8 +138,8 @@ def test_mixed_prioritised_trajectory_sample( sample_sequence_length, fake_transition, ): + buffers = [] buffer_states = [] - buffer_sample_fns = [] for i in range(3): buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer( max_length_time_axis=200 * (i + 1), @@ -144,19 +149,26 @@ def test_mixed_prioritised_trajectory_sample( sample_sequence_length=sample_sequence_length, period=sample_period, ) + buffers.append(buffer) + state = buffer.init( - jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) ) fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) - fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) state = buffer.add(state, fake_add_data) buffer_states.append(state) - buffer_sample_fns.append(buffer.sample) proportions = [0.4, 0.1, 0.5] - samples = mixed_sample( - buffer_states, rng_key, buffer_sample_fns, proportions, sample_batch_size + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, ) + samples = mixer.sample(buffer_states, rng_key) + assert samples is not None expected_zeros = int(sample_batch_size * proportions[0]) expected_ones = int(sample_batch_size * proportions[1]) @@ -172,8 +184,8 @@ def test_mixed_prioritised_trajectory_sample( def test_mixed_flat_buffer_sample( rng_key, sample_batch_size, add_batch_size, fake_transition ): + buffers = [] buffer_states = [] - buffer_sample_fns = [] for i in range(3): buffer = flat_buffer.make_flat_buffer( max_length=200 * (i + 1), @@ -182,19 +194,26 @@ def test_mixed_flat_buffer_sample( add_batch_size=add_batch_size, add_sequences=True, ) + buffers.append(buffer) + state = buffer.init( - jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) ) fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) - fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) state = buffer.add(state, fake_add_data) buffer_states.append(state) - buffer_sample_fns.append(buffer.sample) proportions = [0.1, 0.1, 0.8] - samples = mixed_sample( - buffer_states, rng_key, buffer_sample_fns, proportions, sample_batch_size + mixer = make_mixer( + buffers=buffers, + proportions=proportions, + sample_batch_size=sample_batch_size, ) + samples = mixer.sample(buffer_states, rng_key) + assert samples is not None expected_zeros = int(sample_batch_size * proportions[0]) expected_ones = int(sample_batch_size * proportions[1]) @@ -206,35 +225,6 @@ def test_mixed_flat_buffer_sample( assert np.sum(dones == 2) == expected_twos -def test_joint_mixed_add(rng_key, fake_transition, add_batch_size): - buffer_states = [] - buffer_add_fns = [] - for i in range(3): - buffer = trajectory_buffer.make_trajectory_buffer( - max_length_time_axis=2000 * (i + 1), - min_length_time_axis=0, - sample_batch_size=100, - add_batch_size=add_batch_size, - sample_sequence_length=2, - period=1, - ) - state = buffer.init( - jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) - ) - fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) - fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) - buffer_states.append(state) - buffer_add_fns.append(buffer.add) - - fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) - fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * 6, fake_add_data) - new_states = joint_mixed_add(buffer_states, fake_add_data, buffer_add_fns) - assert len(new_states) == len(buffer_states) - for state in new_states: - assert state is not None - assert jnp.sum(state.experience["done"] == 6) == add_batch_size * 50 - - def test_mixed_buffer_does_not_smoke( rng_key, sample_batch_size, @@ -243,9 +233,8 @@ def test_mixed_buffer_does_not_smoke( sample_sequence_length, fake_transition, ): + buffers = [] buffer_states = [] - buffer_sample_fns = [] - buffer_add_fns = [] for i in range(3): buffer = trajectory_buffer.make_trajectory_buffer( max_length_time_axis=2000 * (i + 1), @@ -255,25 +244,26 @@ def test_mixed_buffer_does_not_smoke( sample_sequence_length=sample_sequence_length, period=sample_period, ) + buffers.append(buffer) + state = buffer.init( - jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_transition) + jax.tree_map(lambda x, _i=i: jnp.ones_like(x) * _i, fake_transition) ) fake_add_data = get_fake_batch_sequence(fake_transition, add_batch_size, 50) - fake_add_data = jax.tree_map(lambda x: jnp.ones_like(x) * i, fake_add_data) + fake_add_data = jax.tree_map( + lambda x, _i=i: jnp.ones_like(x) * _i, fake_add_data + ) state = buffer.add(state, fake_add_data) buffer_states.append(state) - buffer_sample_fns.append(buffer.sample) - buffer_add_fns.append(buffer.add) proportions = [0.2, 0.2, 0.6] - # we expect to pre-instantiate the function with the buffer_sample_fns, proportions and sample_batch_size - mixed_sample_fn = partial( - mixed_sample, - buffer_sample_fns=buffer_sample_fns, + mixer = make_mixer( + buffers=buffers, proportions=proportions, sample_batch_size=sample_batch_size, ) - samples = jax.jit(mixed_sample_fn)(buffer_states, rng_key) + samples = jax.jit(mixer.sample)(buffer_states, rng_key) + assert samples is not None expected_zeros = int(sample_batch_size * proportions[0]) expected_ones = int(sample_batch_size * proportions[1]) @@ -284,5 +274,3 @@ def test_mixed_buffer_does_not_smoke( assert np.sum(dones == 0) == expected_zeros assert np.sum(dones == 1) == expected_ones assert np.sum(dones == 2) == expected_twos - - # mixed_joint_add_fn = partial(joint_mixed_add, buffer_add_fns=buffer_add_fns) From 7874f384fad1e97250c50c99e9118e1436a33cfb Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jul 2024 15:13:24 +0200 Subject: [PATCH 5/8] chore: typing in mixer tests; fix List type in mixer --- flashbax/buffers/mixer.py | 32 +++++++++++++------------- flashbax/buffers/mixer_test.py | 41 ++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py index af6df86..bd27d9b 100644 --- a/flashbax/buffers/mixer.py +++ b/flashbax/buffers/mixer.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Callable, List, TypeVar +from typing import Callable, Sequence, TypeVar import chex import jax @@ -87,10 +87,10 @@ def _batch_slicer( def sample_mixer_fn( - states: List[StateTypes], + states: Sequence[StateTypes], key: chex.PRNGKey, - prop_batch_sizes: List[int], - sample_fns: List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]], + prop_batch_sizes: Sequence[int], + sample_fns: Sequence[Callable[[StateTypes, chex.PRNGKey], SampleTypes]], ) -> SampleTypes: """Perform mixed sampling from provided buffer states, according to provided proportions. @@ -99,12 +99,12 @@ def sample_mixer_fn( all samples need to be sequences of the same sequence length but batch sizes can differ. Args: - states (List[StateTypes]): list of buffer states + states (Sequence[StateTypes]): list of buffer states key (chex.PRNGKey): random key - prop_batch_sizes (List[Numeric]): list of batch sizes sampled from each buffer, calculated - according to the proportions of joint sample size - sample_fns (List[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure sample - functions from each buffer + prop_batch_sizes (Sequence[Numeric]): list of batch sizes sampled from each buffer, + calculated according to the proportions of joint sample size + sample_fns (Sequence[Callable[[StateTypes, chex.PRNGKey], SampleTypes]]): list of pure + sample functions from each buffer Returns: SampleTypes: proportionally concatenated samples from all buffers @@ -140,13 +140,13 @@ def sample_mixer_fn( def can_sample_mixer_fn( - states: List[StateTypes], can_sample_fns: List[Callable[[StateTypes], bool]] + states: Sequence[StateTypes], can_sample_fns: Sequence[Callable[[StateTypes], bool]] ) -> bool: """Check if all buffers can sample. Args: - states (List[StateTypes]): list of buffer states - can_sample_fns (List[Callable[[StateTypes], bool]]): list of can_sample functions + states (Sequence[StateTypes]): list of buffer states + can_sample_fns (Sequence[Callable[[StateTypes], bool]]): list of can_sample functions from each buffer Returns: @@ -162,16 +162,16 @@ def can_sample_mixer_fn( def make_mixer( - buffers: List[BufferTypes], + buffers: Sequence[BufferTypes], sample_batch_size: int, - proportions: List[Numeric], + proportions: Sequence[Numeric], ) -> Mixer: """Create the mixer. Args: - buffers (List[BufferTypes]): list of buffers (pure functions) + buffers (Sequence[BufferTypes]): list of buffers (pure functions) sample_batch_size (int): desired batch size of joint sample - proportions (List[Numeric]): + proportions (Sequence[Numeric]): Proportions of joint sample size to be sampled from each buffer, given as a ratio. Returns: diff --git a/flashbax/buffers/mixer_test.py b/flashbax/buffers/mixer_test.py index 435a7bf..ea037a4 100644 --- a/flashbax/buffers/mixer_test.py +++ b/flashbax/buffers/mixer_test.py @@ -80,12 +80,12 @@ def get_fake_batch_sequence( def test_mixed_trajectory_sample( - rng_key, - sample_batch_size, - sample_period, - add_batch_size, - sample_sequence_length, - fake_transition, + rng_key: chex.PRNGKey, + sample_batch_size: int, + sample_period: int, + add_batch_size: int, + sample_sequence_length: int, + fake_transition: chex.ArrayTree, ): buffers = [] buffer_states = [] @@ -131,12 +131,12 @@ def test_mixed_trajectory_sample( def test_mixed_prioritised_trajectory_sample( - rng_key, - sample_batch_size, - sample_period, - add_batch_size, - sample_sequence_length, - fake_transition, + rng_key: chex.PRNGKey, + sample_batch_size: int, + sample_period: int, + add_batch_size: int, + sample_sequence_length: int, + fake_transition: chex.ArrayTree, ): buffers = [] buffer_states = [] @@ -182,7 +182,10 @@ def test_mixed_prioritised_trajectory_sample( def test_mixed_flat_buffer_sample( - rng_key, sample_batch_size, add_batch_size, fake_transition + rng_key: chex.PRNGKey, + sample_batch_size: int, + add_batch_size: int, + fake_transition: chex.ArrayTree, ): buffers = [] buffer_states = [] @@ -226,12 +229,12 @@ def test_mixed_flat_buffer_sample( def test_mixed_buffer_does_not_smoke( - rng_key, - sample_batch_size, - sample_period, - add_batch_size, - sample_sequence_length, - fake_transition, + rng_key: chex.PRNGKey, + sample_batch_size: int, + sample_period: int, + add_batch_size: int, + sample_sequence_length: int, + fake_transition: chex.ArrayTree, ): buffers = [] buffer_states = [] From fc4269228bf13a4d4ba70fc03eee062edbd0b3d0 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jul 2024 15:44:04 +0200 Subject: [PATCH 6/8] feat: simple demo notebook. --- examples/mixer_demonstration.ipynb | 303 +++++++++++++++++++++++++++++ mixer_new_demo.ipynb | 253 ------------------------ 2 files changed, 303 insertions(+), 253 deletions(-) create mode 100644 examples/mixer_demonstration.ipynb delete mode 100644 mixer_new_demo.ipynb diff --git a/examples/mixer_demonstration.ipynb b/examples/mixer_demonstration.ipynb new file mode 100644 index 0000000..15fe47b --- /dev/null +++ b/examples/mixer_demonstration.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import flashbax as fbx\n", + "import jax.numpy as jnp\n", + "from jax.tree_util import tree_map\n", + "import jax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_a = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=1000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=4,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_a = buffer_a.init(\n", + " timestep,\n", + ")\n", + "for i in range(100):\n", + " state_a = jax.jit(buffer_a.add, donate_argnums=0)(\n", + " state_a,\n", + " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "sample_a = buffer_a.sample(state_a, key)\n", + "tree_map(lambda x: x.shape, sample_a)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_b = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=1000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=16,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_b = buffer_b.init(\n", + " timestep,\n", + ")\n", + "for i in range(100):\n", + " state_b = jax.jit(buffer_b.add, donate_argnums=0)(\n", + " state_b,\n", + " tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "sample_b = buffer_b.sample(state_b, key)\n", + "tree_map(lambda x: x.shape, sample_b)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "mixer = fbx.make_mixer(\n", + " buffers=[buffer_a, buffer_b],\n", + " sample_batch_size=8,\n", + " proportions=[1,3],\n", + ")\n", + "mixer_sample = jax.jit(mixer.sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joint_sample = mixer_sample(\n", + " [state_a, state_b],\n", + " key,\n", + ")\n", + "\n", + "tree_map(lambda x: x.shape, joint_sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': Array([[[90., 90., 90.],\n", + " [91., 91., 91.],\n", + " [92., 92., 92.],\n", + " [93., 93., 93.],\n", + " [94., 94., 94.]],\n", + "\n", + " [[56., 56., 56.],\n", + " [57., 57., 57.],\n", + " [58., 58., 58.],\n", + " [59., 59., 59.],\n", + " [60., 60., 60.]]], dtype=float32), 'obs': Array([[[90., 90.],\n", + " [91., 91.],\n", + " [92., 92.],\n", + " [93., 93.],\n", + " [94., 94.]],\n", + "\n", + " [[56., 56.],\n", + " [57., 57.],\n", + " [58., 58.],\n", + " [59., 59.],\n", + " [60., 60.]]], dtype=float32)})" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Notice how the first 1/4 * 8 = 2 batches are from buffer_a (POSITIVE VALUES)\n", + "tree_map(lambda x: x[0:2], joint_sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TrajectoryBufferSample(experience={'acts': Array([[[-34., -34., -34.],\n", + " [-35., -35., -35.],\n", + " [-36., -36., -36.],\n", + " [-37., -37., -37.],\n", + " [-38., -38., -38.]],\n", + "\n", + " [[-88., -88., -88.],\n", + " [-89., -89., -89.],\n", + " [-90., -90., -90.],\n", + " [-91., -91., -91.],\n", + " [-92., -92., -92.]],\n", + "\n", + " [[-30., -30., -30.],\n", + " [-31., -31., -31.],\n", + " [-32., -32., -32.],\n", + " [-33., -33., -33.],\n", + " [-34., -34., -34.]],\n", + "\n", + " [[-11., -11., -11.],\n", + " [-12., -12., -12.],\n", + " [-13., -13., -13.],\n", + " [-14., -14., -14.],\n", + " [-15., -15., -15.]],\n", + "\n", + " [[-78., -78., -78.],\n", + " [-79., -79., -79.],\n", + " [-80., -80., -80.],\n", + " [-81., -81., -81.],\n", + " [-82., -82., -82.]],\n", + "\n", + " [[-15., -15., -15.],\n", + " [-16., -16., -16.],\n", + " [-17., -17., -17.],\n", + " [-18., -18., -18.],\n", + " [-19., -19., -19.]]], dtype=float32), 'obs': Array([[[-34., -34.],\n", + " [-35., -35.],\n", + " [-36., -36.],\n", + " [-37., -37.],\n", + " [-38., -38.]],\n", + "\n", + " [[-88., -88.],\n", + " [-89., -89.],\n", + " [-90., -90.],\n", + " [-91., -91.],\n", + " [-92., -92.]],\n", + "\n", + " [[-30., -30.],\n", + " [-31., -31.],\n", + " [-32., -32.],\n", + " [-33., -33.],\n", + " [-34., -34.]],\n", + "\n", + " [[-11., -11.],\n", + " [-12., -12.],\n", + " [-13., -13.],\n", + " [-14., -14.],\n", + " [-15., -15.]],\n", + "\n", + " [[-78., -78.],\n", + " [-79., -79.],\n", + " [-80., -80.],\n", + " [-81., -81.],\n", + " [-82., -82.]],\n", + "\n", + " [[-15., -15.],\n", + " [-16., -16.],\n", + " [-17., -17.],\n", + " [-18., -18.],\n", + " [-19., -19.]]], dtype=float32)})" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# and how the second 3/4 * 8 = 6 batches are from buffer_b (NEGATIVE VALUES)\n", + "tree_map(lambda x: x[2:], joint_sample)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flashbax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mixer_new_demo.ipynb b/mixer_new_demo.ipynb deleted file mode 100644 index 4510ec2..0000000 --- a/mixer_new_demo.ipynb +++ /dev/null @@ -1,253 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import flashbax as fbx\n", - "import jax.numpy as jnp\n", - "from jax.tree_util import tree_map\n", - "import jax" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "key = jax.random.PRNGKey(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1500\n" - ] - }, - { - "data": { - "text/plain": [ - "TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "buffer_a = fbx.make_trajectory_buffer(\n", - " add_batch_size=1,\n", - " max_length_time_axis=10_000,\n", - " min_length_time_axis=5,\n", - " sample_sequence_length=5,\n", - " period=1,\n", - " sample_batch_size=2,\n", - ")\n", - "\n", - "timestep = {\n", - " \"obs\": jnp.ones((2)),\n", - " \"acts\": jnp.ones(3),\n", - "}\n", - "\n", - "state_a = buffer_a.init(\n", - " timestep,\n", - ")\n", - "for i in range(1500):\n", - " state_a = buffer_a.add(\n", - " state_a,\n", - " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", - " )\n", - "\n", - "print(state_a.current_index)\n", - "tree_map(lambda x: x.shape, state_a)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "6000\n" - ] - }, - { - "data": { - "text/plain": [ - "TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "buffer_b = fbx.make_trajectory_buffer(\n", - " add_batch_size=1,\n", - " max_length_time_axis=10_000,\n", - " min_length_time_axis=5,\n", - " sample_sequence_length=5,\n", - " period=1,\n", - " sample_batch_size=13,\n", - ")\n", - "\n", - "timestep = {\n", - " \"obs\": jnp.ones((2)),\n", - " \"acts\": jnp.ones(3),\n", - "}\n", - "\n", - "state_b = buffer_b.init(\n", - " timestep,\n", - ")\n", - "for i in range(6000):\n", - " state_b = buffer_b.add(\n", - " state_b,\n", - " tree_map(lambda x, _i=i: (1000 - x * _i)[None, None, ...], timestep),\n", - " )\n", - "\n", - "print(state_b.current_index)\n", - "tree_map(lambda x: x.shape, state_b)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(2, 5, 3)\n", - "(13, 5, 3)\n" - ] - } - ], - "source": [ - "sample_a = buffer_a.sample(state_a, key)\n", - "print(sample_a.experience['acts'].shape)\n", - "\n", - "sample_b = buffer_b.sample(state_b, key)\n", - "print(sample_b.experience['acts'].shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "mixer = fbx.make_mixer(\n", - " buffers=[buffer_a, buffer_b],\n", - " sample_batch_size=8,\n", - " proportions=[2,3]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(6, 5, 3)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "joint_sample = mixer.sample(\n", - " [state_a, state_b],\n", - " key,\n", - ")\n", - "\n", - "joint_sample.experience['acts'].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "mixer = fbx.make_mixer(\n", - " buffers=[buffer_a, buffer_b],\n", - " sample_batch_size=8,\n", - " proportions=[0.1, 0.9]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(8, 5, 3)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "joint_sample = mixer.sample(\n", - " [state_a, state_b],\n", - " key,\n", - ")\n", - "\n", - "joint_sample.experience['acts'].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "flashbax", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From d34b81630736e1f4427b0df9d1f95b3a2693360f Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Wed, 17 Jul 2024 15:46:25 +0200 Subject: [PATCH 7/8] feat: add some comments to demo notebook. --- examples/mixer_demonstration.ipynb | 42 ++++++++++++++++-------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/examples/mixer_demonstration.ipynb b/examples/mixer_demonstration.ipynb index 15fe47b..7b07a40 100644 --- a/examples/mixer_demonstration.ipynb +++ b/examples/mixer_demonstration.ipynb @@ -9,21 +9,14 @@ "import flashbax as fbx\n", "import jax.numpy as jnp\n", "from jax.tree_util import tree_map\n", - "import jax" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ + "import jax\n", + "\n", "key = jax.random.PRNGKey(0)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -32,12 +25,13 @@ "TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})" ] }, - "execution_count": 7, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Create our first buffer, with a sample batch size of 4\n", "buffer_a = fbx.make_trajectory_buffer(\n", " add_batch_size=1,\n", " max_length_time_axis=1000,\n", @@ -56,6 +50,7 @@ " timestep,\n", ")\n", "for i in range(100):\n", + " # Fill with POSITIVE values\n", " state_a = jax.jit(buffer_a.add, donate_argnums=0)(\n", " state_a,\n", " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", @@ -67,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -76,12 +71,13 @@ "TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})" ] }, - "execution_count": 8, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Create our second buffer, with a sample batch size of 16\n", "buffer_b = fbx.make_trajectory_buffer(\n", " add_batch_size=1,\n", " max_length_time_axis=1000,\n", @@ -100,6 +96,7 @@ " timestep,\n", ")\n", "for i in range(100):\n", + " # Fill with NEGATIVE values\n", " state_b = jax.jit(buffer_b.add, donate_argnums=0)(\n", " state_b,\n", " tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),\n", @@ -111,21 +108,24 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ + "# Make the mixer, with a ratio of 1:3 from buffer_a:buffer_b\n", "mixer = fbx.make_mixer(\n", " buffers=[buffer_a, buffer_b],\n", " sample_batch_size=8,\n", " proportions=[1,3],\n", ")\n", + "\n", + "# jittable sampling!\n", "mixer_sample = jax.jit(mixer.sample)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -134,23 +134,25 @@ "TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})" ] }, - "execution_count": 13, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Sample from the mixer, using the usual flashbax API\n", "joint_sample = mixer_sample(\n", " [state_a, state_b],\n", " key,\n", ")\n", "\n", + "# Notice the resulting shape\n", "tree_map(lambda x: x.shape, joint_sample)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -179,7 +181,7 @@ " [60., 60.]]], dtype=float32)})" ] }, - "execution_count": 11, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -191,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -268,7 +270,7 @@ " [-19., -19.]]], dtype=float32)})" ] }, - "execution_count": 12, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } From 989976c272cba666277efc7038c27a54df2880df Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 20 Sep 2024 15:08:07 +0200 Subject: [PATCH 8/8] fix: issue with can_sample not being jittable --- flashbax/buffers/mixer.py | 22 +++++++++++++--------- flashbax/buffers/mixer_test.py | 4 ++++ pyproject.toml | 2 +- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py index bd27d9b..95417a8 100644 --- a/flashbax/buffers/mixer.py +++ b/flashbax/buffers/mixer.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from chex import Numeric, dataclass +from jax import Array from jax.tree_util import tree_map from flashbax.buffers.flat_buffer import TransitionSample @@ -140,25 +141,28 @@ def sample_mixer_fn( def can_sample_mixer_fn( - states: Sequence[StateTypes], can_sample_fns: Sequence[Callable[[StateTypes], bool]] -) -> bool: + states: Sequence[StateTypes], + can_sample_fns: Sequence[Callable[[StateTypes], Array]], +) -> Array: """Check if all buffers can sample. Args: states (Sequence[StateTypes]): list of buffer states - can_sample_fns (Sequence[Callable[[StateTypes], bool]]): list of can_sample functions + can_sample_fns (Sequence[Callable[[StateTypes], Array]]): list of can_sample functions from each buffer Returns: bool: whether all buffers can sample """ - each_can_sample = tree_map( - lambda state, can_sample: can_sample(state), - states, - can_sample_fns, - is_leaf=lambda leaf: type(leaf) in state_types, + each_can_sample = jnp.asarray( + tree_map( + lambda state, can_sample: can_sample(state), + states, + can_sample_fns, + is_leaf=lambda leaf: type(leaf) in state_types, + ) ) - return all(each_can_sample) + return jnp.all(each_can_sample) def make_mixer( diff --git a/flashbax/buffers/mixer_test.py b/flashbax/buffers/mixer_test.py index ea037a4..bfcf6f2 100644 --- a/flashbax/buffers/mixer_test.py +++ b/flashbax/buffers/mixer_test.py @@ -265,6 +265,10 @@ def test_mixed_buffer_does_not_smoke( proportions=proportions, sample_batch_size=sample_batch_size, ) + + can_sample = jax.jit(mixer.can_sample)(buffer_states) + assert can_sample + samples = jax.jit(mixer.sample)(buffer_states, rng_key) assert samples is not None diff --git a/pyproject.toml b/pyproject.toml index 1f7f379..daf68a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dev = [ 'mkdocs-mermaid2-plugin==1.1.1', 'mkdocstrings[python]==0.23.0', 'mknotebooks==0.8.0', - 'mypy>=0.982', + 'mypy>=1.8.0', 'pre-commit>=2.20.0', 'pytest>=7.4.2', 'pytest-cov>=4.00',