Skip to content

Commit

Permalink
test: fix types 2
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Dec 8, 2023
1 parent cf1b176 commit 237ac2c
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 53 deletions.
12 changes: 11 additions & 1 deletion openfisca_core/simulations/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from __future__ import annotations

from typing import Type, Union

from openfisca_core.errors import SituationParsingError

from .typing import FullyDefinedParamsWithoutShortcut


def calculate_output_add(simulation, variable_name: str, period):
return simulation.calculate_add(variable_name, period)
Expand All @@ -9,7 +15,11 @@ def calculate_output_divide(simulation, variable_name: str, period):
return simulation.calculate_divide(variable_name, period)


def check_type(input, input_type, path=None):
def check_type(
input: FullyDefinedParamsWithoutShortcut,
input_type: Type[Union[dict, list, str]],
path: list[str] | None = None,
) -> None:
json_type_map = {
dict: "Object",
list: "Array",
Expand Down
15 changes: 8 additions & 7 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from openfisca_core.types import Population, TaxBenefitSystem, Variable
from typing import Dict, NamedTuple, Optional, Set
from typing import Dict, NamedTuple, Optional, Set, Union

import tempfile
import warnings
Expand All @@ -11,20 +10,22 @@
from openfisca_core import commons, errors, indexed_enums, periods, tracers
from openfisca_core import warnings as core_warnings

from .typing import GroupPopulation, SinglePopulation, TaxBenefitSystem, Variable


class Simulation:
"""
Represents a simulation, and handles the calculation logic
"""

tax_benefit_system: TaxBenefitSystem
populations: Dict[str, Population]
populations: Dict[str, Union[SinglePopulation, GroupPopulation]]
invalidated_caches: Set[Cache]

def __init__(
self,
tax_benefit_system: TaxBenefitSystem,
populations: Dict[str, Population],
populations: Dict[str, Union[SinglePopulation, GroupPopulation]],
):
"""
This constructor is reserved for internal use; see :any:`SimulationBuilder`,
Expand Down Expand Up @@ -530,7 +531,7 @@ def set_input(self, variable_name: str, period, value):
return
self.get_holder(variable_name).set_input(period, value)

def get_variable_population(self, variable_name: str) -> Population:
def get_variable_population(self, variable_name: str) -> GroupPopulation:
variable: Optional[Variable]

variable = self.tax_benefit_system.get_variable(
Expand All @@ -542,7 +543,7 @@ def get_variable_population(self, variable_name: str) -> Population:

return self.populations[variable.entity.key]

def get_population(self, plural: Optional[str] = None) -> Optional[Population]:
def get_population(self, plural: Optional[str] = None) -> Optional[GroupPopulation]:
return next(
(
population
Expand All @@ -555,7 +556,7 @@ def get_population(self, plural: Optional[str] = None) -> Optional[Population]:
def get_entity(
self,
plural: Optional[str] = None,
) -> Optional[Population]:
) -> Optional[GroupPopulation]:
population = self.get_population(plural)
return population and population.entity

Expand Down
144 changes: 106 additions & 38 deletions openfisca_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import typing
from collections.abc import Iterable, Sequence
from numpy.typing import NDArray as Array
from typing import Any, Optional

import copy

Expand All @@ -13,7 +15,22 @@
from . import helpers
from ._axis import _Axis
from .simulation import Simulation
from .typing import AxisParams, Entity, GroupPopulation, Role
from .typing import (
AxesParams,
AxisParams,
Entity,
FullyDefinedParams,
FullyDefinedParamsWithoutAxes,
FullyDefinedParamsWithoutShortcut,
GroupEntity,
GroupEntityParams,
GroupPopulation,
Role,
SingleEntity,
SingleEntityParams,
TaxBenefitSystem,
VariableParams,
)


class SimulationBuilder:
Expand All @@ -25,7 +42,7 @@ class SimulationBuilder:

#: JSON input - Memory of known input values. Indexed by variable or
#: axis name.
input_buffer: dict[str, dict[str, Array]] = {}
input_buffer: dict[str, dict[str, Array[Any]]] = {}

#: ?
populations: dict[str, GroupPopulation] = {}
Expand Down Expand Up @@ -63,17 +80,20 @@ class SimulationBuilder:
#: ?
axes_roles: dict[str, list[Role]]

def __init__(self):
self.input_buffer: dict[str, dict[str, Array]] = {}
self.entity_counts: dict[str, int] = {}
def __init__(self) -> None:
self.input_buffer = {}
self.entity_counts = {}
self.axes = [[]]
self.axes_entity_counts: dict[str, int] = {}
self.axes_entity_ids: dict[str, list[str]] = {}
self.axes_memberships: dict[str, list[int]] = {}
self.axes_roles: dict[str, list[int]] = {}
self.axes_entity_counts = {}
self.axes_entity_ids = {}
self.axes_memberships = {}
self.axes_roles = {}


def build_from_dict(self, tax_benefit_system, input_dict):
def build_from_dict(
self,
tax_benefit_system: TaxBenefitSystem,
input_dict: FullyDefinedParams | VariableParams,
) -> Simulation:
"""
Build a simulation from ``input_dict``
Expand All @@ -83,26 +103,36 @@ def build_from_dict(self, tax_benefit_system, input_dict):
:return: A :any:`Simulation`
"""

input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict)
if any(
key in tax_benefit_system.entities_plural() for key in input_dict.keys()
):
return self.build_from_entities(tax_benefit_system, input_dict)
fully_defined_params = typing.cast(FullyDefinedParams, input_dict)
fully_defined_params_without_shortcut = self.explicit_singular_entities(
tax_benefit_system, fully_defined_params
)
return self.build_from_entities(
tax_benefit_system, fully_defined_params_without_shortcut
)
else:
return self.build_from_variables(tax_benefit_system, input_dict)
variable_params = typing.cast(VariableParams, input_dict)
return self.build_from_variables(tax_benefit_system, variable_params)

def build_from_entities(self, tax_benefit_system, input_dict):
def build_from_entities(
self,
tax_benefit_system: TaxBenefitSystem,
input_dict: FullyDefinedParamsWithoutShortcut,
) -> Simulation:
"""
Build a simulation from a Python dict ``input_dict`` fully specifying entities.
Examples:
>>> simulation_builder.build_from_entities({
'persons': {'Javier': { 'salary': {'2018-11': 2000}}},
'households': {'household': {'parents': ['Javier']}}
})
>>> {
... 'persons': {'Javier': { 'salary': {'2018-11': 2000}}},
... 'households': {'household': {'parents': ['Javier']}}
... }
"""
input_dict = copy.deepcopy(input_dict)
fully_defined_params = copy.deepcopy(input_dict)

simulation = Simulation(
tax_benefit_system, tax_benefit_system.instantiate_entities()
Expand All @@ -114,12 +144,20 @@ def build_from_entities(self, tax_benefit_system, input_dict):
variable_name, simulation.get_variable_population(variable_name).entity
)

helpers.check_type(input_dict, dict, ["error"])
axes = input_dict.pop("axes", None)
helpers.check_type(fully_defined_params, dict, ["error"])
axes = typing.cast(Optional[AxesParams], fully_defined_params.get("axes", None))
full_defined_params_without_axes = typing.cast(
FullyDefinedParamsWithoutAxes,
{
key: value
for key, value in fully_defined_params.items()
if key != "axes"
},
)

unexpected_entities = [
entity
for entity in input_dict
for entity in full_defined_params_without_axes
if entity not in tax_benefit_system.entities_plural()
]
if unexpected_entities:
Expand All @@ -137,20 +175,32 @@ def build_from_entities(self, tax_benefit_system, input_dict):
", ".join(tax_benefit_system.entities_plural()),
),
)
persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None)

person_entity: SingleEntity = tax_benefit_system.person_entity

if person_entity.plural is None:
raise ValueError("#TODO")

persons_json = typing.cast(
Optional[SingleEntityParams],
full_defined_params_without_axes.get(person_entity.plural, None),
)

if not persons_json:
raise errors.SituationParsingError(
[tax_benefit_system.person_entity.plural],
[person_entity.plural],
"No {0} found. At least one {0} must be defined to run a simulation.".format(
tax_benefit_system.person_entity.key
person_entity.key
),
)

persons_ids = self.add_person_entity(simulation.persons.entity, persons_json)

for entity_class in tax_benefit_system.group_entities:
instances_json = input_dict.get(entity_class.plural)
instances_json: GroupEntityParams | None = (
full_defined_params_without_axes.get(entity_class.plural)
)

if instances_json is not None:
self.add_group_entity(
self.persons_plural, persons_ids, entity_class, instances_json
Expand Down Expand Up @@ -182,17 +232,17 @@ def build_from_entities(self, tax_benefit_system, input_dict):

return simulation

def build_from_variables(self, tax_benefit_system, input_dict):
def build_from_variables(
self, tax_benefit_system: TaxBenefitSystem, input_dict: VariableParams
) -> Simulation:
"""
Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities.
This method uses :any:`build_default_simulation` to infer an entity structure
Example:
>>> {'salary': {'2016-10': 12000}}
>>> simulation_builder.build_from_variables(
{'salary': {'2016-10': 12000}}
)
"""
count = helpers._get_person_count(input_dict)
simulation = self.build_default_simulation(tax_benefit_system, count)
Expand Down Expand Up @@ -275,7 +325,9 @@ def join_with_persons(
def build(self, tax_benefit_system):
return Simulation(tax_benefit_system, self.populations)

def explicit_singular_entities(self, tax_benefit_system, input_dict):
def explicit_singular_entities(
self, tax_benefit_system: TaxBenefitSystem, input_dict: FullyDefinedParams
) -> FullyDefinedParamsWithoutShortcut:
"""
Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut
Expand Down Expand Up @@ -309,6 +361,9 @@ def add_person_entity(self, entity, instances_json):
"""
Add the simulation's instances of the persons entity as described in ``instances_json``.
"""
if entity.plural is None:
raise ValueError("#TODO")

helpers.check_type(instances_json, dict, [entity.plural])
entity_ids = list(map(str, instances_json.keys()))
self.persons_plural = entity.plural
Expand All @@ -321,21 +376,34 @@ def add_person_entity(self, entity, instances_json):

return self.get_ids(entity.plural)

def add_default_group_entity(self, persons_ids, entity):
def add_default_group_entity(
self, persons_ids: list[str], entity: GroupEntity
) -> None:
if entity.plural is None:
raise ValueError("#TODO")

persons_count = len(persons_ids)
roles = list(entity.flattened_roles)
self.entity_ids[entity.plural] = persons_ids
self.entity_counts[entity.plural] = persons_count
self.memberships[entity.plural] = list(
numpy.arange(0, persons_count, dtype=numpy.int32)
)
self.roles[entity.plural] = list(
numpy.repeat(entity.flattened_roles[0], persons_count)
)
self.roles[entity.plural] = [roles[0]] * persons_count

def add_group_entity(self, persons_plural, persons_ids, entity, instances_json):
def add_group_entity(
self,
persons_plural: str,
persons_ids: list[str],
entity: GroupEntity,
instances_json,
) -> None:
"""
Add all instances of one of the model's entities as described in ``instances_json``.
"""
if entity.plural is None:
raise ValueError("#TODO")

helpers.check_type(instances_json, dict, [entity.plural])
entity_ids = list(map(str, instances_json.keys()))

Expand Down Expand Up @@ -682,7 +750,7 @@ def expand_axes(self) -> None:
axis_index = axis.index
axis_period = axis.period or self.default_period
axis_name = axis.name
variable = axis_entity.get_variable(axis_name)
variable = axis_entity.get_variable(axis_name, check_existence=True)
array = self.get_input(axis_name, str(axis_period))
if array is None:
array = variable.default_array(
Expand Down
Loading

0 comments on commit 237ac2c

Please sign in to comment.