Skip to content

Commit

Permalink
Implement time_unit option for decode_cf_timedelta (#3)
Browse files Browse the repository at this point in the history
* Fix timedelta encoding overflow issue; always decode to ns resolution

* Implement time_unit for decode_cf_timedelta

* Reduce diff
  • Loading branch information
spencerkclark authored Jan 13, 2025
1 parent 0556376 commit ffc1828
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 25 deletions.
60 changes: 38 additions & 22 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,21 @@ def _check_date_for_units_since_refdate(
return pd.Timestamp("NaT")


def _check_timedelta_range(value, data_unit, time_unit):
if value > np.iinfo("int64").max or value < np.iinfo("int64").min:
OutOfBoundsTimedelta(f"Value {value} can't be represented as Timedelta.")
delta = value * np.timedelta64(1, data_unit)
if not np.isnan(delta):
# this will raise on dtype overflow for integer dtypes
if value.dtype.kind in "u" and not np.int64(delta) == value:
raise OutOfBoundsTimedelta(
"DType overflow in Datetime/Timedelta calculation."
)
# this will raise on overflow if delta cannot be represented with the
# resolutions supported by pandas.
pd.to_timedelta(delta)


def _align_reference_date_and_unit(
ref_date: pd.Timestamp, unit: NPDatetimeUnitOptions
) -> pd.Timestamp:
Expand Down Expand Up @@ -542,19 +557,6 @@ def decode_cf_datetime(
return reshape(dates, num_dates.shape)


def to_timedelta_unboxed(value, **kwargs):
# todo: check, if the procedure here is correct
result = pd.to_timedelta(value, **kwargs).to_numpy()
unique_timedeltas = np.unique(result[pd.notnull(result)])
unit = _netcdf_to_numpy_timeunit(_infer_time_units_from_diff(unique_timedeltas))
if unit not in {"s", "ms", "us", "ns"}:
# default to "ns", when not specified
unit = "ns"
result = result.astype(f"timedelta64[{unit}]")
assert np.issubdtype(result.dtype, "timedelta64")
return result


def to_datetime_unboxed(value, **kwargs):
result = pd.to_datetime(value, **kwargs).to_numpy()
assert np.issubdtype(result.dtype, "datetime64")
Expand Down Expand Up @@ -604,22 +606,36 @@ def _numbers_to_timedelta(
return flat_num.astype(f"timedelta64[{time_unit}]")


def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray:
# todo: check, if this works as intended
def decode_cf_timedelta(
num_timedeltas, units: str, time_unit: str = "ns"
) -> np.ndarray:
"""Given an array of numeric timedeltas in netCDF format, convert it into a
numpy timedelta64 ["s", "ms", "us", "ns"] array.
"""
num_timedeltas = np.asarray(num_timedeltas)
unit = _netcdf_to_numpy_timeunit(units)

_check_timedelta_range(num_timedeltas.min(), unit, time_unit)
_check_timedelta_range(num_timedeltas.max(), unit, time_unit)

timedeltas = _numbers_to_timedelta(num_timedeltas, unit, "s", "timedelta")
timedeltas = pd.to_timedelta(ravel(timedeltas))

if np.isnat(timedeltas).all():
empirical_unit = time_unit
else:
empirical_unit = timedeltas.unit

if np.timedelta64(1, time_unit) > np.timedelta64(1, empirical_unit):
time_unit = empirical_unit

if time_unit not in {"s", "ms", "us", "ns"}:
raise ValueError(
f"time_unit must be one of 's', 'ms', 'us', or 'ns'. Got: {time_unit}"
)

as_unit = unit
if unit not in {"s", "ms", "us", "ns"}:
# default to "ns", when not specified
as_unit = "ns"
result = pd.to_timedelta(ravel(timedeltas)).as_unit(as_unit).to_numpy()
return reshape(result, timedeltas.shape)
result = timedeltas.as_unit(time_unit).to_numpy()
return reshape(result, num_timedeltas.shape)


def _unit_timedelta_cftime(units: str) -> timedelta:
Expand Down Expand Up @@ -700,7 +716,7 @@ def infer_timedelta_units(deltas) -> str:
{'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly
divide all unique time deltas in `deltas`)
"""
deltas = to_timedelta_unboxed(ravel(np.asarray(deltas)))
deltas = ravel(deltas)
unique_timedeltas = np.unique(deltas[pd.notnull(deltas)])
return _infer_time_units_from_diff(unique_timedeltas)

Expand Down
37 changes: 34 additions & 3 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
format_cftime_datetime,
infer_datetime_units,
infer_timedelta_units,
to_timedelta_unboxed,
)
from xarray.coding.variables import SerializationWarning
from xarray.conventions import _update_bounds_attributes, cf_encoder
Expand Down Expand Up @@ -635,7 +634,7 @@ def test_cf_timedelta(timedeltas, units, numbers) -> None:
if timedeltas == "NaT":
timedeltas = np.timedelta64("NaT", "ns")
else:
timedeltas = to_timedelta_unboxed(timedeltas)
timedeltas = pd.to_timedelta(timedeltas).to_numpy()
numbers = np.array(numbers)

expected = numbers
Expand All @@ -659,14 +658,46 @@ def test_cf_timedelta_2d() -> None:
units = "days"
numbers = np.atleast_2d([1, 2, 3])

timedeltas = np.atleast_2d(to_timedelta_unboxed(["1D", "2D", "3D"]))
timedeltas = np.atleast_2d(pd.to_timedelta(["1D", "2D", "3D"]).to_numpy())
expected = timedeltas

actual = decode_cf_timedelta(numbers, units)
assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


@pytest.mark.parametrize("encoding_unit", FREQUENCIES_TO_ENCODING_UNITS.values())
def test_decode_cf_timedelta_time_unit(time_unit, encoding_unit) -> None:
encoded = 1
encoding_unit_as_numpy = _netcdf_to_numpy_timeunit(encoding_unit)
if np.timedelta64(1, time_unit) > np.timedelta64(1, encoding_unit_as_numpy):
expected = np.timedelta64(encoded, encoding_unit_as_numpy)
else:
expected = np.timedelta64(encoded, encoding_unit_as_numpy).astype(
f"timedelta64[{time_unit}]"
)
result = decode_cf_timedelta(encoded, encoding_unit, time_unit)
assert result == expected
assert result.dtype == expected.dtype


def test_decode_cf_timedelta_time_unit_out_of_bounds(time_unit):
# Define a scale factor that will guarantee overflow with the given
# time_unit.
scale_factor = np.timedelta64(1, time_unit) // np.timedelta64(1, "ns")
encoded = scale_factor * 300 * 365
with pytest.raises(OutOfBoundsTimedelta):
decode_cf_timedelta(encoded, "days", time_unit)


def test_cf_timedelta_roundtrip_large_value(time_unit):
value = np.timedelta64(np.iinfo(np.int64).max, time_unit)
encoded, units = encode_cf_timedelta(value)
decoded = decode_cf_timedelta(encoded, units, time_unit=time_unit)
assert value == decoded
assert value.dtype == decoded.dtype


@pytest.mark.parametrize(
["deltas", "expected"],
[
Expand Down

0 comments on commit ffc1828

Please sign in to comment.