diff --git a/flashbax/utils.py b/flashbax/utils.py index f8a3428..88ad680 100644 --- a/flashbax/utils.py +++ b/flashbax/utils.py @@ -68,3 +68,27 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +def get_timestep_count(buffer_state: chex.ArrayTree) -> int: + """Utility to compute the total number of timesteps currently in the buffer state. + + Args: + buffer_state (BufferStateTypes): the buffer state to compute the total timesteps for. + + Returns: + int: the total number of timesteps in the buffer state. + """ + # Ensure the buffer state is a valid buffer state. + assert hasattr(buffer_state, "experience") + assert hasattr(buffer_state, "current_index") + assert hasattr(buffer_state, "is_full") + + b_size, t_size_max = get_tree_shape_prefix(buffer_state.experience, 2) + t_size = jax.lax.cond( + buffer_state.is_full, + lambda: t_size_max, + lambda: buffer_state.current_index, + ) + timestep_count: int = b_size * t_size + return timestep_count diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index 921a05b..c494ba4 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -101,7 +101,8 @@ def __init__( # noqa: CCR001 """ # Get the base path for the vault and the metadata path vault_str = vault_uid if vault_uid else datetime.now().strftime("%Y%m%d%H%M%S") - self._base_path = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) + base_path_unnorm = os.path.join(os.getcwd(), rel_dir, vault_name, vault_str) + self._base_path = os.path.normpath(base_path_unnorm) metadata_path = epath.Path(os.path.join(self._base_path, METADATA_FILE)) # Check if the vault exists, otherwise create the necessary dirs and files