Skip to content

Commit

Permalink
move ensure_dtype_not_object from conventions to backends (#9828)
Browse files Browse the repository at this point in the history
* move ensure_dtype_not_object from conventions to backends
* add whats-new.rst entry
  • Loading branch information
kmuehlbauer authored Dec 2, 2024
1 parent 7fd572d commit 14a544c
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 125 deletions.
4 changes: 2 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~


- Move non-CF related ``ensure_dtype_not_object`` from conventions to backends (:pull:`9828`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

.. _whats-new.2024.11.0:

Expand Down
111 changes: 108 additions & 3 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,36 @@
import os
import time
import traceback
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Hashable, Iterable, Mapping, Sequence
from glob import glob
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload

import numpy as np
import pandas as pd

from xarray.coding import strings, variables
from xarray.coding.variables import SerializationWarning
from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.datatree import DataTree
from xarray.core.datatree import DataTree, Variable
from xarray.core.types import ReadBuffer
from xarray.core.utils import (
FrozenDict,
NdimSizeLenMixin,
attempt_import,
emit_user_level_warning,
is_remote_uri,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.utils import is_duck_dask_array

if TYPE_CHECKING:
from xarray.core.dataset import Dataset
from xarray.core.types import NestedSequence

T_Name = Union[Hashable, None]

# Create a logger object, but don't add any handlers. Leave that to user code.
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -527,13 +534,111 @@ def set_dimensions(self, variables, unlimited_dims=None):
self.set_dimension(dim, length, is_unlimited)


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")

if array.size == 0:
return np.dtype(float)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
if isinstance(element, bytes):
return strings.create_vlen_dtype(bytes)
elif isinstance(element, str):
return strings.create_vlen_dtype(str)

dtype = np.array(element).dtype
if dtype.kind != "O":
return dtype

raise ValueError(
f"unable to infer dtype on variable {name!r}; xarray "
"cannot serialize arbitrary Python objects"
)


def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
"""Create a copy of an array with the given dtype.
We use this instead of np.array() to ensure that custom object dtypes end
up on the resulting array.
"""
result = np.empty(data.shape, dtype)
result[...] = data
return result


def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
if var.dtype.kind == "O":
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
return var

if is_duck_dask_array(data):
emit_user_level_warning(
f"variable {name} has data in the form of a dask array with "
"dtype=object, which means it is being loaded into memory "
"to determine a data type that can be safely stored on disk. "
"To avoid this, coerce this variable to a fixed-size dtype "
"with astype() before saving it.",
category=SerializationWarning,
)
data = data.compute()

missing = pd.isnull(data)
if missing.any():
# nb. this will fail for dask.array data
non_missing_values = data[~missing]
inferred_dtype = _infer_dtype(non_missing_values, name)

# There is no safe bit-pattern for NA in typical binary string
# formats, we so can't set a fill_value. Unfortunately, this means
# we can't distinguish between missing values and empty strings.
fill_value: bytes | str
if strings.is_bytes_dtype(inferred_dtype):
fill_value = b""
elif strings.is_unicode_dtype(inferred_dtype):
fill_value = ""
else:
# insist on using float for numeric values
if not np.issubdtype(inferred_dtype, np.floating):
inferred_dtype = np.dtype(float)
fill_value = inferred_dtype.type(np.nan)

data = _copy_with_dtype(data, dtype=inferred_dtype)
data[missing] = fill_value
else:
data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))

assert data.dtype.kind != "O" or data.dtype.metadata
var = Variable(dims, data, attrs, encoding, fastpath=True)
return var


class WritableCFDataStore(AbstractWritableDataStore):
__slots__ = ()

def encode(self, variables, attributes):
# All NetCDF files get CF encoded by default, without this attempting
# to write times, for example, would fail.
variables, attributes = cf_encoder(variables, attributes)
variables = {
k: ensure_dtype_not_object(v, name=k) for k, v in variables.items()
}
variables = {k: self.encode_variable(v) for k, v in variables.items()}
attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
return variables, attributes
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_encode_variable_name,
_normalize_path,
datatree_from_dict_with_io_cleanup,
ensure_dtype_not_object,
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
Expand Down Expand Up @@ -507,6 +508,7 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
"""

var = conventions.encode_cf_variable(var, name=name)
var = ensure_dtype_not_object(var, name=name)

# zarr allows unicode, but not variable-length strings, so it's both
# simpler and more compact to always encode as UTF-8 explicitly.
Expand Down
100 changes: 0 additions & 100 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union

import numpy as np
import pandas as pd

from xarray.coding import strings, times, variables
from xarray.coding.variables import SerializationWarning, pop_to
Expand Down Expand Up @@ -50,41 +49,6 @@
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")

if array.size == 0:
return np.dtype(float)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
if isinstance(element, bytes):
return strings.create_vlen_dtype(bytes)
elif isinstance(element, str):
return strings.create_vlen_dtype(str)

dtype = np.array(element).dtype
if dtype.kind != "O":
return dtype

raise ValueError(
f"unable to infer dtype on variable {name!r}; xarray "
"cannot serialize arbitrary Python objects"
)


def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
# only the pandas multi-index dimension coordinate cannot be serialized (tuple values)
if isinstance(var._data, indexing.PandasMultiIndexingAdapter):
Expand All @@ -99,67 +63,6 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
)


def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
"""Create a copy of an array with the given dtype.
We use this instead of np.array() to ensure that custom object dtypes end
up on the resulting array.
"""
result = np.empty(data.shape, dtype)
result[...] = data
return result


def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == "O":
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
return var

if is_duck_dask_array(data):
emit_user_level_warning(
f"variable {name} has data in the form of a dask array with "
"dtype=object, which means it is being loaded into memory "
"to determine a data type that can be safely stored on disk. "
"To avoid this, coerce this variable to a fixed-size dtype "
"with astype() before saving it.",
category=SerializationWarning,
)
data = data.compute()

missing = pd.isnull(data)
if missing.any():
# nb. this will fail for dask.array data
non_missing_values = data[~missing]
inferred_dtype = _infer_dtype(non_missing_values, name)

# There is no safe bit-pattern for NA in typical binary string
# formats, we so can't set a fill_value. Unfortunately, this means
# we can't distinguish between missing values and empty strings.
fill_value: bytes | str
if strings.is_bytes_dtype(inferred_dtype):
fill_value = b""
elif strings.is_unicode_dtype(inferred_dtype):
fill_value = ""
else:
# insist on using float for numeric values
if not np.issubdtype(inferred_dtype, np.floating):
inferred_dtype = np.dtype(float)
fill_value = inferred_dtype.type(np.nan)

data = _copy_with_dtype(data, dtype=inferred_dtype)
data[missing] = fill_value
else:
data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))

assert data.dtype.kind != "O" or data.dtype.metadata
var = Variable(dims, data, attrs, encoding, fastpath=True)
return var


def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
) -> Variable:
Expand Down Expand Up @@ -196,9 +99,6 @@ def encode_cf_variable(
]:
var = coder.encode(var, name=name)

# TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends:
var = ensure_dtype_not_object(var, name=name)

for attr_name in CF_RELATED_DATA:
pop_to(var.encoding, var.attrs, attr_name)
return var
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,22 @@ def test_multiindex_not_implemented(self) -> None:
with self.roundtrip(ds_reset) as actual:
assert_identical(actual, ds_reset)

@requires_dask
def test_string_object_warning(self) -> None:
original = Dataset(
{
"x": (
[
"y",
],
np.array(["foo", "bar"], dtype=object),
)
}
).chunk()
with pytest.warns(SerializationWarning, match="dask array with dtype=object"):
with self.roundtrip(original) as actual:
assert_identical(original, actual)


class NetCDFBase(CFEncodedBase):
"""Tests for all netCDF3 and netCDF4 backends."""
Expand Down
15 changes: 14 additions & 1 deletion xarray/tests/test_backends_common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import numpy as np
import pytest

from xarray.backends.common import robust_getitem
from xarray.backends.common import _infer_dtype, robust_getitem


class DummyFailure(Exception):
Expand Down Expand Up @@ -30,3 +31,15 @@ def test_robust_getitem() -> None:
array = DummyArray(failures=3)
with pytest.raises(DummyFailure):
robust_getitem(array, ..., catch=DummyFailure, initial_delay=1, max_retries=2)


@pytest.mark.parametrize(
"data",
[
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
np.array([["x", 1], ["y", 2]], dtype="object"),
],
)
def test_infer_dtype_error_on_mixed_types(data):
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
_infer_dtype(data, "test")
19 changes: 0 additions & 19 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,6 @@ def test_emit_coordinates_attribute_in_encoding(self) -> None:
assert enc["b"].attrs.get("coordinates") == "t"
assert "coordinates" not in enc["b"].encoding

@requires_dask
def test_string_object_warning(self) -> None:
original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk()
with pytest.warns(SerializationWarning, match="dask array with dtype=object"):
encoded = conventions.encode_cf_variable(original)
assert_identical(original, encoded)


@requires_cftime
class TestDecodeCF:
Expand Down Expand Up @@ -593,18 +586,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
pass


@pytest.mark.parametrize(
"data",
[
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
np.array([["x", 1], ["y", 2]], dtype="object"),
],
)
def test_infer_dtype_error_on_mixed_types(data):
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
conventions._infer_dtype(data, "test")


class TestDecodeCFVariableWithArrayUnits:
def test_decode_cf_variable_with_array_units(self) -> None:
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})
Expand Down

0 comments on commit 14a544c

Please sign in to comment.