Skip to content

Commit

Permalink
Fix zarr upstream tests (#9927)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Iannucci <[email protected]>
  • Loading branch information
dcherian and mpiannucci authored Jan 7, 2025
1 parent 251329e commit 1e5045a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 48 deletions.
21 changes: 10 additions & 11 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,11 @@ def extract_zarr_variable_encoding(

safe_to_drop = {"source", "original_shape", "preferred_chunks"}
valid_encodings = {
"codecs",
"chunks",
"compressor",
"compressor", # TODO: delete when min zarr >=3
"compressors",
"filters",
"serializer",
"cache_metadata",
"write_empty_chunks",
}
Expand Down Expand Up @@ -480,6 +481,8 @@ def extract_zarr_variable_encoding(
mode=mode,
shape=shape,
)
if _zarr_v3() and chunks is None:
chunks = "auto"
encoding["chunks"] = chunks
return encoding

Expand Down Expand Up @@ -816,24 +819,20 @@ def open_store_variable(self, name):
)
attributes = dict(attributes)

# TODO: this should not be needed once
# https://github.com/zarr-developers/zarr-python/issues/1269 is resolved.
attributes.pop("filters", None)

encoding = {
"chunks": zarr_array.chunks,
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)),
}

if _zarr_v3() and zarr_array.metadata.zarr_format == 3:
encoding["codecs"] = [x.to_dict() for x in zarr_array.metadata.codecs]
elif _zarr_v3():
if _zarr_v3():
encoding.update(
{
"compressor": zarr_array.metadata.compressor,
"filters": zarr_array.metadata.filters,
"compressors": zarr_array.compressors,
"filters": zarr_array.filters,
}
)
if self.zarr_group.metadata.zarr_format == 3:
encoding.update({"serializer": zarr_array.serializer})
else:
encoding.update(
{
Expand Down
67 changes: 30 additions & 37 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,21 @@


@pytest.fixture(scope="module", params=ZARR_FORMATS)
def default_zarr_version(request) -> Generator[None, None]:
def default_zarr_format(request) -> Generator[None, None]:
if has_zarr_v3:
with zarr.config.set(default_zarr_version=request.param):
with zarr.config.set(default_zarr_format=request.param):
yield
else:
yield


def skip_if_zarr_format_3(reason: str):
if has_zarr_v3 and zarr.config["default_zarr_version"] == 3:
if has_zarr_v3 and zarr.config["default_zarr_format"] == 3:
pytest.skip(reason=f"Unsupported with zarr_format=3: {reason}")


def skip_if_zarr_format_2(reason: str):
if not has_zarr_v3 or (zarr.config["default_zarr_version"] == 2):
if not has_zarr_v3 or (zarr.config["default_zarr_format"] == 2):
pytest.skip(reason=f"Unsupported with zarr_format=2: {reason}")


Expand Down Expand Up @@ -2270,7 +2270,7 @@ def test_roundtrip_coordinates(self) -> None:


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
class ZarrBase(CFEncodedBase):
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
zarr_version = 2
Expand Down Expand Up @@ -2439,7 +2439,7 @@ def test_warning_on_bad_chunks(self) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*Zarr version 3 specification.*",
message=".*Zarr format 3 specification.*",
category=UserWarning,
)
with self.roundtrip(original, open_kwargs=kwargs) as actual:
Expand Down Expand Up @@ -2675,40 +2675,35 @@ def test_write_persistence_modes(self, group) -> None:
assert_identical(original, actual)

def test_compressor_encoding(self) -> None:
original = create_test_data()
# specify a custom compressor

if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3:
encoding_key = "codecs"
original = create_test_data()
if has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3:
encoding_key = "compressors"
# all parameters need to be explicitly specified in order for the comparison to pass below
encoding = {
"serializer": zarr.codecs.BytesCodec(endian="little"),
encoding_key: (
zarr.codecs.BytesCodec(endian="little"),
zarr.codecs.BloscCodec(
cname="zstd",
clevel=3,
shuffle="shuffle",
typesize=8,
blocksize=0,
),
)
),
}
else:
from numcodecs.blosc import Blosc

encoding_key = "compressor"
encoding = {encoding_key: Blosc(cname="zstd", clevel=3, shuffle=2)}
encoding_key = "compressors" if has_zarr_v3 else "compressor"
comp = Blosc(cname="zstd", clevel=3, shuffle=2)
encoding = {encoding_key: (comp,) if has_zarr_v3 else comp}

save_kwargs = dict(encoding={"var1": encoding})

with self.roundtrip(original, save_kwargs=save_kwargs) as ds:
enc = ds["var1"].encoding[encoding_key]
if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3:
# TODO: figure out a cleaner way to do this comparison
codecs = zarr.core.metadata.v3.parse_codecs(enc)
assert codecs == encoding[encoding_key]
else:
assert enc == encoding[encoding_key]
assert enc == encoding[encoding_key]

