Skip to content

Commit

Permalink
Fix computation log types
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Dec 1, 2022
1 parent d9d3501 commit 40233ec
Show file tree
Hide file tree
Showing 19 changed files with 202 additions and 179 deletions.
14 changes: 7 additions & 7 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "style"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def style(session, numpy):
"""Run tests."""

Expand All @@ -21,7 +21,7 @@ def style(session, numpy):


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "docs"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def docs(session, numpy):
"""Run tests."""

Expand All @@ -35,7 +35,7 @@ def docs(session, numpy):


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "mypy"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def mypy(session, numpy):
"""Run tests."""

Expand All @@ -49,7 +49,7 @@ def mypy(session, numpy):


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("lint", "mypy-hxc"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def mypy_hxc(session, numpy):
"""Run tests."""

Expand All @@ -63,7 +63,7 @@ def mypy_hxc(session, numpy):


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("test", "test-core"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def test_core(session, numpy):
"""Run tests."""

Expand All @@ -79,7 +79,7 @@ def test_core(session, numpy):


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("test", "test-country"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def test_country(session, numpy):
"""Run tests."""

Expand All @@ -95,7 +95,7 @@ def test_country(session, numpy):


@nox.session(python = ("3.9", "3.8", "3.7"), tags = ("test", "test-extension"))
@nox.parametrize("numpy", ("1.22", "1.20", "1.21"))
@nox.parametrize("numpy", ("1.23", "1.22", "1.21"))
def test_extension(session, numpy):
"""Run tests."""

Expand Down
5 changes: 3 additions & 2 deletions openfisca_core/commons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* :func:`.average_rate`
* :func:`.concat`
* :func:`.empty_clone`
* :func:`.flatten`
* :func:`.marginal_rate`
* :func:`.stringify_array`
* :func:`.switch`
Expand Down Expand Up @@ -53,11 +54,11 @@
# Official Public API

from .formulas import apply_thresholds, concat, switch # noqa: F401
from .misc import empty_clone, stringify_array # noqa: F401
from .misc import empty_clone, flatten, stringify_array # noqa: F401
from .rates import average_rate, marginal_rate # noqa: F401

__all__ = ["apply_thresholds", "concat", "switch"]
__all__ = ["empty_clone", "stringify_array", *__all__]
__all__ = ["empty_clone", "flatten", "stringify_array", *__all__]
__all__ = ["average_rate", "marginal_rate", *__all__]

# Deprecated
Expand Down
11 changes: 5 additions & 6 deletions openfisca_core/commons/formulas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Sequence
from typing import Any, Dict, Sequence, Union

import numpy

Expand Down Expand Up @@ -29,21 +29,20 @@ def apply_thresholds(
Examples:
>>> input = numpy.array([4, 5, 6, 7, 8])
>>> thresholds =
[5, 7]
>>> thresholds = [5, 7]
>>> choices = [10, 15, 20]
>>> apply_thresholds(input, thresholds, choices)
array([10, 10, 15, 15, 20])
"""

condlist: Sequence[numpy.bool_]
condlist: Sequence[Union[bool, numpy.bool_]]
condlist = [input <= threshold for threshold in thresholds]

if len(condlist) == len(choices) - 1:
# If a choice is provided for input > highest threshold, last condition
# must be true to return it.
condlist += numpy.array([True])
condlist += [True]

assert len(condlist) == len(choices), \
" ".join([
Expand Down Expand Up @@ -119,7 +118,7 @@ def switch(

condlist = [
conditions == condition
for condition in tuple(value_by_condition.keys())
for condition in value_by_condition.keys()
]

return numpy.select(condlist, tuple(value_by_condition.values()))
35 changes: 30 additions & 5 deletions openfisca_core/commons/misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import TypeVar
from typing import Any, Iterator, Optional, Sequence, TypeVar

import numpy
import itertools

from openfisca_core import types

T = TypeVar("T")


def empty_clone(original: T) -> T:
"""Creates an empty instance of the same class of the original object.
"""Create an empty instance of the same class of the original object.
Args:
original: An object to clone.
Expand Down Expand Up @@ -43,8 +45,31 @@ def empty_clone(original: T) -> T:
return new


def stringify_array(array: numpy.ndarray) -> str:
"""Generates a clean string representation of a numpy array.
def flatten(seqs: Sequence[Sequence[T]]) -> Iterator[T]:
"""Flatten a sequence of sequences.
Args:
seqs: Any sequence of sequences.
Returns:
An iterator with the values.
Examples:
>>> list(flatten([(1, 2), (3, 4)]))
[1, 2, 3, 4]
>>> list(flatten(["ab", "cd"]))
['a', 'b', 'c', 'd']
.. versionadded:: 36.0.0
"""

return itertools.chain.from_iterable(seqs)


def stringify_array(array: Optional[types.Array[Any]]) -> str:
"""Generate a clean string representation of a numpy array.
Args:
array: An array.
Expand Down
4 changes: 2 additions & 2 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _forbidden_operation(self, other: Any) -> NoReturn:
__and__ = _forbidden_operation
__or__ = _forbidden_operation

def decode(self) -> numpy.object_:
def decode(self) -> numpy.ndarray:
"""
Return the array of enum items corresponding to self.
Expand All @@ -82,7 +82,7 @@ def decode(self) -> numpy.object_:
list(self.possible_values),
)

def decode_to_str(self) -> numpy.str_:
def decode_to_str(self) -> numpy.ndarray:
"""
Return the array of string identifiers corresponding to self.
Expand Down
24 changes: 12 additions & 12 deletions openfisca_core/populations/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from openfisca_core import periods, projectors
from openfisca_core.holders import Holder, MemoryUsage
from openfisca_core.projectors import Projector
from openfisca_core.types import Array, Entity, Period, Role, Simulation
from openfisca_core.types import Entity, Period, Role, Simulation

from . import config

Expand All @@ -21,14 +21,14 @@ class Population:
entity: Entity
_holders: Dict[str, Holder]
count: int
ids: Array[str]
ids: numpy.ndarray

def __init__(self, entity: Entity) -> None:
self.simulation = None
self.entity = entity
self._holders = {}
self.count = 0
self.ids = []
self.ids = numpy.array([])

def clone(self, simulation: Simulation) -> Population:
result = Population(self.entity)
Expand All @@ -38,14 +38,14 @@ def clone(self, simulation: Simulation) -> Population:
result.ids = self.ids
return result

def empty_array(self) -> Array[float]:
def empty_array(self) -> numpy.ndarray:
return numpy.zeros(self.count)

def filled_array(
self,
value: Union[float, bool],
dtype: Optional[numpy.dtype] = None,
) -> Union[Array[float], Array[bool]]:
) -> numpy.ndarray:
return numpy.full(self.count, value, dtype)

def __getattr__(self, attribute: str) -> Projector:
Expand All @@ -64,7 +64,7 @@ def get_index(self, id: str) -> int:

def check_array_compatible_with_entity(
self,
array: Array[float],
array: numpy.ndarray,
) -> None:
if self.count == array.size:
return None
Expand Down Expand Up @@ -95,7 +95,7 @@ def __call__(
variable_name: str,
period: Optional[Union[int, str, Period]] = None,
options: Optional[Sequence[str]] = None,
) -> Optional[Array[float]]:
) -> Optional[Sequence[float]]:
"""
Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists.
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_memory_usage(
})

@projectors.projectable
def has_role(self, role: Role) -> Optional[Array[bool]]:
def has_role(self, role: Role) -> Optional[Sequence[bool]]:
"""
Check if a person has a given role within its `GroupEntity`
Expand All @@ -195,10 +195,10 @@ def has_role(self, role: Role) -> Optional[Array[bool]]:
@projectors.projectable
def value_from_partner(
self,
array: Array[float],
array: numpy.ndarray,
entity: Projector,
role: Role,
) -> Optional[Array[float]]:
) -> Optional[numpy.ndarray]:
self.check_array_compatible_with_entity(array)
self.entity.check_role_validity(role)

Expand All @@ -218,9 +218,9 @@ def value_from_partner(
def get_rank(
self,
entity: Population,
criteria: Array[float],
criteria: Sequence[float],
condition: bool = True,
) -> Array[int]:
) -> numpy.ndarray:
"""
Get the rank of a person within an entity according to a criteria.
The person with rank 0 has the minimum value of criteria.
Expand Down
48 changes: 18 additions & 30 deletions openfisca_core/tracers/computation_log.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,31 @@
from __future__ import annotations

import typing
from typing import List, Optional, Union
from typing import Any, Optional, Sequence

import numpy

from .. import tracers
from openfisca_core.indexed_enums import EnumArray
import sys

if typing.TYPE_CHECKING:
from numpy.typing import ArrayLike
import numpy

Array = Union[EnumArray, ArrayLike]
from openfisca_core import commons, types


class ComputationLog:
_full_tracer: types.FullTracer

_full_tracer: tracers.FullTracer

def __init__(self, full_tracer: tracers.FullTracer) -> None:
def __init__(self, full_tracer: types.FullTracer) -> None:
self._full_tracer = full_tracer

def display(
self,
value: Optional[Array],
) -> str:
if isinstance(value, EnumArray):
def display(self, value: types.Array[Any]) -> str:
if isinstance(value, types.EnumArray):
value = value.decode_to_str()

return numpy.array2string(value, max_line_width = None)
return numpy.array2string(value, max_line_width = sys.maxsize)

def lines(
self,
aggregate: bool = False,
max_depth: Optional[int] = None,
) -> List[str]:
) -> Sequence[str]:
depth = 1

lines_by_tree = [
Expand All @@ -43,7 +34,7 @@ def lines(
in self._full_tracer.trees
]

return self._flatten(lines_by_tree)
return tuple(commons.flatten(lines_by_tree))

def print_log(self, aggregate = False, max_depth = None) -> None:
"""
Expand All @@ -67,11 +58,14 @@ def print_log(self, aggregate = False, max_depth = None) -> None:

def _get_node_log(
self,
node: tracers.TraceNode,
node: types.TraceNode,
depth: int,
aggregate: bool,
max_depth: Optional[int],
) -> List[str]:
) -> Sequence[str]:

node_log: Sequence[str]
children_log: Sequence[Sequence[str]]

if max_depth is not None and depth > max_depth:
return []
Expand All @@ -84,12 +78,12 @@ def _get_node_log(
in node.children
]

return node_log + self._flatten(children_logs)
return [*node_log, *commons.flatten(children_logs)]

def _print_line(
self,
depth: int,
node: tracers.TraceNode,
node: types.TraceNode,
aggregate: bool,
max_depth: Optional[int],
) -> str:
Expand All @@ -114,9 +108,3 @@ def _print_line(
formatted_value = self.display(value)

return f"{indent}{node.name}<{node.period}> >> {formatted_value}"

def _flatten(
self,
lists: List[List[str]],
) -> List[str]:
return [item for list_ in lists for item in list_]
Loading

0 comments on commit 40233ec

Please sign in to comment.