Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zarr upstream tests #9927

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,11 @@ def extract_zarr_variable_encoding(

safe_to_drop = {"source", "original_shape", "preferred_chunks"}
valid_encodings = {
"codecs",
"chunks",
"compressor",
"compressor", # TODO: delete when min zarr >=3
"compressors",
"filters",
"serializer",
"cache_metadata",
"write_empty_chunks",
}
Expand Down Expand Up @@ -480,6 +481,8 @@ def extract_zarr_variable_encoding(
mode=mode,
shape=shape,
)
if _zarr_v3() and chunks is None:
chunks = "auto"
encoding["chunks"] = chunks
return encoding

Expand Down Expand Up @@ -816,24 +819,20 @@ def open_store_variable(self, name):
)
attributes = dict(attributes)

# TODO: this should not be needed once
# https://github.com/zarr-developers/zarr-python/issues/1269 is resolved.
attributes.pop("filters", None)

encoding = {
"chunks": zarr_array.chunks,
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)),
}

if _zarr_v3() and zarr_array.metadata.zarr_format == 3:
encoding["codecs"] = [x.to_dict() for x in zarr_array.metadata.codecs]
elif _zarr_v3():
if _zarr_v3():
encoding.update(
{
"compressor": zarr_array.metadata.compressor,
"filters": zarr_array.metadata.filters,
"compressors": zarr_array.compressors,
"filters": zarr_array.filters,
}
)
if self.zarr_group.metadata.zarr_format == 3:
encoding.update({"serializer": zarr_array.serializer})
else:
encoding.update(
{
Expand Down
67 changes: 30 additions & 37 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,21 @@


@pytest.fixture(scope="module", params=ZARR_FORMATS)
def default_zarr_version(request) -> Generator[None, None]:
def default_zarr_format(request) -> Generator[None, None]:
if has_zarr_v3:
with zarr.config.set(default_zarr_version=request.param):
with zarr.config.set(default_zarr_format=request.param):
yield
else:
yield


def skip_if_zarr_format_3(reason: str):
if has_zarr_v3 and zarr.config["default_zarr_version"] == 3:
if has_zarr_v3 and zarr.config["default_zarr_format"] == 3:
pytest.skip(reason=f"Unsupported with zarr_format=3: {reason}")


def skip_if_zarr_format_2(reason: str):
if not has_zarr_v3 or (zarr.config["default_zarr_version"] == 2):
if not has_zarr_v3 or (zarr.config["default_zarr_format"] == 2):
pytest.skip(reason=f"Unsupported with zarr_format=2: {reason}")


Expand Down Expand Up @@ -2270,7 +2270,7 @@ def test_roundtrip_coordinates(self) -> None:


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
class ZarrBase(CFEncodedBase):
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
zarr_version = 2
Expand Down Expand Up @@ -2439,7 +2439,7 @@ def test_warning_on_bad_chunks(self) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*Zarr version 3 specification.*",
message=".*Zarr format 3 specification.*",
category=UserWarning,
)
with self.roundtrip(original, open_kwargs=kwargs) as actual:
Expand Down Expand Up @@ -2675,40 +2675,35 @@ def test_write_persistence_modes(self, group) -> None:
assert_identical(original, actual)

def test_compressor_encoding(self) -> None:
original = create_test_data()
# specify a custom compressor

if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3:
encoding_key = "codecs"
original = create_test_data()
if has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3:
encoding_key = "compressors"
# all parameters need to be explicitly specified in order for the comparison to pass below
encoding = {
"serializer": zarr.codecs.BytesCodec(endian="little"),
encoding_key: (
zarr.codecs.BytesCodec(endian="little"),
zarr.codecs.BloscCodec(
cname="zstd",
clevel=3,
shuffle="shuffle",
typesize=8,
blocksize=0,
),
)
),
}
else:
from numcodecs.blosc import Blosc

encoding_key = "compressor"
encoding = {encoding_key: Blosc(cname="zstd", clevel=3, shuffle=2)}
encoding_key = "compressors" if has_zarr_v3 else "compressor"
comp = Blosc(cname="zstd", clevel=3, shuffle=2)
encoding = {encoding_key: (comp,) if has_zarr_v3 else comp}
Comment on lines +2698 to +2700
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all feels a bit silly. "compressors" instead of "compressor" and now a tuple instead of a single object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we had to pluralize it to make things work with zarr v3. fwiw I proposed "compression" as an arity-ambiguous version, but that wasn't popular enough


save_kwargs = dict(encoding={"var1": encoding})

