Skip to content

Commit

Permalink
Merge branch 'main' into feat/prioritised_item_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury authored Jul 4, 2024
2 parents c80f9da + c72a39b commit d583210
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 22 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ our [contributing guidelines](https://github.com/instadeepai/flashbax/blob/main/
details on how to submit pull requests, our Contributor License Agreement, and community guidelines.

## See Also 📚

### Other Buffers
Checkout some of the other buffer libraries from the community that we have highlighted in our
benchmarks.

Expand All @@ -271,6 +273,11 @@ benchmarks.
- 🍰 [Dopamine](https://github.com/google/dopamine/blob/master/dopamine/replay_memory/): research framework for fast prototyping, providing several core replay buffers.
- 🤖 [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/): suite of reliable RL baselines with its own, easy-to-use replay buffers.

### Example Usage
Checkout some libraries from the community that utilise flashbax:
- 🦁 [Mava](https://github.com/instadeepai/Mava): end-to-end JAX implementations of multi-agent algorithms utilising flashbax.
- 🏛️ [Stoix](https://github.com/EdanToledo/Stoix): end-to-end JAX implementations of single-agent algorithms utilising flashbax.

## Citing Flashbax ✏️

If you use Flashbax in your work, please cite the library using:
Expand Down
2 changes: 1 addition & 1 deletion examples/anakin_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
"):\n",
" \"\"\"Sets up the experiment.\"\"\"\n",
" cores_count = len(jax.devices()) # get available TPU cores.\n",
" network = get_network_fn(env.action_spec().num_values) # define network.\n",
" network = get_network_fn(env.action_spec.num_values) # define network.\n",
" optim = optax.adam(step_size) # define optimiser.\n",
"\n",
" rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3) # prng keys.\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/anakin_ppo_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@
" \"\"\"Sets up the experiment and returns the necessary information.\"\"\"\n",
"\n",
" cores_count = len(jax.devices()) # get available TPU cores.\n",
" network = get_network_fn(env.action_spec().num_values) # define network.\n",
" network = get_network_fn(env.action_spec.num_values) # define network.\n",
" optim = optax.adam(step_size) # define optimiser.\n",
"\n",
" rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3) # prng keys.\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/anakin_prioritised_dqn_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@
"):\n",
" \"\"\"Sets up the experiment.\"\"\"\n",
" cores_count = len(jax.devices()) # get available TPU cores.\n",
" network = get_network_fn(env.action_spec().num_values) # define network.\n",
" network = get_network_fn(env.action_spec.num_values) # define network.\n",
" optim = optax.adam(step_size) # define optimiser.\n",
"\n",
" rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3) # prng keys.\n",
Expand Down
36 changes: 18 additions & 18 deletions flashbax/vault/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"id": "gzip",
"level": 5,
}
VERSION = 1.1
VERSION = 1.2


def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str:
Expand Down Expand Up @@ -87,8 +87,8 @@ def __init__( # noqa: CCR001
vault_uid (Optional[str], optional): Unique identifier for this vault.
Defaults to None, which will use the current timestamp.
compression (Optional[dict], optional):
Compression settings for the vault. Defaults to None, which will use
the default settings.
Compression settings used when when creating the vault.
Defaults to None, which will use the default compression.
metadata (Optional[dict], optional):
Any additional metadata to save. Defaults to None.
Expand All @@ -115,6 +115,11 @@ def __init__( # noqa: CCR001

print(f"Loading vault found at {self._base_path}")

if compression is not None:
print(
"Requested compression settings will be ignored as the vault already exists."
)

elif experience_structure is not None:
# Create the necessary dirs for the vault
os.makedirs(self._base_path)
Expand Down Expand Up @@ -145,7 +150,6 @@ def __init__( # noqa: CCR001
"version": VERSION,
"structure_shape": serialised_experience_structure_shape,
"structure_dtype": serialised_experience_structure_dtype,
"compression": compression or COMPRESSION_DEFAULT,
**(metadata_json_ready or {}), # Allow user to save extra metadata
}
# Dump metadata to file
Expand Down Expand Up @@ -184,12 +188,8 @@ def __init__( # noqa: CCR001
target=experience_structure,
)

# Load compression settings from metadata
self._compression = (
self._metadata["compression"]
if "compression" in self._metadata
else COMPRESSION_DEFAULT
)
# Keep the compression settings, to be used in init_leaf, in case we're creating the vault
self._compression = compression

# Each leaf of the fbx_state.experience maps to a data store, so we tree map over the
# tree structure to create each of the data stores.
Expand Down Expand Up @@ -235,11 +235,6 @@ def _get_base_spec(self, name: str) -> dict:
"base": f"{DRIVER}{self._base_path}",
"path": name,
},
"metadata": {
"compressor": {
**self._compression,
}
},
}

def _init_leaf(
Expand All @@ -260,14 +255,19 @@ def _init_leaf(

leaf_shape, leaf_dtype = None, None
if create_ds:
# Only specify dtype and shape if we are creating a vault
# (i.e. don't impose dtype and shape if we are _loading_ a vault)
# Only specify dtype, shape, and compression if we are creating a vault
# (i.e. don't impose these fields if we are _loading_ a vault)
leaf_shape = (
shape[0], # Batch dim
TIME_AXIS_MAX_LENGTH, # Time dim, which we extend
*shape[2:], # Experience dim(s)
)
leaf_dtype = dtype
spec["metadata"] = {
"compressor": COMPRESSION_DEFAULT
if self._compression is None
else self._compression
}

leaf_ds = ts.open(
spec,
Expand Down Expand Up @@ -478,6 +478,6 @@ def read(
# Return the read result as a fbx buffer state
return TrajectoryBufferState(
experience=read_result,
current_index=jnp.array(self.vault_index, dtype=int),
current_index=jnp.array(0, dtype=int),
is_full=jnp.array(True, dtype=bool),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ authors = [
{name="InstaDeep" , email = "[email protected]"},
]
requires-python = ">=3.9"
version = "0.1.0"
version = "0.1.2"
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Console",
Expand Down

0 comments on commit d583210

Please sign in to comment.