def test_group(self) -> None:
original = create_test_data()
Expand Down Expand Up @@ -2846,14 +2841,12 @@ def test_check_encoding_is_consistent_after_append(self) -> None:
import numcodecs

encoding_value: Any
if has_zarr_v3 and zarr.config.config["default_zarr_version"] == 3:
if has_zarr_v3 and zarr.config.config["default_zarr_format"] == 3:
compressor = zarr.codecs.BloscCodec()
encoding_key = "codecs"
encoding_value = [zarr.codecs.BytesCodec(), compressor]
else:
compressor = numcodecs.Blosc()
encoding_key = "compressor"
encoding_value = compressor
encoding_key = "compressors" if has_zarr_v3 else "compressor"
encoding_value = (compressor,) if has_zarr_v3 else compressor

encoding = {"da": {encoding_key: encoding_value}}
ds.to_zarr(store_target, mode="w", encoding=encoding, **self.version_kwargs)
Expand Down Expand Up @@ -2995,7 +2988,7 @@ def test_no_warning_from_open_emptydim_with_chunks(self) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*Zarr version 3 specification.*",
message=".*Zarr format 3 specification.*",
category=UserWarning,
)
with self.roundtrip(ds, open_kwargs=dict(chunks={"a": 1})) as ds_reload:
Expand Down Expand Up @@ -5479,7 +5472,7 @@ def test_dataarray_to_netcdf_no_name_pathlib(self) -> None:
@requires_zarr
class TestDataArrayToZarr:
def skip_if_zarr_python_3_and_zip_store(self, store) -> None:
if has_zarr_v3 and isinstance(store, zarr.storage.zip.ZipStore):
if has_zarr_v3 and isinstance(store, zarr.storage.ZipStore):
pytest.skip(
reason="zarr-python 3.x doesn't support reopening ZipStore with a new mode."
)
Expand Down Expand Up @@ -5786,7 +5779,7 @@ def test_extract_zarr_variable_encoding() -> None:
var = xr.Variable("x", [1, 2])
actual = backends.zarr.extract_zarr_variable_encoding(var)
assert "chunks" in actual
assert actual["chunks"] is None
assert actual["chunks"] == ("auto" if has_zarr_v3 else None)

var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)})
actual = backends.zarr.extract_zarr_variable_encoding(var)
Expand Down Expand Up @@ -6092,14 +6085,14 @@ def test_raise_writing_to_nczarr(self, mode) -> None:

@requires_netCDF4
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_pickle_open_mfdataset_dataset():
with open_example_mfdataset(["bears.nc"]) as ds:
assert_identical(ds, pickle.loads(pickle.dumps(ds)))


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_closing_internal_zip_store():
store_name = "tmp.zarr.zip"
original_da = DataArray(np.arange(12).reshape((3, 4)))
Expand All @@ -6110,7 +6103,7 @@ def test_zarr_closing_internal_zip_store():


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
class TestZarrRegionAuto:
def test_zarr_region_auto_all(self, tmp_path):
x = np.arange(0, 50, 10)
Expand Down Expand Up @@ -6286,7 +6279,7 @@ def test_zarr_region_append(self, tmp_path):


@requires_zarr
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_region(tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
Expand Down Expand Up @@ -6315,7 +6308,7 @@ def test_zarr_region(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_region_chunk_partial(tmp_path):
"""
Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`.
Expand All @@ -6336,7 +6329,7 @@ def test_zarr_region_chunk_partial(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_append_chunk_partial(tmp_path):
t_coords = np.array([np.datetime64("2020-01-01").astype("datetime64[ns]")])
data = np.ones((10, 10))
Expand Down Expand Up @@ -6374,7 +6367,7 @@ def test_zarr_append_chunk_partial(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_region_chunk_partial_offset(tmp_path):
# https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545
store = tmp_path / "foo.zarr"
Expand All @@ -6394,7 +6387,7 @@ def test_zarr_region_chunk_partial_offset(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_safe_chunk_append_dim(tmp_path):
store = tmp_path / "foo.zarr"
data = np.ones((20,))
Expand Down Expand Up @@ -6445,7 +6438,7 @@ def test_zarr_safe_chunk_append_dim(tmp_path):

@requires_zarr
@requires_dask
@pytest.mark.usefixtures("default_zarr_version")
@pytest.mark.usefixtures("default_zarr_format")
def test_zarr_safe_chunk_region(tmp_path):
store = tmp_path / "foo.zarr"

Expand Down

0 comments on commit 1e5045a

Please sign in to comment.