diff --git a/src/jax_loop_utils/metrics.py b/src/jax_loop_utils/metrics.py index 64f99b4..7c63c6b 100644 --- a/src/jax_loop_utils/metrics.py +++ b/src/jax_loop_utils/metrics.py @@ -57,17 +57,17 @@ def evaluate(variables_p, test_ds): """ from __future__ import annotations -from collections.abc import Mapping, Sequence -from typing import Any, TypeVar, Protocol -from absl import logging +from collections.abc import Mapping, Sequence +from typing import Any, Protocol, TypeVar -from jax_loop_utils.internal import utils -import jax_loop_utils.values -from jax_loop_utils.internal import flax import jax import jax.numpy as jnp import numpy as np +from absl import logging + +import jax_loop_utils.values +from jax_loop_utils.internal import flax, utils Array = jax.Array ArrayLike = jax.typing.ArrayLike @@ -663,7 +663,11 @@ def unreplicate(self: C) -> C: # Sentinel to make LastValue.__init__ support tree manipulations that use None. -_default = object() +class _DefaultSentinel: + pass + + +_default = _DefaultSentinel() @flax.struct.dataclass @@ -683,11 +687,11 @@ class LastValue(Metric): total: jnp.ndarray count: jnp.ndarray - def __init__( # pytype: disable=missing-parameter # jnp-array + def __init__( self, - total: jnp.ndarray | _default = _default, - count: jnp.ndarray | _default = _default, - value: jnp.ndarray | _default = _default, + total: jnp.ndarray | _DefaultSentinel = _default, + count: jnp.ndarray | _DefaultSentinel = _default, + value: jnp.ndarray | _DefaultSentinel = _default, ): """Backward compatibility constructor.