Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache pre-existing Zarr arrays in Zarr backend #9861

Merged
merged 48 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9503c71
cache array keys on ZarrStore instances
d-v-b Dec 4, 2024
94b48ac
refactor get_array_keys
d-v-b Dec 6, 2024
2dacb2a
add test for get_array_keys
d-v-b Dec 6, 2024
c7cfacf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2024
b3cf2e5
lint
d-v-b Dec 6, 2024
d9bd7fd
Merge branch 'perf/cache-array-keys' of https://github.com/d-v-b/xarr…
d-v-b Dec 6, 2024
8d47dd3
Merge branch 'main' of https://www.github.com/pydata/xarray into perf…
d-v-b Dec 6, 2024
bec8b42
add type annotation for _cache
d-v-b Dec 6, 2024
61fe759
make type hint accurate
d-v-b Dec 6, 2024
0952c93
make type hint pass mypy
d-v-b Dec 6, 2024
96ddcf4
make type hint pass mypy, correctly
d-v-b Dec 6, 2024
60761a8
Update xarray/backends/zarr.py
d-v-b Dec 9, 2024
f82659d
Merge branch 'main' of https://github.com/pydata/xarray into perf/cac…
d-v-b Dec 10, 2024
107c4e7
use roundtrip method instead of explicit alternative
d-v-b Dec 10, 2024
2d3c13a
cache members instead of just array keys
d-v-b Dec 10, 2024
be5e389
adjust test to members caching
d-v-b Dec 10, 2024
1eaf3ea
Merge branch 'perf/cache-array-keys' of https://github.com/d-v-b/xarr…
d-v-b Dec 10, 2024
9d147b9
refactor members cache
d-v-b Dec 10, 2024
bdb20a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
6578caa
explictly revert changes to test_write_store
d-v-b Dec 10, 2024
9596bcf
Merge branch 'perf/cache-array-keys' of https://github.com/d-v-b/xarr…
d-v-b Dec 10, 2024
df4274f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
5bf7aff
indent correctly
d-v-b Dec 10, 2024
df51855
Merge branch 'perf/cache-array-keys' of https://github.com/d-v-b/xarr…
d-v-b Dec 10, 2024
5a818a5
update instrumented tests, remove cache update functionality, and set…
d-v-b Dec 10, 2024
c18f81b
update instrumented tests, remove cache update functionality, and set…
d-v-b Dec 10, 2024
7b13099
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
8925020
lint
d-v-b Dec 10, 2024
5efbe30
Merge branch 'perf/cache-array-keys' of https://github.com/d-v-b/xarr…
d-v-b Dec 10, 2024
192fc8b
update tests
d-v-b Dec 10, 2024
2ce5b2e
make check_requests more permissive
d-v-b Dec 10, 2024
87a40c7
update instrumented tests for zarr==2.18
d-v-b Dec 10, 2024
7e79b79
Merge branch 'main' into perf/cache-array-keys
dcherian Dec 11, 2024
d987529
Update xarray/backends/zarr.py
d-v-b Dec 14, 2024
169b736
Update xarray/backends/zarr.py
d-v-b Dec 14, 2024
e20d11d
Update xarray/backends/zarr.py
d-v-b Dec 14, 2024
a1c25f2
Update xarray/backends/zarr.py
d-v-b Dec 14, 2024
f8d78cb
Update xarray/backends/zarr.py
d-v-b Dec 16, 2024
650937e
doc: add whats new content
d-v-b Dec 17, 2024
fcf2d64
fixup
d-v-b Dec 17, 2024
bf64fd2
Merge branch 'main' of https://www.github.com/pydata/xarray into perf…
d-v-b Dec 17, 2024
dcecf43
update whats new
d-v-b Dec 17, 2024
d0bb5a2
fixup
d-v-b Dec 17, 2024
65cadcb
remove vestigial cache_array_keys
d-v-b Dec 17, 2024
e8cd3c8
make cache_members default to True
d-v-b Dec 17, 2024
8d31f0e
tweak
dcherian Dec 17, 2024
3d15d3d
Remove from public API
dcherian Dec 17, 2024
f1ef905
Merge branch 'main' into perf/cache-array-keys
dcherian Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ class ZarrStore(AbstractWritableDataStore):

