Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vault #5

Merged
merged 34 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f150637
feat: first port of vault into fbx, wip.
callumtilbury Dec 9, 2023
d6ddadf
feat: return write_length.
callumtilbury Dec 11, 2023
bb96ec1
feat: save fbx structure alongside vault for better reloading (wip!!!)
callumtilbury Dec 11, 2023
e9a3e59
chore: unpin typing_extensions from < 4.6.0, as causing issues elsewh…
callumtilbury Dec 11, 2023
1d87ed2
chore: unpin typing_extensions from < 4.6.0, as it's causing issues e…
callumtilbury Dec 11, 2023
64f9d21
chore: precommit
callumtilbury Dec 11, 2023
750f632
Merge branch 'feat/vault' of github.com:instadeepai/flashbax into fea…
callumtilbury Dec 11, 2023
415dbd3
chore: remove bottom level init for Vault. Should be imported as .
callumtilbury Dec 11, 2023
1f7f325
Merge branch 'main' into feat/vault
callumtilbury Dec 12, 2023
7ca7c32
Merge branch 'main' into feat/vault
callumtilbury Jan 12, 2024
6816107
feat: big update and refactor in order to checkpoint namedtuples.
callumtilbury Jan 16, 2024
99fc9c9
docs: big comment update.
callumtilbury Jan 16, 2024
4a7d6ba
chore: precommit.
callumtilbury Jan 16, 2024
ce0724b
chore: minor docs fix.
callumtilbury Jan 16, 2024
9c16bee
chore: bump version to first major (hopefully stable) release.
callumtilbury Jan 16, 2024
aef9e22
feat: first few tests for vault.
callumtilbury Jan 17, 2024
249b296
fix: use temp dirs for all of the tests.
callumtilbury Jan 17, 2024
e01752b
feat: improved test for reloading vault.
callumtilbury Jan 17, 2024
a878bcc
docs: vault explainer in readme, along with demonstrative notebook.
callumtilbury Jan 17, 2024
37b940e
fix: add fbx install to vault example notebook.
callumtilbury Jan 17, 2024
cfc30db
feat: print messages after loading or creating vault.
callumtilbury Jan 17, 2024
d8abdc9
chore: minor
callumtilbury Jan 17, 2024
f85c70d
chore: minor docs fix
callumtilbury Jan 17, 2024
e0a6fdc
chore: minor docs update
callumtilbury Jan 17, 2024
804335a
chore: nits from code review
callumtilbury Jan 24, 2024
a8c9fbb
feat: save structure metadata of shape and dtype in separate trees.
callumtilbury Jan 24, 2024
dc4a88e
minor: move print lower down, in case of code block failure.
callumtilbury Feb 5, 2024
ff4fac3
feat: use different obs dim in example for clarity's sake.
callumtilbury Feb 5, 2024
df94746
chore: minor var rename.
callumtilbury Feb 5, 2024
0629f29
minor: only print after code block success.
callumtilbury Feb 5, 2024
ded832c
docs: add important consideration of vault ring buffer challenge.
callumtilbury Feb 5, 2024
f031f55
chore: minor text
callumtilbury Feb 5, 2024
222cb6d
chore: fix location of summary sentence
callumtilbury Feb 5, 2024
8709c65
feat: add timesteps overwrite warning when creating vault
callumtilbury Feb 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ from CleanRLs DQN JAX example.
- 🦎 [Jumanji](https://github.com/instadeepai/jumanji/) - utilise Jumanji's JAX based environments
like Snake for our fully jitted examples.

## Vault 💾
Vault is an efficient mechanism for saving Flashbax buffers to persistent data storage, e.g. for use in offline reinforcement learning. Consider a Flashbax buffer which has experience data of dimensionality $(B, T, *E)$, where $B$ is a batch dimension (for the sake of recording independent trajectories synchronously), $T$ is a temporal/sequential dimension, and $*E$ indicates the one or more dimensions of the experience data itself. Since large quantities of data may be generated for a given environment, Vault extends the $T$ dimension to a virtually unconstrained degree by reading and writing slices of buffers along this temporal axis. In doing so, gigantic buffer stores can reside on disk, from which sub-buffers can be loaded into RAM/VRAM for efficient offline training. Vault has been tested with the item, flat, and trajectory buffers.

For more information, see the demonstrative notebook: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/vault_demonstration.ipynb)


## Important Considerations ⚠️

