Skip to content

Commit

Permalink
feature(store,group,array): stores learn to delete prefixes when
Browse files Browse the repository at this point in the history
overwriting nodes

- add Store.delete_dir and Store.delete_prefix
- update array and group creation methods to call delete_dir
- change list_prefix to return absolue keys
  • Loading branch information
jhamman committed Oct 22, 2024
1 parent 8726734 commit 858d4fb
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 30 deletions.
30 changes: 30 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,36 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
...

async def delete_dir(self, prefix: str, recursive: bool = True) -> None:
"""
Remove all keys and prefixes in the store that begin with a given prefix.
"""
if not self.supports_deletes:
raise NotImplementedError
if not self.supports_listing:
raise NotImplementedError
self._check_writable()
if recursive:
if not prefix.endswith("/"):
prefix += "/"
async for key in self.list_prefix(prefix):
await self.delete(f"{key}")
else:
async for key in self.list_dir(prefix):
await self.delete(f"{prefix}/{key}")

async def delete_prefix(self, prefix: str) -> None:
"""
Remove all keys in the store that begin with a given prefix.
"""
if not self.supports_deletes:
raise NotImplementedError
if not self.supports_listing:
raise NotImplementedError
self._check_writable()
async for key in self.list_prefix(prefix):
await self.delete(f"{key}")

def close(self) -> None:
"""Close the store."""
self._is_open = False
Expand Down
14 changes: 12 additions & 2 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,12 @@ async def _create_v3(
attributes: dict[str, JSON] | None = None,
exists_ok: bool = False,
) -> AsyncArray[ArrayV3Metadata]:
if not exists_ok:
if exists_ok:
if store_path.store.supports_deletes:
await store_path.delete_dir(recursive=True)
else:
await ensure_no_existing_node(store_path, zarr_format=3)
else:
await ensure_no_existing_node(store_path, zarr_format=3)

shape = parse_shapelike(shape)
Expand Down Expand Up @@ -606,7 +611,12 @@ async def _create_v2(
attributes: dict[str, JSON] | None = None,
exists_ok: bool = False,
) -> AsyncArray[ArrayV2Metadata]:
if not exists_ok:
if exists_ok:
if store_path.store.supports_deletes:
await store_path.delete_dir(recursive=True)
else:
await ensure_no_existing_node(store_path, zarr_format=2)
else:
await ensure_no_existing_node(store_path, zarr_format=2)

if order is None:
Expand Down
21 changes: 8 additions & 13 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,13 @@ async def from_store(
zarr_format: ZarrFormat = 3,
) -> AsyncGroup:
store_path = await make_store_path(store)
if not exists_ok:

if exists_ok:
if store_path.store.supports_deletes:
await store_path.delete_dir(recursive=True)
else:
await ensure_no_existing_node(store_path, zarr_format=zarr_format)
else:
await ensure_no_existing_node(store_path, zarr_format=zarr_format)
attributes = attributes or {}
group = cls(
Expand Down Expand Up @@ -710,19 +716,8 @@ def _getitem_consolidated(

async def delitem(self, key: str) -> None:
store_path = self.store_path / key
if self.metadata.zarr_format == 3:
await (store_path / ZARR_JSON).delete()

elif self.metadata.zarr_format == 2:
await asyncio.gather(
(store_path / ZGROUP_JSON).delete(), # TODO: missing_ok=False
(store_path / ZARRAY_JSON).delete(), # TODO: missing_ok=False
(store_path / ZATTRS_JSON).delete(), # TODO: missing_ok=True
)

else:
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")

await store_path.delete_dir(recursive=True)
if self.metadata.consolidated_metadata:
self.metadata.consolidated_metadata.metadata.pop(key, None)
await self._save_metadata()
Expand Down
12 changes: 12 additions & 0 deletions src/zarr/storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ async def delete(self) -> None:
"""
await self.store.delete(self.path)

async def delete_dir(self, recursive: bool = False) -> None:
"""
Delete all keys with the given prefix from the store.
"""
await self.store.delete_dir(self.path, recursive=recursive)

async def delete_prefix(self) -> None:
"""
Delete all keys with the given prefix from the store.
"""
await self.store.delete_prefix(self.path)

async def set_if_not_exists(self, default: Buffer) -> None:
"""
Store a key to ``value`` if the key is not already present.
Expand Down
4 changes: 3 additions & 1 deletion src/zarr/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ async def list(self) -> AsyncGenerator[str, None]:

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
to_strip = os.path.join(str(self.root / prefix))
to_strip = str(self.root)
if prefix.endswith("/"):
prefix = prefix[:-1]
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p.relative_to(to_strip))
Expand Down
3 changes: 3 additions & 0 deletions src/zarr/storage/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,6 @@ def with_mode(self, mode: AccessModeLiteral) -> Self:
log_level=self.log_level,
log_handler=self.log_handler,
)


# TODO: wrap delete methods here
11 changes: 8 additions & 3 deletions src/zarr/storage/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from logging import getLogger
from typing import TYPE_CHECKING, Self

from zarr.abc.store import ByteRangeRequest, Store
Expand All @@ -14,6 +15,9 @@
from zarr.core.common import AccessModeLiteral


logger = getLogger(__name__)


