Skip to content

Commit

Permalink
feat: more work on mixer util.
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Jul 16, 2024
1 parent 0a84071 commit 00b1668
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 0 deletions.
2 changes: 2 additions & 0 deletions flashbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
item_buffer,
make_flat_buffer,
make_item_buffer,
make_mixer,
make_prioritised_flat_buffer,
make_prioritised_item_buffer,
make_prioritised_trajectory_buffer,
make_trajectory_buffer,
make_trajectory_queue,
mixer,
prioritised_flat_buffer,
prioritised_item_buffer,
prioritised_trajectory_buffer,
Expand Down
1 change: 1 addition & 0 deletions flashbax/buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from flashbax.buffers.flat_buffer import make_flat_buffer
from flashbax.buffers.item_buffer import make_item_buffer
from flashbax.buffers.mixer import make_mixer
from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
from flashbax.buffers.prioritised_item_buffer import make_prioritised_item_buffer
from flashbax.buffers.prioritised_trajectory_buffer import (
Expand Down
106 changes: 106 additions & 0 deletions flashbax/buffers/mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Callable

import jax.numpy as jnp
from chex import dataclass
from jax.tree_util import tree_map

from flashbax.buffers.trajectory_buffer import (
TrajectoryBufferSample,
TrajectoryBufferState,
)


@dataclass(frozen=True)
class Mixer:
sample: Callable
can_sample: Callable


def sample_mixer_fn(
states,
key,
prop_batch_sizes,
sample_fns,
):
samples_array = tree_map(
lambda state, sample, key_in: sample(state, key_in),
states,
sample_fns,
[key] * len(sample_fns), # if key.ndim == 1 else key,
is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState,
)

def _slicer(sample, batch_slice):
return tree_map(lambda x: x[:batch_slice, ...], sample)

prop_batch_samples_array = tree_map(
lambda x, p: _slicer(x, p),
samples_array,
prop_batch_sizes,
is_leaf=lambda leaf: type(leaf) == TrajectoryBufferSample,
)

joint_sample = tree_map(
lambda *x: jnp.concatenate(x, axis=0),
*prop_batch_samples_array,
)
return joint_sample


def can_sample_mixer_fn(
states,
can_sample_fns,
):
each_can_sample = tree_map(
lambda state, can_sample: can_sample(state),
states,
can_sample_fns,
is_leaf=lambda leaf: type(leaf) == TrajectoryBufferState,
)
return all(each_can_sample)


def make_mixer(
buffers: list,
sample_batch_size: int,
proportions: list,
):
sample_fns = [b.sample for b in buffers]
can_sample_fns = [b.can_sample for b in buffers]

props_sum = sum(proportions)
props_norm = [p / props_sum for p in proportions]
prop_batch_sizes = [int(p * sample_batch_size) for p in props_norm]
if sum(prop_batch_sizes) != sample_batch_size:
prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes)

mixer_sample_fn = functools.partial(
sample_mixer_fn,
prop_batch_sizes=prop_batch_sizes,
sample_fns=sample_fns,
)

mixer_can_sample_fn = functools.partial(
can_sample_mixer_fn,
can_sample_fns=can_sample_fns,
)

return Mixer(
sample=mixer_sample_fn,
can_sample=mixer_can_sample_fn,
)
253 changes: 253 additions & 0 deletions mixer_new_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import flashbax as fbx\n",
"import jax.numpy as jnp\n",
"from jax.tree_util import tree_map\n",
"import jax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1500\n"
]
},
{
"data": {
"text/plain": [
"TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"buffer_a = fbx.make_trajectory_buffer(\n",
" add_batch_size=1,\n",
" max_length_time_axis=10_000,\n",
" min_length_time_axis=5,\n",
" sample_sequence_length=5,\n",
" period=1,\n",
" sample_batch_size=2,\n",
")\n",
"\n",
"timestep = {\n",
" \"obs\": jnp.ones((2)),\n",
" \"acts\": jnp.ones(3),\n",
"}\n",
"\n",
"state_a = buffer_a.init(\n",
" timestep,\n",
")\n",
"for i in range(1500):\n",
" state_a = buffer_a.add(\n",
" state_a,\n",
" tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n",
" )\n",
"\n",
"print(state_a.current_index)\n",
"tree_map(lambda x: x.shape, state_a)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6000\n"
]
},
{
"data": {
"text/plain": [
"TrajectoryBufferState(experience={'acts': (1, 10000, 3), 'obs': (1, 10000, 2)}, current_index=(), is_full=())"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"buffer_b = fbx.make_trajectory_buffer(\n",
" add_batch_size=1,\n",
" max_length_time_axis=10_000,\n",
" min_length_time_axis=5,\n",
" sample_sequence_length=5,\n",
" period=1,\n",
" sample_batch_size=13,\n",
")\n",
"\n",
"timestep = {\n",
" \"obs\": jnp.ones((2)),\n",
" \"acts\": jnp.ones(3),\n",
"}\n",
"\n",
"state_b = buffer_b.init(\n",
" timestep,\n",
")\n",
"for i in range(6000):\n",
" state_b = buffer_b.add(\n",
" state_b,\n",
" tree_map(lambda x, _i=i: (1000 - x * _i)[None, None, ...], timestep),\n",
" )\n",
"\n",
"print(state_b.current_index)\n",
"tree_map(lambda x: x.shape, state_b)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 5, 3)\n",
"(13, 5, 3)\n"
]
}
],
"source": [
"sample_a = buffer_a.sample(state_a, key)\n",
"print(sample_a.experience['acts'].shape)\n",
"\n",
"sample_b = buffer_b.sample(state_b, key)\n",
"print(sample_b.experience['acts'].shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"mixer = fbx.make_mixer(\n",
" buffers=[buffer_a, buffer_b],\n",
" sample_batch_size=8,\n",
" proportions=[2,3]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6, 5, 3)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joint_sample = mixer.sample(\n",
" [state_a, state_b],\n",
" key,\n",
")\n",
"\n",
"joint_sample.experience['acts'].shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"mixer = fbx.make_mixer(\n",
" buffers=[buffer_a, buffer_b],\n",
" sample_batch_size=8,\n",
" proportions=[0.1, 0.9]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(8, 5, 3)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"joint_sample = mixer.sample(\n",
" [state_a, state_b],\n",
" key,\n",
")\n",
"\n",
"joint_sample.experience['acts'].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "flashbax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 00b1668

Please sign in to comment.