Skip to content

Commit

Permalink
improve type annotations in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
garymm committed Dec 6, 2024
1 parent 3365d12 commit bdcaece
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/jax_loop_utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ def evaluate(variables_p, test_ds):
"""

from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeVar, Protocol

from absl import logging
from collections.abc import Mapping, Sequence
from typing import Any, Protocol, TypeVar

from jax_loop_utils.internal import utils
import jax_loop_utils.values
from jax_loop_utils.internal import flax
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging

import jax_loop_utils.values
from jax_loop_utils.internal import flax, utils

Array = jax.Array
ArrayLike = jax.typing.ArrayLike
Expand Down Expand Up @@ -663,7 +663,11 @@ def unreplicate(self: C) -> C:


# Sentinel to make LastValue.__init__ support tree manipulations that use None.
_default = object()
class _DefaultSentinel:
pass


_default = _DefaultSentinel()


@flax.struct.dataclass
Expand All @@ -683,11 +687,11 @@ class LastValue(Metric):
total: jnp.ndarray
count: jnp.ndarray

def __init__( # pytype: disable=missing-parameter # jnp-array
def __init__(
self,
total: jnp.ndarray | _default = _default,
count: jnp.ndarray | _default = _default,
value: jnp.ndarray | _default = _default,
total: jnp.ndarray | _DefaultSentinel = _default,
count: jnp.ndarray | _DefaultSentinel = _default,
value: jnp.ndarray | _DefaultSentinel = _default,
):
"""Backward compatibility constructor.
Expand Down

0 comments on commit bdcaece

Please sign in to comment.