From 5dc7daa18828bcf574555ea8a1ef295d4ea68eaa Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 19 Feb 2024 18:45:59 +0200 Subject: [PATCH] feat: vault compression --- flashbax/vault/vault.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 08e0353..09cb4a7 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -34,7 +34,11 @@ DRIVER = "file://" METADATA_FILE = "metadata.json" TIME_AXIS_MAX_LENGTH = int(10e12) # Upper bound on the length of the time axis -VERSION = 1.0 +COMPRESSION_DEFAULT = { + "id": "gzip", + "level": 5, +} +VERSION = 1.1 def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str: @@ -61,12 +65,13 @@ def _path_to_ds_name(path: Tuple[Union[DictKey, GetAttrKey], ...]) -> str: class Vault: - def __init__( + def __init__( # noqa: CCR001 self, vault_name: str, experience_structure: Optional[Experience] = None, rel_dir: str = "vaults", vault_uid: Optional[str] = None, + compression: Optional[dict] = None, metadata: Optional[dict] = None, ) -> None: """Flashbax utility for storing buffers to disk efficiently. @@ -81,6 +86,9 @@ def __init__( Base directory of all vaults. Defaults to "vaults". vault_uid (Optional[str], optional): Unique identifier for this vault. Defaults to None, which will use the current timestamp. + compression (Optional[dict], optional): + Compression settings for the vault. Defaults to None, which will use + the default settings. metadata (Optional[dict], optional): Any additional metadata to save. Defaults to None. @@ -137,6 +145,7 @@ def __init__( "version": VERSION, "structure_shape": serialised_experience_structure_shape, "structure_dtype": serialised_experience_structure_dtype, + "compression": compression or COMPRESSION_DEFAULT, **(metadata_json_ready or {}), # Allow user to save extra metadata } # Dump metadata to file @@ -175,6 +184,13 @@ def __init__( target=experience_structure, ) + # Load compression settings from metadata + self._compression = ( + self._metadata["compression"] + if "compression" in self._metadata + else COMPRESSION_DEFAULT + ) + # Each leaf of the fbx_state.experience maps to a data store, so we tree map over the # tree structure to create each of the data stores. self._all_datastores = jax.tree_util.tree_map_with_path( @@ -219,6 +235,11 @@ def _get_base_spec(self, name: str) -> dict: "base": f"{DRIVER}{self._base_path}", "path": name, }, + "metadata": { + "compressor": { + **self._compression, + } + }, } def _init_leaf(