class MemoryStore(Store):
"""
In-memory store for testing purposes.
Expand Down Expand Up @@ -137,7 +141,7 @@ async def delete(self, key: str) -> None:
try:
del self._store_dict[key]
except KeyError:
pass
logger.debug("Key %s does not exist.", key)

async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None:
# docstring inherited
Expand All @@ -150,9 +154,10 @@ async def list(self) -> AsyncGenerator[str, None]:

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
for key in self._store_dict:
# note: we materialize all dict keys into a list here so we can mutate the dict in-place (e.g. in delete_prefix)
for key in list(self._store_dict):
if key.startswith(prefix):
yield key.removeprefix(prefix)
yield key

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
Expand Down
7 changes: 4 additions & 3 deletions src/zarr/storage/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
find_str = f"{self.path}/{prefix}"
for onefile in await self.fs._find(find_str, detail=False, maxdepth=None, withdirs=False):
yield onefile.removeprefix(find_str)
for onefile in await self.fs._find(
f"{self.path}/{prefix}", detail=False, maxdepth=None, withdirs=False
):
yield onefile.removeprefix(f"{self.path}/")
2 changes: 1 addition & 1 deletion src/zarr/storage/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
async for key in self.list():
if key.startswith(prefix):
yield key.removeprefix(prefix)
yield key

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
Expand Down
5 changes: 2 additions & 3 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ async def test_list(self, store: S) -> None:
async def test_list_prefix(self, store: S) -> None:
"""
Test that the `list_prefix` method works as intended. Given a prefix, it should return
all the keys in storage that start with this prefix. Keys should be returned with the shared
prefix removed.
all the keys in storage that start with this prefix.
"""
prefixes = ("", "a/", "a/b/", "a/b/c/")
data = self.buffer_cls.from_bytes(b"")
Expand All @@ -264,7 +263,7 @@ async def test_list_prefix(self, store: S) -> None:
expected: tuple[str, ...] = ()
for key in store_dict:
if key.startswith(prefix):
expected += (key.removeprefix(prefix),)
expected += (key,)
expected = tuple(sorted(expected))
assert observed == expected

Expand Down
3 changes: 3 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_array_creation_existing_node(
new_dtype = "float32"

if exists_ok:
if not store.supports_deletes:
# TODO: confirm you get the expected error here
pytest.skip("store does not support deletes")
arr_new = Array.create(
spath / "extant",
shape=new_shape,
Expand Down
23 changes: 22 additions & 1 deletion tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def test_group_open(store: Store, zarr_format: ZarrFormat, exists_ok: bool) -> N
with pytest.raises(ContainsGroupError):
Group.from_store(store, attributes=attrs, zarr_format=zarr_format, exists_ok=exists_ok)
else:
if not store.supports_deletes:
pytest.skip(
"Store does not support deletes but `exists_ok` is True, requiring deletes to override a group"
)
group_created_again = Group.from_store(
store, attributes=new_attrs, zarr_format=zarr_format, exists_ok=exists_ok
)
Expand Down Expand Up @@ -703,6 +707,8 @@ def test_group_creation_existing_node(
new_attributes = {"new": True}

if exists_ok:
if not store.supports_deletes:
pytest.skip("store does not support deletes but exists_ok is True")
node_new = Group.from_store(
spath / "extant",
attributes=new_attributes,
Expand Down Expand Up @@ -1075,7 +1081,9 @@ async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrF
assert foo_group.attrs == {"foo": 100}

# test that we can get the group using require_group and overwrite=True
foo_group = await root.require_group("foo", overwrite=True)
if store.supports_deletes:
foo_group = await root.require_group("foo", overwrite=True)
assert foo_group.attrs == {}

_ = await foo_group.create_array(
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
Expand Down Expand Up @@ -1354,3 +1362,16 @@ def test_group_deprecated_positional_args(method: str) -> None:
with pytest.warns(FutureWarning, match=r"Pass name=.*, data=.* as keyword args."):
arr = getattr(root, method)("foo_like", data, **kwargs)
assert arr.shape == data.shape


@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"])
def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None:
# https://github.com/zarr-developers/zarr-python/issues/2191
g1 = zarr.group(store=store)
g1.create_group("0")
g1.create_group("0/0")
arr = g1.create_array("0/0/0", shape=(1,))
arr[:] = 1
del g1["0"]
with pytest.raises(KeyError):
g1["0/0"]
8 changes: 6 additions & 2 deletions tests/test_store/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ async def test_logging_store_counter(store: Store) -> None:
arr[:] = 1

assert wrapped.counter["set"] == 2
assert wrapped.counter["get"] == 0 # 1 if overwrite=False
assert wrapped.counter["list"] == 0
assert wrapped.counter["list_dir"] == 0
assert wrapped.counter["list_prefix"] == 0
if store.supports_deletes:
assert wrapped.counter["get"] == 0 # 1 if overwrite=False
assert wrapped.counter["list_prefix"] == 1
else:
assert wrapped.counter["get"] == 1
assert wrapped.counter["list_prefix"] == 0


async def test_with_mode():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ async def test_with_mode(self, store: ZipStore) -> None:

@pytest.mark.parametrize("mode", ["a", "w"])
async def test_store_open_mode(self, store_kwargs: dict[str, Any], mode: str) -> None:
super().test_store_open_mode(store_kwargs, mode)
await super().test_store_open_mode(store_kwargs, mode)

0 comments on commit 858d4fb

Please sign in to comment.