When working with Flashbax buffers, it's crucial to be mindful of certain considerations to ensure the proper functionality of your RL agent.
Expand Down Expand Up @@ -188,6 +194,10 @@ train_state, buffer_state = jax.jit(train, donate_argnums=(1,))(

It is important to include `donate_argnums` when calling `jax.jit` to enable JAX to perform an in-place update of the replay buffer state. Omitting `donate_argnums` would force JAX to create a copy of the state for any modifications to the replay buffer state, potentially negating all performance benefits. More information about buffer donation in JAX can be found in the [documentation](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).


### Storing Data with Vault
As mentioned [above](./README.md#vault-💾), Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data.

In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers.

## Benchmarks 📈
Expand Down Expand Up @@ -242,6 +252,7 @@ Previous benchmarks added only a single timestep at a time, we now evaluate addi

Ultimately, we see improved or comparable performance to benchmarked buffers whilst providing buffers that are fully JAX-compatible in addition to other features such as batched adding as well as being able to add sequences of varying length. We do note that due to JAX having different XLA backends for CPU, GPU, and TPU, the performance of the buffers can vary depending on the device and the specific operation being called.


## Contributing 🤝

Contributions are welcome! See our issue tracker for
Expand Down
354 changes: 354 additions & 0 deletions examples/vault_demonstration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vault demonstration"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"try:\n",
" import flashbax as fbx\n",
"except ModuleNotFoundError:\n",
" print('installing flashbax')\n",
" %pip install -q flashbax\n",
" import flashbax as fbx"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"from typing import NamedTuple\n",
"import jax.numpy as jnp\n",
"from flashbax.vault import Vault\n",
"import flashbax as fbx\n",
"from chex import Array"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create a simple timestep structure, with a corresponding flat buffer."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/claude/flashbax/flashbax/buffers/trajectory_buffer.py:473: UserWarning: Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = 5`.This allows one to control exactly how many timesteps are stored in the buffer.Note that this overrides the `max_length_time_axis` argument.\n",
" warnings.warn(\n"
]
}
],
"source": [
"class FbxTransition(NamedTuple):\n",
" obs: Array\n",
"\n",
"tx = FbxTransition(obs=jnp.zeros(shape=(2,)))\n",
"\n",
"buffer = fbx.make_flat_buffer(\n",
" max_length=5,\n",
" min_length=1,\n",
" sample_batch_size=1,\n",
")\n",
"buffer_state = buffer.init(tx)\n",
"buffer_add = jax.jit(buffer.add, donate_argnums=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The shape of this buffer is $(B = 1, T = 5, E = 2)$, meaning the buffer can hold 5 timesteps, where each observation is of shape $(2,)$."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 5, 2)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"buffer_state.experience.obs.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create the vault, based on the buffer's experience structure."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New vault created at /tmp/demo/20240205140817\n"
]
}
],
"source": [
"v = Vault(\n",
" vault_name=\"demo\",\n",
" experience_structure=buffer_state.experience,\n",
" rel_dir=\"/tmp\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now add 10 timesteps to the buffer, and write that buffer to the vault. We inspect the buffer and vault state after each timestep."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------------------\n",
"Buffer state:\n",
"[[[0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [0. 0.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [0. 0.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [0. 0.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [7. 7.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]\n",
" [7. 7.]]]\n",
"------------------\n",
"------------------\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [7. 7.]\n",
" [8. 8.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]\n",
" [7. 7.]\n",
" [8. 8.]]]\n",
"------------------\n"
]
}
],
"source": [
"for i in range(1, 10):\n",
" print('------------------')\n",
" print(\"Buffer state:\")\n",
" print(buffer_state.experience.obs)\n",
" print()\n",
"\n",
" v.write(buffer_state)\n",
"\n",
" print(\"Vault state:\")\n",
" print(v.read().experience.obs)\n",
" print('------------------')\n",
"\n",
" buffer_state = buffer_add(\n",
" buffer_state,\n",
" FbxTransition(obs=i * jnp.ones(1))\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that when the buffer (implemented as a ring buffer) wraps around, the vault continues storing the data:\n",
"```\n",
"Buffer state:\n",
"[[[6. 6.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]]]\n",
"\n",
"Vault state:\n",
"[[[1. 1.]\n",
" [2. 2.]\n",
" [3. 3.]\n",
" [4. 4.]\n",
" [5. 5.]\n",
" [6. 6.]]]\n",
"```\n",
"\n",
"Note: the vault must be given the buffer state at least every `max_steps` number of timesteps (i.e. before stale data is overwritten in the ring buffer)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "flashbax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading