Skip to content

Commit

Permalink
Fix types in commons
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Dec 1, 2022
1 parent 4c1e3dd commit d9d3501
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
4 changes: 2 additions & 2 deletions openfisca_core/commons/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TypeVar

from openfisca_core.types import Array
import numpy

T = TypeVar("T")

Expand Down Expand Up @@ -43,7 +43,7 @@ def empty_clone(original: T) -> T:
return new


def stringify_array(array: Array) -> str:
def stringify_array(array: numpy.ndarray) -> str:
"""Generates a clean string representation of a numpy array.
Args:
Expand Down
57 changes: 27 additions & 30 deletions openfisca_core/commons/rates.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Optional
from typing import Optional, Sequence

import numpy

from openfisca_core.types import ArrayLike, Array


def average_rate(
target: Array[float],
varying: ArrayLike[float],
trim: Optional[ArrayLike[float]] = None,
) -> Array[float]:
target: numpy.ndarray,
varying: Sequence[float],
trim: Optional[Sequence[float]] = None,
) -> numpy.ndarray:
"""Computes the average rate of a target net income.
Given a ``target`` net income, and according to the ``varying`` gross
Expand Down Expand Up @@ -41,32 +39,31 @@ def average_rate(
"""

average_rate: Array[float]

average_rate = 1 - target / varying
rate: numpy.ndarray
rate = 1 - target / varying

if trim is not None:

average_rate = numpy.where(
average_rate <= max(trim),
average_rate,
rate = numpy.where(
rate <= max(trim),
rate,
numpy.nan,
)

average_rate = numpy.where(
average_rate >= min(trim),
average_rate,
rate = numpy.where(
rate >= min(trim),
rate,
numpy.nan,
)

return average_rate
return rate


def marginal_rate(
target: Array[float],
varying: Array[float],
trim: Optional[ArrayLike[float]] = None,
) -> Array[float]:
target: numpy.ndarray,
varying: numpy.ndarray,
trim: Optional[numpy.ndarray] = None,
) -> numpy.ndarray:
"""Computes the marginal rate of a target net income.
Given a ``target`` net income, and according to the ``varying`` gross
Expand Down Expand Up @@ -98,26 +95,26 @@ def marginal_rate(
"""

marginal_rate: Array[float]
rate: numpy.ndarray

marginal_rate = (
rate = (
+ 1
- (target[:-1] - target[1:])
/ (varying[:-1] - varying[1:])
)

if trim is not None:

marginal_rate = numpy.where(
marginal_rate <= max(trim),
marginal_rate,
rate = numpy.where(
rate <= max(trim),
rate,
numpy.nan,
)

marginal_rate = numpy.where(
marginal_rate >= min(trim),
marginal_rate,
rate = numpy.where(
rate >= min(trim),
rate,
numpy.nan,
)

return marginal_rate
return rate
1 change: 0 additions & 1 deletion openfisca_core/tracers/computation_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def display(
) -> str:
if isinstance(value, EnumArray):
value = value.decode_to_str()
raise ValueError(type(value))

return numpy.array2string(value, max_line_width = None)

Expand Down

0 comments on commit d9d3501

Please sign in to comment.