__slots__ = (
"_append_dim",
"_cache",
"_cache_array_keys",
dcherian marked this conversation as resolved.
Show resolved Hide resolved
"_close_store_on_close",
"_consolidate_on_close",
"_group",
Expand Down Expand Up @@ -633,6 +635,7 @@ def open_store(
zarr_format=None,
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_array_keys: bool = True,
d-v-b marked this conversation as resolved.
Show resolved Hide resolved
):
(
zarr_group,
Expand Down Expand Up @@ -664,6 +667,7 @@ def open_store(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_array_keys=cache_array_keys,
)
for group in group_paths
}
Expand All @@ -686,6 +690,7 @@ def open_group(
zarr_format=None,
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_array_keys: bool = True,
):
(
zarr_group,
Expand Down Expand Up @@ -716,6 +721,7 @@ def open_group(
write_empty,
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_array_keys,
)

def __init__(
Expand All @@ -729,6 +735,7 @@ def __init__(
write_empty: bool | None = None,
close_store_on_close: bool = False,
use_zarr_fill_value_as_mask=None,
cache_array_keys: bool = True,
):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
Expand All @@ -743,6 +750,10 @@ def __init__(
self._close_store_on_close = close_store_on_close
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask

self._cache: dict[str, Any] = {}
if cache_array_keys:
self._cache["array_keys"] = None

@property
def ds(self):
# TODO: consider deprecating this in favor of zarr_group
Expand Down Expand Up @@ -802,6 +813,17 @@ def get_variables(self):
(k, self.open_store_variable(k, v)) for k, v in self.zarr_group.arrays()
)

def get_array_keys(self) -> tuple[str, ...]:
key = "array_keys"
if key not in self._cache:
result = tuple(self.zarr_group.array_keys())
elif self._cache[key] is None:
result = tuple(self.zarr_group.array_keys())
self._cache[key] = result
else:
result = self._cache[key]
return result

def get_attrs(self):
return {
k: v
Expand Down Expand Up @@ -881,7 +903,7 @@ def store(
existing_keys = {}
existing_variable_names = {}
else:
existing_keys = tuple(self.zarr_group.array_keys())
existing_keys = self.get_array_keys()
existing_variable_names = {
vn for vn in variables if _encode_variable_name(vn) in existing_keys
}
Expand Down Expand Up @@ -1059,7 +1081,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
dimensions.
"""

existing_keys = tuple(self.zarr_group.array_keys())
existing_keys = self.get_array_keys()
is_zarr_v3_format = _zarr_v3() and self.zarr_group.metadata.zarr_format == 3

for vn, v in variables.items():
Expand Down Expand Up @@ -1251,7 +1273,7 @@ def _validate_and_autodetect_region(self, ds: Dataset) -> Dataset:

def _validate_encoding(self, encoding) -> None:
if encoding and self._mode in ["a", "a-", "r+"]:
existing_var_names = set(self.zarr_group.array_keys())
existing_var_names = self.get_array_keys()
for var_name in existing_var_names:
if var_name in encoding:
raise ValueError(
Expand Down Expand Up @@ -1281,6 +1303,7 @@ def open_zarr(
use_zarr_fill_value_as_mask=None,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
cache_array_keys: bool = False,
d-v-b marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""Load and decode a dataset from a Zarr store.
Expand Down Expand Up @@ -1434,6 +1457,7 @@ def open_zarr(
"storage_options": storage_options,
"zarr_version": zarr_version,
"zarr_format": zarr_format,
"cache_array_keys": cache_array_keys,
}

ds = open_dataset(
Expand Down Expand Up @@ -1505,6 +1529,7 @@ def open_dataset(
store=None,
engine=None,
use_zarr_fill_value_as_mask=None,
cache_array_keys: bool = True,
) -> Dataset:
filename_or_obj = _normalize_path(filename_or_obj)
if not store:
Expand All @@ -1520,6 +1545,7 @@ def open_dataset(
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
cache_array_keys=cache_array_keys,
)

store_entrypoint = StoreBackendEntrypoint()
Expand Down
38 changes: 36 additions & 2 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from xarray.backends.pydap_ import PydapDataStore
from xarray.backends.scipy_ import ScipyBackendEntrypoint
from xarray.backends.zarr import ZarrStore
from xarray.coding.cftime_offsets import cftime_range
from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype
from xarray.coding.variables import SerializationWarning
Expand Down Expand Up @@ -3261,6 +3262,41 @@ def test_chunked_cftime_datetime(self) -> None:
assert original[name].chunks == actual_var.chunks
assert original.chunks == actual.chunks

@pytest.mark.parametrize("cache_array_keys", [True, False])
def test_get_array_keys(self, cache_array_keys: bool) -> None:
"""
Ensure that if `ZarrStore` is created with `cache_array_keys` set to `True`,
a `ZarrStore.get_array_keys` only invokes the `array_keys` function on the
`ZarrStore.zarr_group` instance once, and that the results of that call are cached.

Otherwise, `ZarrStore.get_array_keys` instance should invoke the `array_keys`
each time it is called.
"""
with self.create_zarr_target() as store_target:
zstore = backends.ZarrStore.open_group(
store_target, mode="w", cache_array_keys=cache_array_keys
)

# ensure that the keys are sorted
array_keys = sorted(("foo", "bar"))

# create some arrays
for ak in array_keys:
zstore.zarr_group.create(name=ak, shape=(1,), dtype="uint8")

observed_keys_0 = sorted(zstore.get_array_keys())
assert observed_keys_0 == array_keys

# create a new array
new_key = "baz"
zstore.zarr_group.create(name=new_key, shape=(1,), dtype="uint8")
observed_keys_1 = sorted(zstore.get_array_keys())

if cache_array_keys:
assert observed_keys_1 == array_keys
else:
assert observed_keys_1 == sorted(array_keys + [new_key])


@requires_zarr
@pytest.mark.skipif(
Expand Down Expand Up @@ -6158,8 +6194,6 @@ def test_zarr_region_auto_new_coord_vals(self, tmp_path):
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"})

def test_zarr_region_index_write(self, tmp_path):
from xarray.backends.zarr import ZarrStore

x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
Expand Down
Loading