with self.roundtrip(original, save_kwargs=save_kwargs) as ds:
enc = ds["var1"].encoding[encoding_key]
if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3:
# TODO: figure out a cleaner way to do this comparison
codecs = zarr.core.metadata.v3.parse_codecs(enc)
assert codecs == encoding[encoding_key]
else:
assert enc == encoding[encoding_key]
assert enc == encoding[encoding_key]

def test_group(self) -> None:
original = create_test_data()
Expand Down Expand Up @@ -2846,14 +2841,12 @@ def test_check_encoding_is_consistent_after_append(self) -> None:
import numcodecs

encoding_value: Any
if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3:
if has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3:
compressor = zarr.codecs.BloscCodec()
encoding_key = "codecs"
encoding_value = [zarr.codecs.BytesCodec(), compressor]
else:
compressor = numcodecs.Blosc()
encoding_key = "compressor"
encoding_value = compressor
encoding_key = "compressors" if has_zarr_v3 else "compressor"
encoding_value = (compressor,) if has_zarr_v3 else compressor

encoding = {"da": {encoding_key: encoding_value}}
ds.to_zarr(store_target, mode="w", encoding=encoding, **self.version_kwargs)
Expand Down Expand Up @@ -2995,7 +2988,7 @@ def test_no_warning_from_open_emptydim_with_chunks(self) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*Zarr version 3 specification.*",
message=".*Zarr format 3 specification.*",
category=UserWarning,
)
with self.roundtrip(ds, open_kwargs=dict(chunks={"a": 1})) as ds_reload:
Expand Down Expand Up @@ -5479,7 +5472,7 @@ def test_dataarray_to_netcdf_no_name_pathlib(self) -> None:
@requires_zarr
class TestDataArrayToZarr:
def skip_if_zarr_python_3_and_zip_store(self, store) -> None:
if has_zarr_v3 and isinstance(store, zarr.storage.zip.ZipStore):
if has_zarr_v3 and isinstance(store, zarr.storage.ZipStore):
pytest.skip(
reason="zarr-python 3.x doesn't support reopening ZipStore with a new mode."
)
Expand Down Expand Up @@ -5786,7 +5779,7 @@ def test_extract_zarr_variable_encoding() -> None:
var = xr.Variable("x", [1, 2])
actual = backends.zarr.extract_zarr_variable_encoding(var)
assert "chunks" in actual
assert actual["chunks"] is None
assert actual["chunks"] == ("auto" if has_zarr_v3 else None)

var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)})
actual = backends.zarr.extract_zarr_variable_encoding(var)
Expand Down Expand Up @@ -6092,14 +6085,14 @@ def test_raise_writing_to_nczarr(self, mode) -> None:

@requires_netCDF4
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_pickle_open_mfdataset_dataset():
with open_example_mfdataset(["bears.nc"]) as ds:
assert_identical(ds, pickle.loads(pickle.dumps(ds)))


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_closing_internal_zip_store():
store_name = "tmp.zarr.zip"
original_da = DataArray(np.arange(12).reshape((3, 4)))
Expand All @@ -6110,7 +6103,7 @@ def test_zarr_closing_internal_zip_store():


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
class TestZarrRegionAuto:
def test_zarr_region_auto_all(self, tmp_path):
x = np.arange(0, 50, 10)
Expand Down Expand Up @@ -6286,7 +6279,7 @@ def test_zarr_region_append(self, tmp_path):


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_region(tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
Expand Down Expand Up @@ -6315,7 +6308,7 @@ def test_zarr_region(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_region_chunk_partial(tmp_path):
"""
Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`.
Expand All @@ -6336,7 +6329,7 @@ def test_zarr_region_chunk_partial(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_append_chunk_partial(tmp_path):
t_coords = np.array([np.datetime64("2020-01-01").astype("datetime64[ns]")])
data = np.ones((10, 10))
Expand Down Expand Up @@ -6374,7 +6367,7 @@ def test_zarr_append_chunk_partial(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_region_chunk_partial_offset(tmp_path):
# https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545
store = tmp_path / "foo.zarr"
Expand All @@ -6394,7 +6387,7 @@ def test_zarr_region_chunk_partial_offset(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_safe_chunk_append_dim(tmp_path):
store = tmp_path / "foo.zarr"
data = np.ones((20,))
Expand Down Expand Up @@ -6445,7 +6438,7 @@ def test_zarr_safe_chunk_append_dim(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_safe_chunk_region(tmp_path):
store = tmp_path / "foo.zarr"

Expand Down
Loading