Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Byte attr support #9407

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_name(name: Hashable):
check_name(k)


def _validate_attrs(dataset, invalid_netcdf=False):
def _validate_attrs(dataset, engine, invalid_netcdf=False):
"""`attrs` must have a string key and a value which is either: a number,
a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_.

Expand All @@ -177,8 +177,8 @@ def _validate_attrs(dataset, invalid_netcdf=False):
`invalid_netcdf=True`.
"""

valid_types = (str, Number, np.ndarray, np.number, list, tuple)
if invalid_netcdf:
valid_types = (str, Number, np.ndarray, np.number, list, tuple, bytes)
if invalid_netcdf and engine == "h5netcdf":
valid_types += (np.bool_,)

def check_attr(name, value, valid_types):
Expand All @@ -202,6 +202,23 @@ def check_attr(name, value, valid_types):
f"{', '.join([vtype.__name__ for vtype in valid_types])}"
)

if isinstance(value, bytes) and engine == "h5netcdf":
try:
value.decode("utf-8")
except UnicodeDecodeError as e:
raise ValueError(
f"Invalid value provided for attribute '{name!r}': {value!r}. "
"Only binary data derived from UTF-8 encoded strings is allowed "
f"for the '{engine}' engine. Consider using the 'netcdf4' engine."
) from e

if b"\x00" in value:
raise ValueError(
f"Invalid value provided for attribute '{name!r}': {value!r}. "
f"Null characters are not permitted for the '{engine}' engine. "
"Consider using the 'netcdf4' engine."
)

# Check attrs on the dataset itself
for k, v in dataset.attrs.items():
check_attr(k, v, valid_types)
Expand Down Expand Up @@ -1353,7 +1370,7 @@ def to_netcdf(

# validate Dataset keys, DataArray names, and attr keys/values
_validate_dataset_names(dataset)
_validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf")
_validate_attrs(dataset, engine, invalid_netcdf)

try:
store_open = WRITEABLE_STORES[engine]
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ def d(request, backend, type) -> DataArray | Dataset:
raise ValueError


@pytest.fixture
def byte_attrs_dataset():
"""For testing issue #9407"""
null_byte = b"\x00"
other_bytes = bytes(range(1, 256))
ds = Dataset({"x": 1}, coords={"x_coord": [1]})
ds["x"].attrs["null_byte"] = null_byte
ds["x"].attrs["other_bytes"] = other_bytes

expected = ds.copy()
expected["x"].attrs["null_byte"] = ""
expected["x"].attrs["other_bytes"] = other_bytes.decode(errors="replace")

return {
"input": ds,
"expected": expected,
"h5netcdf_error": r"Invalid value provided for attribute .*: .*\. Null characters .*",
}
Comment on lines +142 to +159
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!



@pytest.fixture(scope="module")
def create_test_datatree():
"""
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,13 @@ def test_refresh_from_disk(self) -> None:
a.close()
b.close()

def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
# test for issue #9407
input = byte_attrs_dataset["input"]
expected = byte_attrs_dataset["expected"]
with self.roundtrip(input) as actual:
assert_identical(actual, expected)


_counter = itertools.count()

Expand Down Expand Up @@ -3861,6 +3868,10 @@ def test_decode_utf8_warning(self) -> None:
assert ds.title == title
assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message)

def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
with pytest.raises(ValueError, match=byte_attrs_dataset["h5netcdf_error"]):
super().test_byte_attrs(byte_attrs_dataset)


@requires_h5netcdf
@requires_netCDF4
Expand Down
Loading