Skip to content

Commit

Permalink
Revert changes to taxscales.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mauko Quiroga committed Jan 2, 2020
1 parent ef4ae1e commit 61e6358
Showing 1 changed file with 25 additions and 40 deletions.
65 changes: 25 additions & 40 deletions openfisca_core/taxscales.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import copy
import itertools
import logging
Expand All @@ -11,7 +9,6 @@
from numpy import (
around,
array,
asarray,
digitize,
dot,
finfo,
Expand Down Expand Up @@ -44,7 +41,7 @@ def __init__(
class_name: str,
method_name: str,
arg_name: str,
arg_value: Union[List, ndarray]
arg_value: ndarray
) -> None:
message = [
f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n",
Expand Down Expand Up @@ -109,7 +106,7 @@ def __repr__(self) -> Any:
f"{self.__class__.__name__}",
)

def calc(self, _tax_base: Union[ndarray[int], ndarray[float]], _right: bool) -> Any:
def calc(self, _tax_base: ndarray, _right: bool) -> Any:
raise NotImplementedError(
"Method 'calc' is not implemented for "
f"{self.__class__.__name__}",
Expand All @@ -129,7 +126,7 @@ def multiply_thresholds(

def bracket_indices(
self,
tax_base: Union[ndarray[int], ndarray[float]],
tax_base: ndarray,
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> Any:
Expand Down Expand Up @@ -247,20 +244,20 @@ def multiply_thresholds(

def bracket_indices(
self,
tax_base: Union[ndarray[int], ndarray[float]],
tax_base: ndarray,
factor: float = 1.0,
round_decimals: Optional[int] = None,
) -> ndarray[int]:
) -> ndarray:
"""
Compute the relevant bracket indices for the given tax bases.
:param ndarray tax_base: Array of the tax bases.
:param float factor: Factor to apply to the thresholds of the tax scales.
:param int round_decimals: Decimals to keep when rounding thresholds.
:param tax_base: Array of the tax bases.
:param factor: Factor to apply to the thresholds of the tax scales.
:param round_decimals: Decimals to keep when rounding thresholds.
:returns: Int array with relevant bracket indices for the given tax bases.
For instance:
:example:
>>> tax_scale = AbstractRateTaxScale()
>>> tax_scale.add_bracket(0, 0)
Expand All @@ -278,7 +275,7 @@ def bracket_indices(
self.thresholds,
)

if not size(asarray(tax_base)):
if not size(array(tax_base)):
raise EmptyArgumentError(
self.__class__.__name__,
"bracket_indices",
Expand Down Expand Up @@ -342,11 +339,7 @@ def add_bracket(self, threshold: int, amount: Union[int, float]) -> None:
self.thresholds.insert(i, threshold)
self.amounts.insert(i, amount)

def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
right: bool = False,
) -> ndarray[float]:
def calc(self, tax_base: ndarray, right: bool = False) -> ndarray:
guarded_thresholds = array([-inf] + self.thresholds + [inf])
bracket_indices = digitize(tax_base, guarded_thresholds, right = right)
guarded_amounts = array([0] + self.amounts + [0])
Expand All @@ -366,23 +359,15 @@ class MarginalAmountTaxScale(SingleAmountTaxScale):
containing the input.
"""

def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
_right: bool = False,
) -> ndarray[float]:
def calc(self, tax_base: ndarray, _right: bool = False) -> ndarray:
base1 = tile(tax_base, (len(self.thresholds), 1)).T
thresholds1 = tile(hstack((self.thresholds, inf)), (len(tax_base), 1))
a = max_(min_(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0)
return dot(self.amounts, a.T > 0)


class LinearAverageRateTaxScale(AbstractRateTaxScale):
def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
_right: bool = False,
) -> ndarray[float]:
def calc(self, tax_base: ndarray, _right: bool = False) -> ndarray:
if len(self.rates) == 1:
return tax_base * self.rates[0]

Expand Down Expand Up @@ -460,20 +445,20 @@ def add_tax_scale(self, tax_scale: AbstractRateTaxScale) -> None:

def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
tax_base: ndarray,
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray[float]:
) -> ndarray:
"""
Compute the tax amount for the given tax bases by applying the taxscale.
:param ndarray tax_base: Array of the tax bases.
:param float factor: Factor to apply to the thresholds of the tax scale.
:param int round_base_decimals: Decimals to keep when rounding thresholds.
:param tax_base: Array of the tax bases.
:param factor: Factor to apply to the thresholds of the tax scale.
:param round_base_decimals: Decimals to keep when rounding thresholds.
:returns: Float array with tax amount for the given tax bases.
For instance:
:example:
>>> tax_scale = MarginalRateTaxScale()
>>> tax_scale.add_bracket(0, 0)
Expand Down Expand Up @@ -529,20 +514,20 @@ def combine_bracket(

def marginal_rates(
self,
tax_base: Union[ndarray[int], ndarray[float]],
tax_base: ndarray,
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray[float]:
) -> ndarray:
"""
Compute the marginal tax rates relevant for the given tax bases.
:param ndarray tax_base: Array of the tax bases.
:param float factor: Factor to apply to the thresholds of the tax scale.
:param int round_base_decimals: Decimals to keep when rounding thresholds.
:param tax_base: Array of the tax bases.
:param factor: Factor to apply to the thresholds of the tax scale.
:param round_base_decimals: Decimals to keep when rounding thresholds.
:returns: Float array with relevant marginal tax rate for the given tax bases.
For instance:
:example:
>>> tax_scale = MarginalRateTaxScale()
>>> tax_scale.add_bracket(0, 0)
Expand Down

0 comments on commit 61e6358

Please sign in to comment.