diff --git a/flashbax/__init__.py b/flashbax/__init__.py index 576f501..d2f2c14 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -18,11 +18,13 @@ item_buffer, make_flat_buffer, make_item_buffer, + make_mixer, make_prioritised_flat_buffer, make_prioritised_item_buffer, make_prioritised_trajectory_buffer, make_trajectory_buffer, make_trajectory_queue, + mixer, prioritised_flat_buffer, prioritised_item_buffer, prioritised_trajectory_buffer, diff --git a/flashbax/buffers/__init__.py b/flashbax/buffers/__init__.py index 1e0096f..4325a07 100644 --- a/flashbax/buffers/__init__.py +++ b/flashbax/buffers/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from flashbax.buffers.flat_buffer import make_flat_buffer from flashbax.buffers.item_buffer import make_item_buffer +from flashbax.buffers.mixer import make_mixer from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer from flashbax.buffers.prioritised_item_buffer import make_prioritised_item_buffer from flashbax.buffers.prioritised_trajectory_buffer import ( diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py new file mode 100644 index 0000000..ef692d4 --- /dev/null +++ b/flashbax/buffers/mixer.py @@ -0,0 +1,106 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Callable + +import jax.numpy as jnp +from chex import dataclass +from jax.tree_util import tree_map + +from flashbax.buffers.trajectory_buffer import ( + TrajectoryBufferSample, + TrajectoryBufferState, +) + + +@dataclass(frozen=True) +class Mixer: + sample: Callable + can_sample: Callable + + +def sample_mixer_fn( + states, + key, + prop_batch_sizes, + sample_fns, +): + samples_array = tree_map( + lambda state, sample, key_in: sample(state, key_in), + states, + sample_fns, + [key] * len(sample_fns), # if key.ndim == 1 else key, + is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState, + ) + + def _slicer(sample, batch_slice): + return tree_map(lambda x: x[:batch_slice, ...], sample) + + prop_batch_samples_array = tree_map( + lambda x, p: _slicer(x, p), + samples_array, + prop_batch_sizes, + is_leaf=lambda leaf: type(leaf) == TrajectoryBufferSample, + ) + + joint_sample = tree_map( + lambda *x: jnp.concatenate(x, axis=0), + *prop_batch_samples_array, + ) + return joint_sample + + +def can_sample_mixer_fn( + states, + can_sample_fns, +): + each_can_sample = tree_map( + lambda state, can_sample: can_sample(state), + states, + can_sample_fns, + is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState, + ) + return all(each_can_sample) + + +def make_mixer( + buffers: list, + sample_batch_size: int, + proportions: list, +): + sample_fns = [b.sample for b in buffers] + can_sample_fns = [b.can_sample for b in buffers] + + props_sum = sum(proportions) + props_norm = [p / props_sum for p in proportions] + prop_batch_sizes = [int(p * sample_batch_size) for p in props_norm] + if sum(prop_batch_sizes) != sample_batch_size: + prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes) + + mixer_sample_fn = functools.partial( + sample_mixer_fn, + prop_batch_sizes=prop_batch_sizes, + sample_fns=sample_fns, + ) + + mixer_can_sample_fn = functools.partial( + can_sample_mixer_fn, + can_sample_fns=can_sample_fns, + ) + + return Mixer( + sample=mixer_sample_fn, + can_sample=mixer_can_sample_fn, + ) diff --git a/mixer_new_demo.ipynb b/mixer_new_demo.ipynb new file mode 100644 index 0000000..4510ec2 --- /dev/null +++ b/mixer_new_demo.ipynb @@ -0,0 +1,253 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import flashbax as fbx\n", + "import jax.numpy as jnp\n", + "from jax.tree_util import tree_map\n", + "import jax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1500\n" + ] + }, + { + "data": { + "text/plain": [ + "TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_a = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=10_000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=2,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_a = buffer_a.init(\n", + " timestep,\n", + ")\n", + "for i in range(1500):\n", + " state_a = buffer_a.add(\n", + " state_a,\n", + " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "print(state_a.current_index)\n", + "tree_map(lambda x: x.shape, state_a)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6000\n" + ] + }, + { + "data": { + "text/plain": [ + "TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "buffer_b = fbx.make_trajectory_buffer(\n", + " add_batch_size=1,\n", + " max_length_time_axis=10_000,\n", + " min_length_time_axis=5,\n", + " sample_sequence_length=5,\n", + " period=1,\n", + " sample_batch_size=13,\n", + ")\n", + "\n", + "timestep = {\n", + " \"obs\": jnp.ones((2)),\n", + " \"acts\": jnp.ones(3),\n", + "}\n", + "\n", + "state_b = buffer_b.init(\n", + " timestep,\n", + ")\n", + "for i in range(6000):\n", + " state_b = buffer_b.add(\n", + " state_b,\n", + " tree_map(lambda x, _i=i: (1000 - x * _i)[None, None, ...], timestep),\n", + " )\n", + "\n", + "print(state_b.current_index)\n", + "tree_map(lambda x: x.shape, state_b)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 5, 3)\n", + "(13, 5, 3)\n" + ] + } + ], + "source": [ + "sample_a = buffer_a.sample(state_a, key)\n", + "print(sample_a.experience['acts'].shape)\n", + "\n", + "sample_b = buffer_b.sample(state_b, key)\n", + "print(sample_b.experience['acts'].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "mixer = fbx.make_mixer(\n", + " buffers=[buffer_a, buffer_b],\n", + " sample_batch_size=8,\n", + " proportions=[2,3]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(6, 5, 3)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joint_sample = mixer.sample(\n", + " [state_a, state_b],\n", + " key,\n", + ")\n", + "\n", + "joint_sample.experience['acts'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "mixer = fbx.make_mixer(\n", + " buffers=[buffer_a, buffer_b],\n", + " sample_batch_size=8,\n", + " proportions=[0.1, 0.9]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8, 5, 3)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joint_sample = mixer.sample(\n", + " [state_a, state_b],\n", + " key,\n", + ")\n", + "\n", + "joint_sample.experience['acts'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flashbax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}